Skip to content

Commit

Permalink
feat: enable OnOpen for connected UDP socket (#554)
Browse files Browse the repository at this point in the history
Fixes #549
  • Loading branch information
panjf2000 authored Mar 23, 2024
1 parent 54f81b6 commit ab69eec
Show file tree
Hide file tree
Showing 9 changed files with 119 additions and 34 deletions.
2 changes: 1 addition & 1 deletion acceptor_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func (eng *engine) listen() (err error) {
}
el := eng.eventLoops.next(tc.RemoteAddr())
c := newTCPConn(tc, el)
el.ch <- c
el.ch <- &openConn{c: c}
go func(c *conn, tc net.Conn, el *eventloop) {
var buffer [0x10000]byte
for {
Expand Down
82 changes: 63 additions & 19 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
package gnet

import (
"bytes"
"io"
"math/rand"
"net"
"sync"
"sync/atomic"
"testing"
"time"
Expand Down Expand Up @@ -41,6 +43,13 @@ func (ev *clientEvents) OnBoot(e Engine) Action {
return None
}

var pingMsg = []byte("PING\r\n")

func (ev *clientEvents) OnOpen(Conn) (out []byte, action Action) {
out = pingMsg
return
}

func (ev *clientEvents) OnClose(Conn, error) Action {
if ev.svr != nil {
if atomic.AddInt32(&ev.svr.clientActive, -1) == 0 {
Expand All @@ -53,7 +62,7 @@ func (ev *clientEvents) OnClose(Conn, error) Action {
func (ev *clientEvents) OnTraffic(c Conn) (action Action) {
handler := c.Context().(*connHandler)
if handler.network == "udp" {
ev.packetLen = 1024
ev.packetLen = datagramLen
}
buf, err := c.Next(-1)
assert.NoError(ev.tester, err)
Expand Down Expand Up @@ -190,19 +199,20 @@ func TestServeWithGnetClient(t *testing.T) {

type testClientServer struct {
*BuiltinEventEngine
client *Client
tester *testing.T
eng Engine
network string
addr string
multicore bool
async bool
nclients int
started int32
connected int32
clientActive int32
disconnected int32
workerPool *goPool.Pool
client *Client
tester *testing.T
eng Engine
network string
addr string
multicore bool
async bool
nclients int
started int32
connected int32
clientActive int32
disconnected int32
workerPool *goPool.Pool
udpReadHeader int32
}

func (s *testClientServer) OnBoot(eng Engine) (action Action) {
Expand All @@ -211,7 +221,7 @@ func (s *testClientServer) OnBoot(eng Engine) (action Action) {
}

func (s *testClientServer) OnOpen(c Conn) (out []byte, action Action) {
c.SetContext(c)
c.SetContext(&sync.Once{})
atomic.AddInt32(&s.connected, 1)
require.NotNil(s.tester, c.LocalAddr(), "nil local addr")
require.NotNil(s.tester, c.RemoteAddr(), "nil remote addr")
Expand All @@ -223,7 +233,7 @@ func (s *testClientServer) OnClose(c Conn, err error) (action Action) {
logging.Debugf("error occurred on closed, %v\n", err)
}
if s.network != "udp" {
require.Equal(s.tester, c.Context(), c, "invalid context")
require.IsType(s.tester, c.Context(), new(sync.Once), "invalid context")
}

atomic.AddInt32(&s.disconnected, 1)
Expand All @@ -236,7 +246,25 @@ func (s *testClientServer) OnClose(c Conn, err error) (action Action) {
return
}

func (s *testClientServer) OnShutdown(Engine) {
if s.network == "udp" {
require.EqualValues(s.tester, int32(s.nclients), atomic.LoadInt32(&s.udpReadHeader))
}
}

func (s *testClientServer) OnTraffic(c Conn) (action Action) {
readHeader := func() {
ping := make([]byte, len(pingMsg))
n, err := io.ReadFull(c, ping)
require.NoError(s.tester, err)
require.EqualValues(s.tester, len(pingMsg), n)
require.Equal(s.tester, string(pingMsg), string(ping), "bad header")
}
v := c.Context()
if v != nil {
v.(*sync.Once).Do(readHeader)
}

if s.async {
buf := bbPool.Get()
_, _ = c.WriteTo(buf)
Expand All @@ -247,14 +275,30 @@ func (s *testClientServer) OnTraffic(c Conn) (action Action) {
_ = c.OutboundBuffered()
_, _ = c.Discard(1)
}
if v == nil && bytes.Equal(buf.Bytes(), pingMsg) {
atomic.AddInt32(&s.udpReadHeader, 1)
buf.Reset()
}
_ = s.workerPool.Submit(
func() {
_ = c.AsyncWrite(buf.Bytes(), nil)
if buf.Len() > 0 {
err := c.AsyncWrite(buf.Bytes(), nil)
require.NoError(s.tester, err)
}
})
return
}

buf, _ := c.Next(-1)
_, _ = c.Write(buf)
if v == nil && bytes.Equal(buf, pingMsg) {
atomic.AddInt32(&s.udpReadHeader, 1)
buf = nil
}
if len(buf) > 0 {
n, err := c.Write(buf)
require.NoError(s.tester, err)
require.EqualValues(s.tester, len(buf), n)
}
return
}

Expand Down Expand Up @@ -343,7 +387,7 @@ func startGnetClient(t *testing.T, cli *Client, network, addr string, multicore,
for time.Since(start) < duration {
reqData := make([]byte, streamLen)
if network == "udp" {
reqData = reqData[:1024]
reqData = reqData[:datagramLen]
}
_, err = rand.Read(reqData)
require.NoError(t, err)
Expand Down
13 changes: 10 additions & 3 deletions client_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ func (cli *Client) EnrollContext(c net.Conn, ctx interface{}) (Conn, error) {

var (
sockAddr unix.Sockaddr
gc Conn
gc *conn
)
switch c.(type) {
case *net.UnixConn:
Expand Down Expand Up @@ -227,11 +227,18 @@ func (cli *Client) EnrollContext(c net.Conn, ctx interface{}) (Conn, error) {
default:
return nil, errorx.ErrUnsupportedProtocol
}
gc.SetContext(ctx)
err = cli.el.poller.UrgentTrigger(cli.el.register, gc)
gc.ctx = ctx

connOpened := make(chan struct{})
ccb := &connWithCallback{c: gc, cb: func() {
close(connOpened)
}}
err = cli.el.poller.UrgentTrigger(cli.el.register, ccb)
if err != nil {
gc.Close()
return nil, err
}

<-connOpened
return gc, nil
}
7 changes: 5 additions & 2 deletions client_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ func (cli *Client) Enroll(nc net.Conn) (gc Conn, err error) {
}

func (cli *Client) EnrollContext(nc net.Conn, ctx interface{}) (gc Conn, err error) {
connOpened := make(chan struct{})
switch v := nc.(type) {
case *net.TCPConn:
if cli.opts.TCPNoDelay == TCPNoDelay {
Expand All @@ -165,7 +166,7 @@ func (cli *Client) EnrollContext(nc net.Conn, ctx interface{}) (gc Conn, err err

c := newTCPConn(nc, cli.el)
c.SetContext(ctx)
cli.el.ch <- c
cli.el.ch <- &openConn{c: c, cb: func() { close(connOpened) }}
go func(c *conn, tc net.Conn, el *eventloop) {
var buffer [0x10000]byte
for {
Expand All @@ -181,7 +182,7 @@ func (cli *Client) EnrollContext(nc net.Conn, ctx interface{}) (gc Conn, err err
case *net.UnixConn:
c := newTCPConn(nc, cli.el)
c.SetContext(ctx)
cli.el.ch <- c
cli.el.ch <- &openConn{c: c, cb: func() { close(connOpened) }}
go func(c *conn, uc net.Conn, el *eventloop) {
var buffer [0x10000]byte
for {
Expand All @@ -204,6 +205,7 @@ func (cli *Client) EnrollContext(nc net.Conn, ctx interface{}) (gc Conn, err err
c := newUDPConn(cli.el, nc.LocalAddr(), nc.RemoteAddr())
c.SetContext(ctx)
c.rawConn = nc
cli.el.ch <- &openConn{c: c, isDatagram: true, cb: func() { close(connOpened) }}
go func(uc net.Conn, el *eventloop) {
var buffer [0x10000]byte
for {
Expand All @@ -222,5 +224,6 @@ func (cli *Client) EnrollContext(nc net.Conn, ctx interface{}) (gc Conn, err err
return nil, errorx.ErrUnsupportedProtocol
}

<-connOpened
return
}
6 changes: 5 additions & 1 deletion connection_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ func newUDPConn(fd int, el *eventloop, localAddr net.Addr, sa unix.Sockaddr, con
}

func (c *conn) release() {
c.opened = false
c.ctx = nil
c.buffer = nil
if addr, ok := c.localAddr.(*net.TCPAddr); ok && c.localAddr != c.loop.ln.addr && len(addr.Zone) > 0 {
Expand All @@ -102,14 +103,17 @@ func (c *conn) release() {
c.remoteAddr = nil
c.pollAttachment.FD, c.pollAttachment.Callback = 0, nil
if !c.isDatagram {
c.opened = false
c.peer = nil
c.inboundBuffer.Done()
c.outboundBuffer.Release()
}
}

func (c *conn) open(buf []byte) error {
if c.isDatagram && c.peer == nil {
return unix.Send(c.fd, buf, 0)
}

n, err := unix.Write(c.fd, buf)
if err != nil && err == unix.EAGAIN {
_, _ = c.outboundBuffer.Write(buf)
Expand Down
6 changes: 6 additions & 0 deletions connection_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ type udpConn struct {
c *conn
}

type openConn struct {
c *conn
cb func()
isDatagram bool
}

type conn struct {
ctx interface{} // user-defined context
loop *eventloop // owner event-loop
Expand Down
15 changes: 13 additions & 2 deletions eventloop_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,19 @@ func (el *eventloop) closeConns() {
})
}

type connWithCallback struct {
c *conn
cb func()
}

func (el *eventloop) register(itf interface{}) error {
c := itf.(*conn)
c, ok := itf.(*conn)
if !ok {
ccb := itf.(*connWithCallback)
c = ccb.c
defer ccb.cb()
}

if err := el.poller.AddRead(&c.pollAttachment); err != nil {
_ = unix.Close(c.fd)
c.release()
Expand All @@ -71,7 +82,7 @@ func (el *eventloop) register(itf interface{}) error {

el.connections.addConn(c, el.idx)

if c.isDatagram {
if c.isDatagram && c.peer != nil {
return nil
}
return el.open(c)
Expand Down
15 changes: 11 additions & 4 deletions eventloop_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func (el *eventloop) run() (err error) {
err = v
case *netErr:
err = el.close(v.c, v.err)
case *conn:
case *openConn:
err = el.open(v)
case *tcpConn:
unpackTCPConn(v)
Expand All @@ -90,9 +90,16 @@ func (el *eventloop) run() (err error) {
return nil
}

func (el *eventloop) open(c *conn) error {
el.connections[c] = struct{}{}
el.incConn(1)
func (el *eventloop) open(oc *openConn) error {
if oc.cb != nil {
defer oc.cb()
}

c := oc.c
if !oc.isDatagram {
el.connections[c] = struct{}{}
el.incConn(1)
}

out, action := el.eventHandler.OnOpen(c)
if out != nil {
Expand Down
7 changes: 5 additions & 2 deletions gnet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ import (
goPool "github.com/panjf2000/gnet/v2/pkg/pool/goroutine"
)

var streamLen = 1024 * 1024
var (
datagramLen = 1024
streamLen = 1024 * 1024
)

func TestServe(t *testing.T) {
// start an engine
Expand Down Expand Up @@ -415,7 +418,7 @@ func startClient(t *testing.T, network, addr string, multicore, async bool) {
for time.Since(start) < duration {
reqData := make([]byte, streamLen)
if network == "udp" {
reqData = reqData[:1024]
reqData = reqData[:datagramLen]
}
_, err = rand.Read(reqData)
require.NoError(t, err)
Expand Down

0 comments on commit ab69eec

Please sign in to comment.