diff --git a/client_test.go b/client_test.go index 47a454a99..e51b4fcd0 100644 --- a/client_test.go +++ b/client_test.go @@ -591,13 +591,13 @@ func (ev *clientEventsForWake) OnTraffic(c Conn) (action Action) { assert.Nilf(ev.tester, buf, "expected: %v, but got: %v", nil, buf) assert.ErrorIsf(ev.tester, err, io.ErrShortBuffer, "expected error: %v, but got: %v", io.ErrShortBuffer, err) buf, err = c.Next(-1) - assert.Nilf(ev.tester, buf, "expected: %v, but got: %v", nil, buf) + assert.Emptyf(ev.tester, buf, "expected an empty slice, but got: %v", buf) assert.NoErrorf(ev.tester, err, "expected: %v, but got: %v", nil, err) buf, err = c.Peek(10) assert.Nilf(ev.tester, buf, "expected: %v, but got: %v", nil, buf) assert.ErrorIsf(ev.tester, err, io.ErrShortBuffer, "expected error: %v, but got: %v", io.ErrShortBuffer, err) buf, err = c.Peek(-1) - assert.Nilf(ev.tester, buf, "expected: %v, but got: %v", nil, buf) + assert.Emptyf(ev.tester, buf, "expected an empty slice, but got: %v", buf) assert.NoErrorf(ev.tester, err, "expected: %v, but got: %v", nil, err) n, err = c.Discard(10) assert.Zerof(ev.tester, n, "expected: %v, but got: %v", 0, n) diff --git a/client_unix.go b/client_unix.go index 709a09ef2..7fef073b9 100644 --- a/client_unix.go +++ b/client_unix.go @@ -135,7 +135,7 @@ func (cli *Client) Start() error { // Stop stops the client event-loop. func (cli *Client) Stop() (err error) { - logging.Error(cli.el.poller.Trigger(queue.HighPriority, func(_ interface{}) error { return errorx.ErrEngineShutdown }, nil)) + logging.Error(cli.el.poller.Trigger(queue.HighPriority, func(_ any) error { return errorx.ErrEngineShutdown }, nil)) // Stop the ticker. if cli.opts.Ticker { cli.el.engine.ticker.cancel() @@ -153,7 +153,7 @@ func (cli *Client) Dial(network, address string) (Conn, error) { } // DialContext is like Dial but also accepts an empty interface ctx that can be obtained later via Conn.Context. -func (cli *Client) DialContext(network, address string, ctx interface{}) (Conn, error) { +func (cli *Client) DialContext(network, address string, ctx any) (Conn, error) { c, err := net.Dial(network, address) if err != nil { return nil, err @@ -167,7 +167,7 @@ func (cli *Client) Enroll(c net.Conn) (Conn, error) { } // EnrollContext is like Enroll but also accepts an empty interface ctx that can be obtained later via Conn.Context. -func (cli *Client) EnrollContext(c net.Conn, ctx interface{}) (Conn, error) { +func (cli *Client) EnrollContext(c net.Conn, ctx any) (Conn, error) { defer c.Close() sc, ok := c.(syscall.Conn) diff --git a/client_windows.go b/client_windows.go index d9cbbde4b..ae2750bb7 100644 --- a/client_windows.go +++ b/client_windows.go @@ -61,7 +61,7 @@ func NewClient(eh EventHandler, opts ...Option) (cli *Client, err error) { eventHandler: eh, } cli.el = &eventloop{ - ch: make(chan interface{}, 1024), + ch: make(chan any, 1024), eng: eng, connections: make(map[*conn]struct{}), eventHandler: eh, @@ -121,7 +121,7 @@ func (cli *Client) Dial(network, addr string) (Conn, error) { return cli.DialContext(network, addr, nil) } -func (cli *Client) DialContext(network, addr string, ctx interface{}) (Conn, error) { +func (cli *Client) DialContext(network, addr string, ctx any) (Conn, error) { var ( c net.Conn err error @@ -146,7 +146,7 @@ func (cli *Client) Enroll(nc net.Conn) (gc Conn, err error) { return cli.EnrollContext(nc, nil) } -func (cli *Client) EnrollContext(nc net.Conn, ctx interface{}) (gc Conn, err error) { +func (cli *Client) EnrollContext(nc net.Conn, ctx any) (gc Conn, err error) { connOpened := make(chan struct{}) switch v := nc.(type) { case *net.TCPConn: diff --git a/connection_unix.go b/connection_unix.go index 3ccac4ab1..dc568ad80 100644 --- a/connection_unix.go +++ b/connection_unix.go @@ -40,7 +40,7 @@ import ( type conn struct { fd int // file descriptor gfd gfd.GFD // gnet file descriptor - ctx interface{} // user-defined context + ctx any // user-defined context remote unix.Sockaddr // remote socket address localAddr net.Addr // local addr remoteAddr net.Addr // remote addr @@ -243,8 +243,8 @@ type asyncWriteHook struct { data []byte } -func (c *conn) asyncWrite(itf interface{}) (err error) { - hook := itf.(*asyncWriteHook) +func (c *conn) asyncWrite(a any) (err error) { + hook := a.(*asyncWriteHook) defer func() { if hook.callback != nil { _ = hook.callback(c, err) @@ -264,8 +264,8 @@ type asyncWritevHook struct { data [][]byte } -func (c *conn) asyncWritev(itf interface{}) (err error) { - hook := itf.(*asyncWritevHook) +func (c *conn) asyncWritev(a any) (err error) { + hook := a.(*asyncWritevHook) defer func() { if hook.callback != nil { _ = hook.callback(c, err) @@ -318,16 +318,18 @@ func (c *conn) Next(n int) (buf []byte, err error) { } else if n <= 0 { n = totalLen } + if c.inboundBuffer.IsEmpty() { buf = c.buffer[:n] c.buffer = c.buffer[n:] return } + head, tail := c.inboundBuffer.Peek(n) defer c.inboundBuffer.Discard(n) //nolint:errcheck c.loop.cache.Reset() c.loop.cache.Write(head) - if len(head) >= n { + if len(head) == n { return c.loop.cache.Bytes(), err } c.loop.cache.Write(tail) @@ -348,12 +350,14 @@ func (c *conn) Peek(n int) (buf []byte, err error) { } else if n <= 0 { n = totalLen } + if c.inboundBuffer.IsEmpty() { return c.buffer[:n], err } + head, tail := c.inboundBuffer.Peek(n) - if len(head) >= n { - return head[:n], err + if len(head) == n { + return head, err } c.loop.cache.Reset() c.loop.cache.Write(head) @@ -435,10 +439,10 @@ func (c *conn) OutboundBuffered() int { return c.outboundBuffer.Buffered() } -func (c *conn) Context() interface{} { return c.ctx } -func (c *conn) SetContext(ctx interface{}) { c.ctx = ctx } -func (c *conn) LocalAddr() net.Addr { return c.localAddr } -func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr } +func (c *conn) Context() any { return c.ctx } +func (c *conn) SetContext(ctx any) { c.ctx = ctx } +func (c *conn) LocalAddr() net.Addr { return c.localAddr } +func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr } // Implementation of Socket interface @@ -485,7 +489,7 @@ func (c *conn) AsyncWritev(bs [][]byte, callback AsyncCallback) error { } func (c *conn) Wake(callback AsyncCallback) error { - return c.loop.poller.Trigger(queue.LowPriority, func(_ interface{}) (err error) { + return c.loop.poller.Trigger(queue.LowPriority, func(_ any) (err error) { err = c.loop.wake(c) if callback != nil { _ = callback(c, err) @@ -495,7 +499,7 @@ func (c *conn) Wake(callback AsyncCallback) error { } func (c *conn) CloseWithCallback(callback AsyncCallback) error { - return c.loop.poller.Trigger(queue.LowPriority, func(_ interface{}) (err error) { + return c.loop.poller.Trigger(queue.LowPriority, func(_ any) (err error) { err = c.loop.close(c, nil) if callback != nil { _ = callback(c, err) @@ -505,7 +509,7 @@ func (c *conn) CloseWithCallback(callback AsyncCallback) error { } func (c *conn) Close() error { - return c.loop.poller.Trigger(queue.LowPriority, func(_ interface{}) (err error) { + return c.loop.poller.Trigger(queue.LowPriority, func(_ any) (err error) { err = c.loop.close(c, nil) return }, nil) diff --git a/connection_windows.go b/connection_windows.go index 745028434..443404c4a 100644 --- a/connection_windows.go +++ b/connection_windows.go @@ -21,11 +21,13 @@ import ( "syscall" "time" + "github.com/panjf2000/ants/v2" "golang.org/x/sys/windows" "github.com/panjf2000/gnet/v2/pkg/buffer/elastic" errorx "github.com/panjf2000/gnet/v2/pkg/errors" bbPool "github.com/panjf2000/gnet/v2/pkg/pool/bytebuffer" + goPool "github.com/panjf2000/gnet/v2/pkg/pool/goroutine" ) type netErr struct { @@ -34,8 +36,8 @@ type netErr struct { } type tcpConn struct { - c *conn - buf *bbPool.ByteBuffer + c *conn + b *bbPool.ByteBuffer } type udpConn struct { @@ -49,7 +51,7 @@ type openConn struct { type conn struct { pc net.PacketConn - ctx interface{} // user-defined context + ctx any // user-defined context loop *eventloop // owner event-loop buffer *bbPool.ByteBuffer // reuse memory of inbound data as a temporary buffer rawConn net.Conn // original connection @@ -59,35 +61,34 @@ type conn struct { } func packTCPConn(c *conn, buf []byte) *tcpConn { - tc := &tcpConn{c: c, buf: bbPool.Get()} - _, _ = tc.buf.Write(buf) - return tc + b := bbPool.Get() + _, _ = b.Write(buf) + return &tcpConn{c: c, b: b} } -func unpackTCPConn(tc *tcpConn) { - tc.c.buffer = tc.buf - tc.buf = nil -} - -func resetTCPConn(tc *tcpConn) { - bbPool.Put(tc.c.buffer) - tc.c.buffer = nil +func unpackTCPConn(tc *tcpConn) *conn { + if tc.c.buffer == nil { // the connection has been closed + return nil + } + _, _ = tc.c.buffer.Write(tc.b.B) + bbPool.Put(tc.b) + tc.b = nil + return tc.c } func packUDPConn(c *conn, buf []byte) *udpConn { - uc := &udpConn{c} - _, _ = uc.c.buffer.Write(buf) - return uc + _, _ = c.buffer.Write(buf) + return &udpConn{c} } func newTCPConn(nc net.Conn, el *eventloop) (c *conn) { - c = &conn{ - loop: el, - rawConn: nc, + return &conn{ + loop: el, + buffer: bbPool.Get(), + rawConn: nc, + localAddr: nc.LocalAddr(), + remoteAddr: nc.RemoteAddr(), } - c.localAddr = c.rawConn.LocalAddr() - c.remoteAddr = c.rawConn.RemoteAddr() - return } func (c *conn) release() { @@ -118,18 +119,11 @@ func (c *conn) resetBuffer() { } func (c *conn) Read(p []byte) (n int, err error) { - if c.buffer == nil { - if len(p) == 0 { - return 0, nil - } - return 0, io.ErrShortBuffer - } - if c.inboundBuffer.IsEmpty() { n = copy(p, c.buffer.B) c.buffer.B = c.buffer.B[n:] if n == 0 && len(p) > 0 { - err = io.EOF + err = io.ErrShortBuffer } return } @@ -144,13 +138,6 @@ func (c *conn) Read(p []byte) (n int, err error) { } func (c *conn) Next(n int) (buf []byte, err error) { - if c.buffer == nil { - if n <= 0 { - return nil, nil - } - return nil, io.ErrShortBuffer - } - inBufferLen := c.inboundBuffer.Buffered() if totalLen := inBufferLen + c.buffer.Len(); n > totalLen { return nil, io.ErrShortBuffer @@ -166,7 +153,7 @@ func (c *conn) Next(n int) (buf []byte, err error) { defer c.inboundBuffer.Discard(n) //nolint:errcheck c.loop.cache.Reset() c.loop.cache.Write(head) - if len(head) >= n { + if len(head) == n { return c.loop.cache.Bytes(), err } c.loop.cache.Write(tail) @@ -181,13 +168,6 @@ func (c *conn) Next(n int) (buf []byte, err error) { } func (c *conn) Peek(n int) (buf []byte, err error) { - if c.buffer == nil { - if n <= 0 { - return nil, nil - } - return nil, io.ErrShortBuffer - } - inBufferLen := c.inboundBuffer.Buffered() if totalLen := inBufferLen + c.buffer.Len(); n > totalLen { return nil, io.ErrShortBuffer @@ -198,8 +178,8 @@ func (c *conn) Peek(n int) (buf []byte, err error) { return c.buffer.B[:n], err } head, tail := c.inboundBuffer.Peek(n) - if len(head) >= n { - return head[:n], err + if len(head) == n { + return head, err } c.loop.cache.Reset() c.loop.cache.Write(head) @@ -214,10 +194,6 @@ func (c *conn) Peek(n int) (buf []byte, err error) { } func (c *conn) Discard(n int) (int, error) { - if c.buffer == nil { - return 0, nil - } - inBufferLen := c.inboundBuffer.Buffered() tempBufferLen := c.buffer.Len() if inBufferLen+tempBufferLen < n || n <= 0 { @@ -297,10 +273,10 @@ func (c *conn) OutboundBuffered() int { return 0 } -func (c *conn) Context() interface{} { return c.ctx } -func (c *conn) SetContext(ctx interface{}) { c.ctx = ctx } -func (c *conn) LocalAddr() net.Addr { return c.localAddr } -func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr } +func (c *conn) Context() any { return c.ctx } +func (c *conn) SetContext(ctx any) { c.ctx = ctx } +func (c *conn) LocalAddr() net.Addr { return c.localAddr } +func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr } func (c *conn) Fd() (fd int) { if c.rawConn == nil { @@ -430,19 +406,43 @@ func (c *conn) SetKeepAlivePeriod(d time.Duration) error { return nil } +type nonBlockingPool struct { + *goPool.Pool +} + +func (np *nonBlockingPool) Go(task func()) (err error) { + if err = np.Submit(task); err == ants.ErrPoolOverload { + go task() + } + return +} + +var workerPool = nonBlockingPool{Pool: goPool.Default()} + // Gfd return an uninitialized GFD which is not valid, // this method is only implemented for compatibility, don't use it on Windows. // func (c *conn) Gfd() gfd.GFD { return gfd.GFD{} } func (c *conn) AsyncWrite(buf []byte, cb AsyncCallback) error { - if cb == nil { - cb = func(c Conn, err error) error { return nil } - } _, err := c.Write(buf) - c.loop.ch <- func() error { - return cb(c, err) + + callback := func() error { + if cb != nil { + _ = cb(c, err) + } + return err } - return nil + + select { + case c.loop.ch <- callback: + default: + // If the event-loop channel is full, asynchronize this operation to avoid blocking the eventloop. + err = workerPool.Go(func() { + c.loop.ch <- callback + }) + } + + return err } func (c *conn) AsyncWritev(bs [][]byte, cb AsyncCallback) error { @@ -459,48 +459,63 @@ func (c *conn) AsyncWritev(bs [][]byte, cb AsyncCallback) error { }) } -func (c *conn) Wake(cb AsyncCallback) error { - if cb == nil { - cb = func(c Conn, err error) error { return nil } +func (c *conn) Wake(cb AsyncCallback) (err error) { + wakeFn := func() (err error) { + err = c.loop.wake(c) + if cb != nil { + _ = cb(c, err) + } + return } - c.loop.ch <- func() (err error) { - defer func() { - defer func() { - if err == nil { - err = cb(c, nil) - return - } - _ = cb(c, err) - }() - }() - return c.loop.wake(c) + + select { + case c.loop.ch <- wakeFn: + default: + // If the event-loop channel is full, asynchronize this operation to avoid blocking the eventloop. + err = workerPool.Go(func() { + c.loop.ch <- wakeFn + }) } - return nil + + return } -func (c *conn) Close() error { - c.loop.ch <- func() error { - err := c.loop.close(c, nil) - return err +func (c *conn) Close() (err error) { + closeFn := func() error { + return c.loop.close(c, nil) } - return nil -} -func (c *conn) CloseWithCallback(cb AsyncCallback) error { - if cb == nil { - cb = func(c Conn, err error) error { return nil } + select { + case c.loop.ch <- closeFn: + default: + // If the event-loop channel is full, asynchronize this operation to avoid blocking the eventloop. + err = workerPool.Go(func() { + c.loop.ch <- closeFn + }) } - c.loop.ch <- func() (err error) { - defer func() { - if err == nil { - err = cb(c, nil) - return - } + + return +} + +func (c *conn) CloseWithCallback(cb AsyncCallback) (err error) { + closeFn := func() (err error) { + err = c.loop.close(c, nil) + if cb != nil { _ = cb(c, err) - }() - return c.loop.close(c, nil) + } + return } - return nil + + select { + case c.loop.ch <- closeFn: + default: + // If the event-loop channel is full, asynchronize this operation to avoid blocking the eventloop. + err = workerPool.Go(func() { + c.loop.ch <- closeFn + }) + } + + return } func (*conn) SetDeadline(_ time.Time) error { diff --git a/engine_unix.go b/engine_unix.go index 3bc488f95..607e2f01f 100644 --- a/engine_unix.go +++ b/engine_unix.go @@ -212,14 +212,14 @@ func (eng *engine) stop(s Engine) { // Notify all event-loops to exit. eng.eventLoops.iterate(func(i int, el *eventloop) bool { - err := el.poller.Trigger(queue.HighPriority, func(_ interface{}) error { return errorx.ErrEngineShutdown }, nil) + err := el.poller.Trigger(queue.HighPriority, func(_ any) error { return errorx.ErrEngineShutdown }, nil) if err != nil { eng.opts.Logger.Errorf("failed to enqueue shutdown signal of high-priority for event-loop(%d): %v", i, err) } return true }) if eng.ingress != nil { - err := eng.ingress.poller.Trigger(queue.HighPriority, func(_ interface{}) error { return errorx.ErrEngineShutdown }, nil) + err := eng.ingress.poller.Trigger(queue.HighPriority, func(_ any) error { return errorx.ErrEngineShutdown }, nil) if err != nil { eng.opts.Logger.Errorf("failed to enqueue shutdown signal of high-priority for main event-loop: %v", err) } diff --git a/engine_windows.go b/engine_windows.go index 304b14421..3eabb03ba 100644 --- a/engine_windows.go +++ b/engine_windows.go @@ -74,7 +74,7 @@ func (eng *engine) closeEventLoops() { func (eng *engine) start(numEventLoop int) error { for i := 0; i < numEventLoop; i++ { el := eventloop{ - ch: make(chan interface{}, 1024), + ch: make(chan any, 1024), idx: i, eng: eng, connections: make(map[*conn]struct{}), diff --git a/eventloop_unix.go b/eventloop_unix.go index 6602e78eb..ff73ef18b 100644 --- a/eventloop_unix.go +++ b/eventloop_unix.go @@ -68,10 +68,10 @@ type connWithCallback struct { cb func() } -func (el *eventloop) register(itf interface{}) error { - c, ok := itf.(*conn) +func (el *eventloop) register(a any) error { + c, ok := a.(*conn) if !ok { - ccb := itf.(*connWithCallback) + ccb := a.(*connWithCallback) c = ccb.c defer ccb.cb() } @@ -114,8 +114,8 @@ func (el *eventloop) open(c *conn) error { return el.handleAction(c, action) } -func (el *eventloop) read0(itf interface{}) error { - return el.read(itf.(*conn)) +func (el *eventloop) read0(a any) error { + return el.read(a.(*conn)) } func (el *eventloop) read(c *conn) error { @@ -166,8 +166,8 @@ loop: return nil } -func (el *eventloop) write0(itf interface{}) error { - return el.write(itf.(*conn)) +func (el *eventloop) write0(a any) error { + return el.write(a.(*conn)) } // The default value of UIO_MAXIOV/IOV_MAX is 1024 on Linux and most BSD-like OSs. @@ -297,7 +297,7 @@ func (el *eventloop) ticker(ctx context.Context) { case Shutdown: // It seems reasonable to mark this as low-priority, waiting for some tasks like asynchronous writes // to finish up before shutting down the service. - err := el.poller.Trigger(queue.LowPriority, func(_ interface{}) error { return errorx.ErrEngineShutdown }, nil) + err := el.poller.Trigger(queue.LowPriority, func(_ any) error { return errorx.ErrEngineShutdown }, nil) el.getLogger().Debugf("failed to enqueue shutdown signal of high-priority for event-loop(%d): %v", el.idx, err) } if timer == nil { @@ -354,8 +354,8 @@ func (el *eventloop) handleAction(c *conn, action Action) error { } /* -func (el *eventloop) execCmd(itf interface{}) (err error) { - cmd := itf.(*asyncCmd) +func (el *eventloop) execCmd(a any) (err error) { + cmd := a.(*asyncCmd) c := el.connections.getConnByGFD(cmd.fd) if c == nil || c.gfd != cmd.fd { return errorx.ErrInvalidConn @@ -373,9 +373,9 @@ func (el *eventloop) execCmd(itf interface{}) (err error) { case asyncCmdWake: return el.wake(c) case asyncCmdWrite: - _, err = c.Write(cmd.arg.([]byte)) + _, err = c.Write(cmd.param.([]byte)) case asyncCmdWritev: - _, err = c.Writev(cmd.arg.([][]byte)) + _, err = c.Writev(cmd.param.([][]byte)) default: return errorx.ErrUnsupportedOp } diff --git a/eventloop_windows.go b/eventloop_windows.go index 906d8924f..460652981 100644 --- a/eventloop_windows.go +++ b/eventloop_windows.go @@ -28,7 +28,7 @@ import ( ) type eventloop struct { - ch chan interface{} // channel for event-loop + ch chan any // channel for event-loop idx int // index of event-loop in event-loops eng *engine // engine in loop cache bytes.Buffer // temporary buffer for scattered bytes @@ -71,9 +71,7 @@ func (el *eventloop) run() (err error) { case *openConn: err = el.open(v) case *tcpConn: - unpackTCPConn(v) - err = el.read(v.c) - resetTCPConn(v) + err = el.read(unpackTCPConn(v)) case *udpConn: err = el.readUDP(v.c) case func() error: diff --git a/gnet.go b/gnet.go index f05197e72..b2572932f 100644 --- a/gnet.go +++ b/gnet.go @@ -126,7 +126,7 @@ type asyncCmd struct { fd gfd.GFD typ asyncCmdType cb AsyncCallback - arg interface{} + param any } // AsyncWrite writes data to the given connection asynchronously. @@ -135,7 +135,7 @@ func (e Engine) AsyncWrite(fd gfd.GFD, p []byte, cb AsyncCallback) error { return err } - return e.eng.sendCmd(&asyncCmd{fd: fd, typ: asyncCmdWrite, cb: cb, arg: p}, false) + return e.eng.sendCmd(&asyncCmd{fd: fd, typ: asyncCmdWrite, cb: cb, param: p}, false) } // AsyncWritev is like AsyncWrite, but it accepts a slice of byte slices. @@ -144,7 +144,7 @@ func (e Engine) AsyncWritev(fd gfd.GFD, batch [][]byte, cb AsyncCallback) error return err } - return e.eng.sendCmd(&asyncCmd{fd: fd, typ: asyncCmdWritev, cb: cb, arg: batch}, false) + return e.eng.sendCmd(&asyncCmd{fd: fd, typ: asyncCmdWritev, cb: cb, param: batch}, false) } // Close closes the given connection. @@ -237,9 +237,12 @@ type Writer interface { AsyncWritev(bs [][]byte, callback AsyncCallback) (err error) } -// AsyncCallback is a callback which will be invoked after the asynchronous functions has finished executing. +// AsyncCallback is a callback that will be invoked after the asynchronous function finishes. // -// Note that the parameter gnet.Conn is already released under UDP protocol, thus it's not allowed to be accessed. +// Note that the parameter gnet.Conn might have been already released when it's UDP protocol, +// thus it shouldn't be accessed. +// This callback will be executed in event-loop, thus it must not block, otherwise, +// it blocks the event-loop. type AsyncCallback func(c Conn, err error) error // Socket is a set of functions which manipulate the underlying file descriptor of a connection. @@ -303,11 +306,11 @@ type Conn interface { // Context returns a user-defined context, it's not concurrency-safe, // you must invoke it within any method in EventHandler. - Context() (ctx interface{}) + Context() (ctx any) // SetContext sets a user-defined context, it's not concurrency-safe, // you must invoke it within any method in EventHandler. - SetContext(ctx interface{}) + SetContext(ctx any) // LocalAddr is the connection's local socket address, it's not concurrency-safe, // you must invoke it within any method in EventHandler. diff --git a/gnet_test.go b/gnet_test.go index 201c6b9de..b6755e038 100644 --- a/gnet_test.go +++ b/gnet_test.go @@ -8,6 +8,7 @@ import ( "encoding/binary" "errors" "io" + "math" "math/rand" "net" "path/filepath" @@ -1542,7 +1543,8 @@ type simServer struct { multicore bool nclients int packetSize int - packetBatch int + batchWrite int + batchRead int started int32 connected int32 disconnected int32 @@ -1579,7 +1581,7 @@ func (s *simServer) OnClose(_ Conn, err error) (action Action) { func (s *simServer) OnTraffic(c Conn) (action Action) { codec := c.Context().(*testCodec) var packets [][]byte - for { + for i := 0; i < s.batchRead; i++ { data, err := codec.Decode(c) if errors.Is(err, errIncompletePacket) { break @@ -1596,6 +1598,10 @@ func (s *simServer) OnTraffic(c Conn) (action Action) { } else if n == 1 { _, _ = c.Write(packets[0]) } + if len(packets) == s.batchRead && c.InboundBuffered() > 0 { + err := c.Wake(nil) // wake up the connection manually to avoid missing the leftover data + assert.NoError(s.tester, err) + } return } @@ -1603,7 +1609,7 @@ func (s *simServer) OnTick() (delay time.Duration, action Action) { if atomic.CompareAndSwapInt32(&s.started, 0, 1) { for i := 0; i < s.nclients; i++ { go func() { - runSimClient(s.tester, s.network, s.addr, s.packetSize, s.packetBatch) + runSimClient(s.tester, s.network, s.addr, s.packetSize, s.batchWrite) }() } } @@ -1651,11 +1657,14 @@ func (codec testCodec) Encode(buf []byte) ([]byte, error) { return data, nil } -func (codec *testCodec) Decode(c Conn) ([]byte, error) { +func (codec testCodec) Decode(c Conn) ([]byte, error) { bodyOffset := magicNumberSize + bodySize - buf, _ := c.Peek(bodyOffset) - if len(buf) < bodyOffset { - return nil, errIncompletePacket + buf, err := c.Peek(bodyOffset) + if err != nil { + if errors.Is(err, io.ErrShortBuffer) { + err = errIncompletePacket + } + return nil, err } if !bytes.Equal(magicNumberBytes, buf[:magicNumberSize]) { @@ -1664,13 +1673,18 @@ func (codec *testCodec) Decode(c Conn) ([]byte, error) { bodyLen := binary.BigEndian.Uint32(buf[magicNumberSize:bodyOffset]) msgLen := bodyOffset + int(bodyLen) - if c.InboundBuffered() < msgLen { - return nil, errIncompletePacket + buf, err = c.Peek(msgLen) + if err != nil { + if errors.Is(err, io.ErrShortBuffer) { + err = errIncompletePacket + } + return nil, err } - buf, _ = c.Peek(msgLen) + body := make([]byte, bodyLen) + copy(body, buf[bodyOffset:msgLen]) _, _ = c.Discard(msgLen) - return buf[bodyOffset:msgLen], nil + return body, nil } func (codec testCodec) Unpack(buf []byte) ([]byte, error) { @@ -1693,41 +1707,48 @@ func (codec testCodec) Unpack(buf []byte) ([]byte, error) { } func TestSimServer(t *testing.T) { + t.Run("packet-size=64,batch=200", func(t *testing.T) { + runSimServer(t, ":7200", true, 10, 64, 200, -1) + }) t.Run("packet-size=128,batch=100", func(t *testing.T) { - runSimServer(t, ":7200", false, 10, 128, 100) + runSimServer(t, ":7201", false, 10, 128, 100, 10) }) t.Run("packet-size=256,batch=50", func(t *testing.T) { - runSimServer(t, ":7201", true, 10, 256, 50) + runSimServer(t, ":7202", true, 10, 256, 50, -1) }) t.Run("packet-size=512,batch=30", func(t *testing.T) { - runSimServer(t, ":7202", false, 10, 512, 30) + runSimServer(t, ":7203", false, 10, 512, 30, 3) }) t.Run("packet-size=1024,batch=20", func(t *testing.T) { - runSimServer(t, ":7203", true, 10, 1024, 20) + runSimServer(t, ":7204", true, 10, 1024, 20, -1) }) t.Run("packet-size=64*1024,batch=10", func(t *testing.T) { - runSimServer(t, ":7204", false, 10, 64*1024, 10) + runSimServer(t, ":7205", false, 10, 64*1024, 10, 1) }) t.Run("packet-size=128*1024,batch=5", func(t *testing.T) { - runSimServer(t, ":7205", true, 10, 128*1024, 5) + runSimServer(t, ":7206", true, 10, 128*1024, 5, -1) }) t.Run("packet-size=512*1024,batch=3", func(t *testing.T) { - runSimServer(t, ":7206", false, 10, 512*1024, 3) + runSimServer(t, ":7207", false, 10, 512*1024, 3, 1) }) t.Run("packet-size=1024*1024,batch=2", func(t *testing.T) { - runSimServer(t, ":7207", true, 10, 1024*1024, 2) + runSimServer(t, ":7208", true, 10, 1024*1024, 2, -1) }) } -func runSimServer(t *testing.T, addr string, et bool, nclients, packetSize, packetBatch int) { +func runSimServer(t *testing.T, addr string, et bool, nclients, packetSize, batchWrite, batchRead int) { ts := &simServer{ - tester: t, - network: "tcp", - addr: addr, - multicore: true, - nclients: nclients, - packetSize: packetSize, - packetBatch: packetBatch, + tester: t, + network: "tcp", + addr: addr, + multicore: true, + nclients: nclients, + packetSize: packetSize, + batchWrite: batchWrite, + batchRead: batchRead, + } + if batchRead < 0 { + ts.batchRead = math.MaxInt32 // unlimited read batch } err := Run(ts, ts.network+"://"+ts.addr, @@ -1789,6 +1810,7 @@ func batchSendAndRecv(t *testing.T, c net.Conn, rd *bufio.Reader, packetSize, ba for i, req := range requests { rsp, err := codec.Unpack(respPacket[i*packetLen:]) require.NoError(t, err) - require.Equalf(t, req, rsp, "request and response mismatch, packet size: %d, batch: %d", packetSize, batch) + require.Equalf(t, req, rsp, "request and response mismatch, packet size: %d, batch: %d, round: %d", + packetSize, batch, i) } } diff --git a/internal/netpoll/poller_epoll_default.go b/internal/netpoll/poller_epoll_default.go index ec47d0b46..39b6b0c1e 100644 --- a/internal/netpoll/poller_epoll_default.go +++ b/internal/netpoll/poller_epoll_default.go @@ -88,9 +88,9 @@ var ( // any asks other than high-priority tasks will be shunted to asyncTaskQueue. // // Note that asyncTaskQueue is a queue of low-priority whose size may grow large and tasks in it may backlog. -func (p *Poller) Trigger(priority queue.EventPriority, fn queue.TaskFunc, arg interface{}) (err error) { +func (p *Poller) Trigger(priority queue.EventPriority, fn queue.Func, param any) (err error) { task := queue.GetTask() - task.Run, task.Arg = fn, arg + task.Exec, task.Param = fn, param if priority > queue.HighPriority && p.urgentAsyncTaskQueue.Length() >= p.highPriorityEventsThreshold { p.asyncTaskQueue.Enqueue(task) } else { @@ -145,7 +145,7 @@ func (p *Poller) Polling(callback PollEventHandler) error { doChores = false task := p.urgentAsyncTaskQueue.Dequeue() for ; task != nil; task = p.urgentAsyncTaskQueue.Dequeue() { - err = task.Run(task.Arg) + err = task.Exec(task.Param) if errors.Is(err, errorx.ErrEngineShutdown) { return err } @@ -155,7 +155,7 @@ func (p *Poller) Polling(callback PollEventHandler) error { if task = p.asyncTaskQueue.Dequeue(); task == nil { break } - err = task.Run(task.Arg) + err = task.Exec(task.Param) if errors.Is(err, errorx.ErrEngineShutdown) { return err } diff --git a/internal/netpoll/poller_epoll_ultimate.go b/internal/netpoll/poller_epoll_ultimate.go index 730e0485b..acd3fb7d9 100644 --- a/internal/netpoll/poller_epoll_ultimate.go +++ b/internal/netpoll/poller_epoll_ultimate.go @@ -89,9 +89,9 @@ var ( // any asks other than high-priority tasks will be shunted to asyncTaskQueue. // // Note that asyncTaskQueue is a queue of low-priority whose size may grow large and tasks in it may backlog. -func (p *Poller) Trigger(priority queue.EventPriority, fn queue.TaskFunc, arg interface{}) (err error) { +func (p *Poller) Trigger(priority queue.EventPriority, fn queue.Func, param any) (err error) { task := queue.GetTask() - task.Run, task.Arg = fn, arg + task.Exec, task.Param = fn, param if priority > queue.HighPriority && p.urgentAsyncTaskQueue.Length() >= p.highPriorityEventsThreshold { p.asyncTaskQueue.Enqueue(task) } else { @@ -147,7 +147,7 @@ func (p *Poller) Polling() error { doChores = false task := p.urgentAsyncTaskQueue.Dequeue() for ; task != nil; task = p.urgentAsyncTaskQueue.Dequeue() { - err = task.Run(task.Arg) + err = task.Exec(task.Param) if errors.Is(err, errorx.ErrEngineShutdown) { return err } @@ -157,7 +157,7 @@ func (p *Poller) Polling() error { if task = p.asyncTaskQueue.Dequeue(); task == nil { break } - err = task.Run(task.Arg) + err = task.Exec(task.Param) if errors.Is(err, errorx.ErrEngineShutdown) { return err } diff --git a/internal/netpoll/poller_kqueue_default.go b/internal/netpoll/poller_kqueue_default.go index 4fbcd8e60..c0c513c4a 100644 --- a/internal/netpoll/poller_kqueue_default.go +++ b/internal/netpoll/poller_kqueue_default.go @@ -76,9 +76,9 @@ func (p *Poller) Close() error { // any asks other than high-priority tasks will be shunted to asyncTaskQueue. // // Note that asyncTaskQueue is a queue of low-priority whose size may grow large and tasks in it may backlog. -func (p *Poller) Trigger(priority queue.EventPriority, fn queue.TaskFunc, arg interface{}) (err error) { +func (p *Poller) Trigger(priority queue.EventPriority, fn queue.Func, param any) (err error) { task := queue.GetTask() - task.Run, task.Arg = fn, arg + task.Exec, task.Param = fn, param if priority > queue.HighPriority && p.urgentAsyncTaskQueue.Length() >= p.highPriorityEventsThreshold { p.asyncTaskQueue.Enqueue(task) } else { @@ -130,7 +130,7 @@ func (p *Poller) Polling(callback PollEventHandler) error { doChores = false task := p.urgentAsyncTaskQueue.Dequeue() for ; task != nil; task = p.urgentAsyncTaskQueue.Dequeue() { - err = task.Run(task.Arg) + err = task.Exec(task.Param) if errors.Is(err, errorx.ErrEngineShutdown) { return err } @@ -140,7 +140,7 @@ func (p *Poller) Polling(callback PollEventHandler) error { if task = p.asyncTaskQueue.Dequeue(); task == nil { break } - err = task.Run(task.Arg) + err = task.Exec(task.Param) if errors.Is(err, errorx.ErrEngineShutdown) { return err } diff --git a/internal/netpoll/poller_kqueue_ultimate.go b/internal/netpoll/poller_kqueue_ultimate.go index e12caf0c1..e523d257d 100644 --- a/internal/netpoll/poller_kqueue_ultimate.go +++ b/internal/netpoll/poller_kqueue_ultimate.go @@ -77,9 +77,9 @@ func (p *Poller) Close() error { // any asks other than high-priority tasks will be shunted to asyncTaskQueue. // // Note that asyncTaskQueue is a queue of low-priority whose size may grow large and tasks in it may backlog. -func (p *Poller) Trigger(priority queue.EventPriority, fn queue.TaskFunc, arg interface{}) (err error) { +func (p *Poller) Trigger(priority queue.EventPriority, fn queue.Func, param any) (err error) { task := queue.GetTask() - task.Run, task.Arg = fn, arg + task.Exec, task.Param = fn, param if priority > queue.HighPriority && p.urgentAsyncTaskQueue.Length() >= p.highPriorityEventsThreshold { p.asyncTaskQueue.Enqueue(task) } else { @@ -132,7 +132,7 @@ func (p *Poller) Polling() error { doChores = false task := p.urgentAsyncTaskQueue.Dequeue() for ; task != nil; task = p.urgentAsyncTaskQueue.Dequeue() { - err = task.Run(task.Arg) + err = task.Exec(task.Param) if errors.Is(err, errorx.ErrEngineShutdown) { return err } @@ -142,7 +142,7 @@ func (p *Poller) Polling() error { if task = p.asyncTaskQueue.Dequeue(); task == nil { break } - err = task.Run(task.Arg) + err = task.Exec(task.Param) if errors.Is(err, errorx.ErrEngineShutdown) { return err } diff --git a/internal/queue/queue.go b/internal/queue/queue.go index 826f1843f..a85b0d7b9 100644 --- a/internal/queue/queue.go +++ b/internal/queue/queue.go @@ -16,16 +16,16 @@ package queue import "sync" -// TaskFunc is the callback function executed by poller. -type TaskFunc func(interface{}) error +// Func is the callback function executed by poller. +type Func func(any) error // Task is a wrapper that contains function and its argument. type Task struct { - Run TaskFunc - Arg interface{} + Exec Func + Param any } -var taskPool = sync.Pool{New: func() interface{} { return new(Task) }} +var taskPool = sync.Pool{New: func() any { return new(Task) }} // GetTask gets a cached Task from pool. func GetTask() *Task { @@ -34,7 +34,7 @@ func GetTask() *Task { // PutTask puts the trashy Task back in pool. func PutTask(task *Task) { - task.Run, task.Arg = nil, nil + task.Exec, task.Param = nil, nil taskPool.Put(task) } diff --git a/pkg/logging/logger.go b/pkg/logging/logger.go index 86ad17e93..5ccc3e7da 100644 --- a/pkg/logging/logger.go +++ b/pkg/logging/logger.go @@ -244,35 +244,35 @@ func Error(err error) { } // Debugf logs messages at DEBUG level. -func Debugf(format string, args ...interface{}) { +func Debugf(format string, args ...any) { mu.RLock() defaultLogger.Debugf(format, args...) mu.RUnlock() } // Infof logs messages at INFO level. -func Infof(format string, args ...interface{}) { +func Infof(format string, args ...any) { mu.RLock() defaultLogger.Infof(format, args...) mu.RUnlock() } // Warnf logs messages at WARN level. -func Warnf(format string, args ...interface{}) { +func Warnf(format string, args ...any) { mu.RLock() defaultLogger.Warnf(format, args...) mu.RUnlock() } // Errorf logs messages at ERROR level. -func Errorf(format string, args ...interface{}) { +func Errorf(format string, args ...any) { mu.RLock() defaultLogger.Errorf(format, args...) mu.RUnlock() } // Fatalf logs messages at FATAL level. -func Fatalf(format string, args ...interface{}) { +func Fatalf(format string, args ...any) { mu.RLock() defaultLogger.Fatalf(format, args...) mu.RUnlock() @@ -281,13 +281,13 @@ func Fatalf(format string, args ...interface{}) { // Logger is used for logging formatted messages. type Logger interface { // Debugf logs messages at DEBUG level. - Debugf(format string, args ...interface{}) + Debugf(format string, args ...any) // Infof logs messages at INFO level. - Infof(format string, args ...interface{}) + Infof(format string, args ...any) // Warnf logs messages at WARN level. - Warnf(format string, args ...interface{}) + Warnf(format string, args ...any) // Errorf logs messages at ERROR level. - Errorf(format string, args ...interface{}) + Errorf(format string, args ...any) // Fatalf logs messages at FATAL level. - Fatalf(format string, args ...interface{}) + Fatalf(format string, args ...any) } diff --git a/pkg/pool/goroutine/goroutine.go b/pkg/pool/goroutine/goroutine.go index fa7a59d12..41ae371a7 100644 --- a/pkg/pool/goroutine/goroutine.go +++ b/pkg/pool/goroutine/goroutine.go @@ -47,7 +47,7 @@ type antsLogger struct { } // Printf implements the ants.Logger interface. -func (l antsLogger) Printf(format string, args ...interface{}) { +func (l antsLogger) Printf(format string, args ...any) { l.Infof(format, args...) } @@ -57,8 +57,8 @@ func Default() *Pool { ExpiryDuration: ExpiryDuration, Nonblocking: Nonblocking, Logger: &antsLogger{logging.GetDefaultLogger()}, - PanicHandler: func(i interface{}) { - logging.Errorf("goroutine pool panic: %v", i) + PanicHandler: func(a any) { + logging.Errorf("goroutine pool panic: %v", a) }, } defaultAntsPool, _ := ants.NewPool(DefaultAntsPoolSize, ants.WithOptions(options))