diff --git a/endless.go b/endless.go index 9604a20..e9c7be7 100644 --- a/endless.go +++ b/endless.go @@ -1,6 +1,7 @@ package endless import ( + "context" "crypto/tls" "errors" "fmt" @@ -15,7 +16,6 @@ import ( "sync" "syscall" "time" - // "github.com/fvbock/uds-go/introspect" ) @@ -80,13 +80,15 @@ type endlessServer struct { state uint8 lock *sync.RWMutex BeforeBegin func(add string) + Ctx context.Context + Cancel context.CancelFunc } /* NewServer returns an intialized endlessServer Object. Calling Serve on it will actually "start" the server. */ -func NewServer(addr string, handler http.Handler) (srv *endlessServer) { +func NewServer(ctx context.Context, addr string, handler http.Handler) (srv *endlessServer) { runningServerReg.Lock() defer runningServerReg.Unlock() @@ -132,6 +134,7 @@ func NewServer(addr string, handler http.Handler) (srv *endlessServer) { srv.Server.WriteTimeout = DefaultWriteTimeOut srv.Server.MaxHeaderBytes = DefaultMaxHeaderBytes srv.Server.Handler = handler + srv.Ctx, srv.Cancel = context.WithCancel(ctx) srv.BeforeBegin = func(addr string) { log.Println(syscall.Getpid(), addr) @@ -143,16 +146,38 @@ func NewServer(addr string, handler http.Handler) (srv *endlessServer) { return } +/* +ListenAndServeCtx listens on the TCP network address addr and then calls Serve +with handler to handle requests on incoming connections. Handler is typically +nil, in which case the DefaultServeMux is used. +*/ +func ListenAndServeCtx(ctx context.Context, addr string, handler http.Handler) error { + server := NewServer(ctx, addr, handler) + return server.ListenAndServe() +} + /* ListenAndServe listens on the TCP network address addr and then calls Serve with handler to handle requests on incoming connections. Handler is typically nil, in which case the DefaultServeMux is used. */ func ListenAndServe(addr string, handler http.Handler) error { - server := NewServer(addr, handler) + server := NewServer(context.Background(), addr, handler) return server.ListenAndServe() } +/* +ListenAndServeTLSCtx acts identically to ListenAndServe, except that it expects +HTTPS connections. Additionally, files containing a certificate and matching +private key for the server must be provided. If the certificate is signed by a +certificate authority, the certFile should be the concatenation of the server's +certificate followed by the CA's certificate. +*/ +func ListenAndServeTLSCtx(ctx context.Context, addr string, certFile string, keyFile string, handler http.Handler) error { + server := NewServer(ctx, addr, handler) + return server.ListenAndServeTLS(certFile, keyFile) +} + /* ListenAndServeTLS acts identically to ListenAndServe, except that it expects HTTPS connections. Additionally, files containing a certificate and matching @@ -161,7 +186,7 @@ certificate authority, the certFile should be the concatenation of the server's certificate followed by the CA's certificate. */ func ListenAndServeTLS(addr string, certFile string, keyFile string, handler http.Handler) error { - server := NewServer(addr, handler) + server := NewServer(context.Background(), addr, handler) return server.ListenAndServeTLS(certFile, keyFile) } @@ -192,7 +217,9 @@ down the server. func (srv *endlessServer) Serve() (err error) { defer log.Println(syscall.Getpid(), "Serve() returning...") srv.setState(STATE_RUNNING) + srv.Server.BaseContext = func(list net.Listener) context.Context { return srv.Ctx } err = srv.Server.Serve(srv.EndlessListener) + log.Println(syscall.Getpid(), "Waiting for connections to finish...") srv.wg.Wait() srv.setState(STATE_TERMINATE) @@ -376,6 +403,7 @@ func (srv *endlessServer) shutdown() { if DefaultHammerTime >= 0 { go srv.hammerTime(DefaultHammerTime) } + srv.Cancel() // disable keep-alives on existing connections srv.SetKeepAlivesEnabled(false) err := srv.EndlessListener.Close() diff --git a/examples/hook.go b/examples/hook.go index 1b4ceae..a65c82a 100644 --- a/examples/hook.go +++ b/examples/hook.go @@ -6,8 +6,8 @@ import ( "os" "syscall" - "github.com/fvbock/endless" "github.com/gorilla/mux" + "github.com/voyager3m/endless/v2" ) func handler(w http.ResponseWriter, r *http.Request) { diff --git a/examples/multi_port.go b/examples/multi_port.go index fe98e04..f15a02e 100644 --- a/examples/multi_port.go +++ b/examples/multi_port.go @@ -1,14 +1,15 @@ package main import ( + "context" "log" "net/http" "os" "sync" "time" - "github.com/fvbock/endless" "github.com/gorilla/mux" + "github.com/voyager3m/endless" ) func handler(w http.ResponseWriter, r *http.Request) { @@ -41,7 +42,7 @@ func main() { w.Add(2) go func() { time.Sleep(time.Second) - err := endless.ListenAndServe("localhost:4242", mux1) + err := endless.ListenAndServe(context.Background(), "localhost:4242", mux1) if err != nil { log.Println(err) } @@ -49,7 +50,7 @@ func main() { w.Done() }() go func() { - err := endless.ListenAndServe("localhost:4243", mux2) + err := endless.ListenAndServe(context.Background(), "localhost:4243", mux2) if err != nil { log.Println(err) } diff --git a/examples/simple.go b/examples/simple.go index d62817e..87c5db6 100644 --- a/examples/simple.go +++ b/examples/simple.go @@ -1,12 +1,13 @@ package main import ( + "context" "log" "net/http" "os" - "github.com/fvbock/endless" "github.com/gorilla/mux" + "github.com/voyager3m/endless" ) func handler(w http.ResponseWriter, r *http.Request) { @@ -18,7 +19,7 @@ func main() { mux1.HandleFunc("/hello", handler). Methods("GET") - err := endless.ListenAndServe("localhost:4242", mux1) + err := endless.ListenAndServe(context.Background(), "localhost:4242", mux1) if err != nil { log.Println(err) } diff --git a/examples/testserver.go b/examples/testserver.go index 3f4ddc1..965766d 100644 --- a/examples/testserver.go +++ b/examples/testserver.go @@ -1,14 +1,15 @@ package main import ( + "context" "log" "math/rand" "net/http" "os" "time" - "github.com/fvbock/endless" "github.com/gorilla/mux" + "github.com/voyager3m/endless" ) func handler(w http.ResponseWriter, r *http.Request) { @@ -22,7 +23,7 @@ func main() { mux.HandleFunc("/foo", handler). Methods("GET") - err := endless.ListenAndServe("localhost:4242", mux) + err := endless.ListenAndServe(context.Background(), "localhost:4242", mux) if err != nil { log.Println(err) } diff --git a/examples/tls.go b/examples/tls.go index ddffa04..a4afcff 100644 --- a/examples/tls.go +++ b/examples/tls.go @@ -1,12 +1,13 @@ package main import ( + "context" "log" "net/http" "os" - "github.com/fvbock/endless" "github.com/gorilla/mux" + "github.com/voyager3m/endless" ) func handler(w http.ResponseWriter, r *http.Request) { @@ -19,7 +20,7 @@ func main() { mux1.HandleFunc("/hello", handler). Methods("GET") - err := endless.ListenAndServeTLS("localhost:4242", "cert.pem", "key.pem", mux1) + err := endless.ListenAndServeTLS(context.Background(), "localhost:4242", "cert.pem", "key.pem", mux1) if err != nil { log.Println(err) } diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..dc5a67a --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module github.com/voyager3m/endless/v2 + +go 1.15