diff --git a/acceptor_windows.go b/acceptor_windows.go index d9d6339cd..11ea152fd 100644 --- a/acceptor_windows.go +++ b/acceptor_windows.go @@ -18,7 +18,6 @@ import ( "errors" "net" "runtime" - "sync/atomic" errorx "github.com/panjf2000/gnet/v2/pkg/errors" ) @@ -36,7 +35,7 @@ func (eng *engine) listenStream(ln net.Listener) (err error) { tc, e := ln.Accept() if e != nil { err = e - if atomic.LoadInt32(&eng.beingShutdown) == 0 { + if !eng.beingShutdown.Load() { eng.opts.Logger.Errorf("Accept() fails due to error: %v", err) } else if errors.Is(err, net.ErrClosed) { err = errors.Join(err, errorx.ErrEngineShutdown) @@ -74,7 +73,7 @@ func (eng *engine) ListenUDP(pc net.PacketConn) (err error) { n, addr, e := pc.ReadFrom(buffer[:]) if e != nil { err = e - if atomic.LoadInt32(&eng.beingShutdown) == 0 { + if !eng.beingShutdown.Load() { eng.opts.Logger.Errorf("failed to receive data from UDP fd due to error:%v", err) } else if errors.Is(err, net.ErrClosed) { err = errors.Join(err, errorx.ErrEngineShutdown) diff --git a/client_unix.go b/client_unix.go index 7fef073b9..728b3b615 100644 --- a/client_unix.go +++ b/client_unix.go @@ -22,7 +22,6 @@ import ( "errors" "net" "strconv" - "sync" "syscall" "golang.org/x/sync/errgroup" @@ -66,20 +65,17 @@ func NewClient(eh EventHandler, opts ...Option) (cli *Client, err error) { return } - shutdownCtx, shutdown := context.WithCancel(context.Background()) + rootCtx, shutdown := context.WithCancel(context.Background()) + eg, ctx := errgroup.WithContext(rootCtx) eng := engine{ listeners: make(map[int]*listener), opts: options, + turnOff: shutdown, eventHandler: eh, - workerPool: struct { + concurrency: struct { *errgroup.Group - shutdownCtx context.Context - shutdown context.CancelFunc - once sync.Once - }{&errgroup.Group{}, shutdownCtx, shutdown, sync.Once{}}, - } - if options.Ticker { - eng.ticker.ctx, eng.ticker.cancel = context.WithCancel(context.Background()) + ctx context.Context + }{eg, ctx}, } el := eventloop{ listeners: eng.listeners, @@ -124,10 +120,14 @@ func NewClient(eh EventHandler, opts ...Option) (cli *Client, err error) { func (cli *Client) Start() error { logging.Infof("Starting gnet client with 1 event-loop") cli.el.eventHandler.OnBoot(Engine{cli.el.engine}) - cli.el.engine.workerPool.Go(cli.el.run) + cli.el.engine.concurrency.Go(cli.el.run) // Start the ticker. if cli.opts.Ticker { - go cli.el.ticker(cli.el.engine.ticker.ctx) + ctx := cli.el.engine.concurrency.ctx + cli.el.engine.concurrency.Go(func() error { + cli.el.ticker(ctx) + return nil + }) } logging.Debugf("default logging level is %s", logging.LogLevel()) return nil @@ -136,11 +136,7 @@ func (cli *Client) Start() error { // Stop stops the client event-loop. func (cli *Client) Stop() (err error) { logging.Error(cli.el.poller.Trigger(queue.HighPriority, func(_ any) error { return errorx.ErrEngineShutdown }, nil)) - // Stop the ticker. - if cli.opts.Ticker { - cli.el.engine.ticker.cancel() - } - _ = cli.el.engine.workerPool.Wait() + err = cli.el.engine.concurrency.Wait() logging.Error(cli.el.poller.Close()) cli.el.eventHandler.OnShutdown(Engine{cli.el.engine}) logging.Cleanup() diff --git a/client_windows.go b/client_windows.go index ae2750bb7..96806414d 100644 --- a/client_windows.go +++ b/client_windows.go @@ -48,17 +48,17 @@ func NewClient(eh EventHandler, opts ...Option) (cli *Client, err error) { } logging.SetDefaultLoggerAndFlusher(logger, logFlusher) - shutdownCtx, shutdown := context.WithCancel(context.Background()) + rootCtx, shutdown := context.WithCancel(context.Background()) + eg, ctx := errgroup.WithContext(rootCtx) eng := &engine{ - listeners: []*listener{}, - opts: options, - workerPool: struct { - *errgroup.Group - shutdownCtx context.Context - shutdown context.CancelFunc - once sync.Once - }{&errgroup.Group{}, shutdownCtx, shutdown, sync.Once{}}, + listeners: []*listener{}, + opts: options, + turnOff: shutdown, eventHandler: eh, + concurrency: struct { + *errgroup.Group + ctx context.Context + }{eg, ctx}, } cli.el = &eventloop{ ch: make(chan any, 1024), @@ -71,11 +71,11 @@ func NewClient(eh EventHandler, opts ...Option) (cli *Client, err error) { func (cli *Client) Start() error { cli.el.eventHandler.OnBoot(Engine{cli.el.eng}) - cli.el.eng.workerPool.Go(cli.el.run) + cli.el.eng.concurrency.Go(cli.el.run) if cli.opts.Ticker { - cli.el.eng.ticker.ctx, cli.el.eng.ticker.cancel = context.WithCancel(context.Background()) - cli.el.eng.workerPool.Go(func() error { - cli.el.ticker(cli.el.eng.ticker.ctx) + ctx := cli.el.eng.concurrency.ctx + cli.el.eng.concurrency.Go(func() error { + cli.el.ticker(ctx) return nil }) } @@ -85,10 +85,7 @@ func (cli *Client) Start() error { func (cli *Client) Stop() (err error) { cli.el.ch <- errorx.ErrEngineShutdown - if cli.opts.Ticker { - cli.el.eng.ticker.cancel() - } - _ = cli.el.eng.workerPool.Wait() + err = cli.el.eng.concurrency.Wait() cli.el.eventHandler.OnShutdown(Engine{cli.el.eng}) logging.Cleanup() return diff --git a/engine_unix.go b/engine_unix.go index 607e2f01f..0ede0cdca 100644 --- a/engine_unix.go +++ b/engine_unix.go @@ -22,7 +22,6 @@ import ( "errors" "runtime" "strings" - "sync" "sync/atomic" "golang.org/x/sync/errgroup" @@ -35,27 +34,22 @@ import ( ) type engine struct { - listeners map[int]*listener // listeners for accepting incoming connections - opts *Options // options with engine - ingress *eventloop // main event-loop that monitors all listeners - eventLoops loadBalancer // event-loops for handling events - inShutdown int32 // whether the engine is in shutdown - ticker struct { - ctx context.Context // context for ticker - cancel context.CancelFunc // function to stop the ticker - } - workerPool struct { + listeners map[int]*listener // listeners for accepting incoming connections + opts *Options // options with engine + ingress *eventloop // main event-loop that monitors all listeners + eventLoops loadBalancer // event-loops for handling events + inShutdown atomic.Bool // whether the engine is in shutdown + turnOff context.CancelFunc + eventHandler EventHandler // user eventHandler + concurrency struct { *errgroup.Group - shutdownCtx context.Context - shutdown context.CancelFunc - once sync.Once + ctx context.Context } - eventHandler EventHandler // user eventHandler } -func (eng *engine) isInShutdown() bool { - return atomic.LoadInt32(&eng.inShutdown) == 1 +func (eng *engine) isShutdown() bool { + return eng.inShutdown.Load() } // shutdown signals the engine to shut down. @@ -64,9 +58,7 @@ func (eng *engine) shutdown(err error) { eng.opts.Logger.Errorf("engine is being shutdown with error: %v", err) } - eng.workerPool.once.Do(func() { - eng.workerPool.shutdown() - }) + eng.turnOff() } func (eng *engine) closeEventLoops() { @@ -88,7 +80,7 @@ func (eng *engine) closeEventLoops() { } } -func (eng *engine) runEventLoops(numEventLoop int) error { +func (eng *engine) runEventLoops(ctx context.Context, numEventLoop int) error { var el0 *eventloop lns := eng.listeners // Create loops locally and bind the listeners. @@ -129,13 +121,13 @@ func (eng *engine) runEventLoops(numEventLoop int) error { // Start event-loops in background. eng.eventLoops.iterate(func(_ int, el *eventloop) bool { - eng.workerPool.Go(el.run) + eng.concurrency.Go(el.run) return true }) if el0 != nil { - eng.workerPool.Go(func() error { - el0.ticker(eng.ticker.ctx) + eng.concurrency.Go(func() error { + el0.ticker(ctx) return nil }) } @@ -143,7 +135,7 @@ func (eng *engine) runEventLoops(numEventLoop int) error { return nil } -func (eng *engine) activateReactors(numEventLoop int) error { +func (eng *engine) activateReactors(ctx context.Context, numEventLoop int) error { for i := 0; i < numEventLoop; i++ { p, err := netpoll.OpenPoller() if err != nil { @@ -161,7 +153,7 @@ func (eng *engine) activateReactors(numEventLoop int) error { // Start sub reactors in background. eng.eventLoops.iterate(func(_ int, el *eventloop) bool { - eng.workerPool.Go(el.orbit) + eng.concurrency.Go(el.orbit) return true }) @@ -183,12 +175,12 @@ func (eng *engine) activateReactors(numEventLoop int) error { eng.ingress = el // Start main reactor in background. - eng.workerPool.Go(el.rotate) + eng.concurrency.Go(el.rotate) // Start the ticker. if eng.opts.Ticker { - eng.workerPool.Go(func() error { - eng.ingress.ticker(eng.ticker.ctx) + eng.concurrency.Go(func() error { + eng.ingress.ticker(ctx) return nil }) } @@ -196,17 +188,17 @@ func (eng *engine) activateReactors(numEventLoop int) error { return nil } -func (eng *engine) start(numEventLoop int) error { +func (eng *engine) start(ctx context.Context, numEventLoop int) error { if eng.opts.ReusePort { - return eng.runEventLoops(numEventLoop) + return eng.runEventLoops(ctx, numEventLoop) } - return eng.activateReactors(numEventLoop) + return eng.activateReactors(ctx, numEventLoop) } -func (eng *engine) stop(s Engine) { +func (eng *engine) stop(ctx context.Context, s Engine) { // Wait on a signal for shutdown - <-eng.workerPool.shutdownCtx.Done() + <-ctx.Done() eng.eventHandler.OnShutdown(s) @@ -225,12 +217,7 @@ func (eng *engine) stop(s Engine) { } } - // Stop the ticker. - if eng.ticker.cancel != nil { - eng.ticker.cancel() - } - - if err := eng.workerPool.Wait(); err != nil { + if err := eng.concurrency.Wait(); err != nil { eng.opts.Logger.Errorf("engine shutdown error: %v", err) } @@ -238,7 +225,7 @@ func (eng *engine) stop(s Engine) { eng.closeEventLoops() // Put the engine into the shutdown state. - atomic.StoreInt32(&eng.inShutdown, 1) + eng.inShutdown.Store(true) } func run(eventHandler EventHandler, listeners []*listener, options *Options, addrs []string) error { @@ -261,17 +248,17 @@ func run(eventHandler EventHandler, listeners []*listener, options *Options, add for _, ln := range listeners { lns[ln.fd] = ln } - shutdownCtx, shutdown := context.WithCancel(context.Background()) + rootCtx, shutdown := context.WithCancel(context.Background()) + eg, ctx := errgroup.WithContext(rootCtx) eng := engine{ - listeners: lns, - opts: options, - workerPool: struct { - *errgroup.Group - shutdownCtx context.Context - shutdown context.CancelFunc - once sync.Once - }{&errgroup.Group{}, shutdownCtx, shutdown, sync.Once{}}, + listeners: lns, + opts: options, + turnOff: shutdown, eventHandler: eventHandler, + concurrency: struct { + *errgroup.Group + ctx context.Context + }{eg, ctx}, } switch options.LB { case RoundRobin: @@ -282,23 +269,19 @@ func run(eventHandler EventHandler, listeners []*listener, options *Options, add eng.eventLoops = new(sourceAddrHashLoadBalancer) } - if eng.opts.Ticker { - eng.ticker.ctx, eng.ticker.cancel = context.WithCancel(context.Background()) - } - e := Engine{&eng} switch eng.eventHandler.OnBoot(e) { - case None: + case None, Close: case Shutdown: return nil } - if err := eng.start(numEventLoop); err != nil { + if err := eng.start(ctx, numEventLoop); err != nil { eng.closeEventLoops() eng.opts.Logger.Errorf("gnet engine is stopping with error: %v", err) return err } - defer eng.stop(e) + defer eng.stop(rootCtx, e) for _, addr := range addrs { allEngines.Store(addr, &eng) diff --git a/engine_windows.go b/engine_windows.go index 3eabb03ba..3137cf4a8 100644 --- a/engine_windows.go +++ b/engine_windows.go @@ -19,7 +19,6 @@ import ( "errors" "runtime" "strings" - "sync" "sync/atomic" "golang.org/x/sync/errgroup" @@ -29,27 +28,22 @@ import ( ) type engine struct { - listeners []*listener - opts *Options // options with engine - eventLoops loadBalancer // event-loops for handling events - ticker struct { - ctx context.Context - cancel context.CancelFunc - } - inShutdown int32 // whether the engine is in shutdown - beingShutdown int32 // whether the engine is being shutdown - workerPool struct { + listeners []*listener + opts *Options // options with engine + eventLoops loadBalancer // event-loops for handling events + inShutdown atomic.Bool // whether the engine is in shutdown + beingShutdown atomic.Bool // whether the engine is being shutdown + turnOff context.CancelFunc + eventHandler EventHandler // user eventHandler + concurrency struct { *errgroup.Group - shutdownCtx context.Context - shutdown context.CancelFunc - once sync.Once + ctx context.Context } - eventHandler EventHandler // user eventHandler } -func (eng *engine) isInShutdown() bool { - return atomic.LoadInt32(&eng.inShutdown) == 1 +func (eng *engine) isShutdown() bool { + return eng.inShutdown.Load() } // shutdown signals the engine to shut down. @@ -57,8 +51,8 @@ func (eng *engine) shutdown(err error) { if err != nil && !errors.Is(err, errorx.ErrEngineShutdown) { eng.opts.Logger.Errorf("engine is being shutdown with error: %v", err) } - eng.workerPool.shutdown() - atomic.StoreInt32(&eng.beingShutdown, 1) + eng.turnOff() + eng.beingShutdown.Store(true) } func (eng *engine) closeEventLoops() { @@ -71,7 +65,7 @@ func (eng *engine) closeEventLoops() { } } -func (eng *engine) start(numEventLoop int) error { +func (eng *engine) start(ctx context.Context, numEventLoop int) error { for i := 0; i < numEventLoop; i++ { el := eventloop{ ch: make(chan any, 1024), @@ -81,10 +75,10 @@ func (eng *engine) start(numEventLoop int) error { eventHandler: eng.eventHandler, } eng.eventLoops.register(&el) - eng.workerPool.Go(el.run) + eng.concurrency.Go(el.run) if i == 0 && eng.opts.Ticker { - eng.workerPool.Go(func() error { - el.ticker(eng.ticker.ctx) + eng.concurrency.Go(func() error { + el.ticker(ctx) return nil }) } @@ -93,11 +87,11 @@ func (eng *engine) start(numEventLoop int) error { for _, ln := range eng.listeners { l := ln if l.pc != nil { - eng.workerPool.Go(func() error { + eng.concurrency.Go(func() error { return eng.ListenUDP(l.pc) }) } else { - eng.workerPool.Go(func() error { + eng.concurrency.Go(func() error { return eng.listenStream(l.ln) }) } @@ -106,24 +100,18 @@ func (eng *engine) start(numEventLoop int) error { return nil } -func (eng *engine) stop(engine Engine) error { - <-eng.workerPool.shutdownCtx.Done() +func (eng *engine) stop(ctx context.Context, engine Engine) { + <-ctx.Done() eng.eventHandler.OnShutdown(engine) - if eng.ticker.cancel != nil { - eng.ticker.cancel() - } - eng.closeEventLoops() - if err := eng.workerPool.Wait(); err != nil && !errors.Is(err, errorx.ErrEngineShutdown) { + if err := eng.concurrency.Wait(); err != nil && !errors.Is(err, errorx.ErrEngineShutdown) { eng.opts.Logger.Errorf("engine shutdown error: %v", err) } - atomic.StoreInt32(&eng.inShutdown, 1) - - return nil + eng.inShutdown.Store(true) } func run(eventHandler EventHandler, listeners []*listener, options *Options, addrs []string) error { @@ -139,17 +127,17 @@ func run(eventHandler EventHandler, listeners []*listener, options *Options, add logging.Infof("Launching gnet with %d event-loops, listening on: %s", numEventLoop, strings.Join(addrs, " | ")) - shutdownCtx, shutdown := context.WithCancel(context.Background()) + rootCtx, shutdown := context.WithCancel(context.Background()) + eg, ctx := errgroup.WithContext(rootCtx) eng := engine{ opts: options, - eventHandler: eventHandler, listeners: listeners, - workerPool: struct { + turnOff: shutdown, + eventHandler: eventHandler, + concurrency: struct { *errgroup.Group - shutdownCtx context.Context - shutdown context.CancelFunc - once sync.Once - }{&errgroup.Group{}, shutdownCtx, shutdown, sync.Once{}}, + ctx context.Context + }{eg, ctx}, } switch options.LB { @@ -166,22 +154,18 @@ func run(eventHandler EventHandler, listeners []*listener, options *Options, add eng.eventLoops = new(sourceAddrHashLoadBalancer) } - if options.Ticker { - eng.ticker.ctx, eng.ticker.cancel = context.WithCancel(context.Background()) - } - engine := Engine{eng: &eng} switch eventHandler.OnBoot(engine) { - case None: + case None, Close: case Shutdown: return nil } - if err := eng.start(numEventLoop); err != nil { + if err := eng.start(ctx, numEventLoop); err != nil { eng.opts.Logger.Errorf("gnet engine is stopping with error: %v", err) return err } - defer eng.stop(engine) //nolint:errcheck + defer eng.stop(rootCtx, engine) for _, addr := range addrs { allEngines.Store(addr, &eng) diff --git a/gnet.go b/gnet.go index b2572932f..fe29427fc 100644 --- a/gnet.go +++ b/gnet.go @@ -54,7 +54,7 @@ func (e Engine) Validate() error { if e.eng == nil || len(e.eng.listeners) == 0 { return errors.ErrEmptyEngine } - if e.eng.isInShutdown() { + if e.eng.isShutdown() { return errors.ErrEngineInShutdown } return nil @@ -101,7 +101,7 @@ func (e Engine) Stop(ctx context.Context) error { ticker := time.NewTicker(shutdownPollInterval) defer ticker.Stop() for { - if e.eng.isInShutdown() { + if e.eng.isShutdown() { return nil } select { @@ -599,14 +599,14 @@ func Stop(ctx context.Context, protoAddr string) error { return errors.ErrEngineInShutdown } - if eng.isInShutdown() { + if eng.isShutdown() { return errors.ErrEngineInShutdown } ticker := time.NewTicker(shutdownPollInterval) defer ticker.Stop() for { - if eng.isInShutdown() { + if eng.isShutdown() { return nil } select {