diff --git a/ap/resolver.go b/ap/resolver.go index a1e07395..ac95836c 100644 --- a/ap/resolver.go +++ b/ap/resolver.go @@ -18,7 +18,6 @@ package ap import ( "context" - "net/http" "github.com/dimkr/tootik/httpsig" ) @@ -40,5 +39,5 @@ const ( type Resolver interface { ResolveID(ctx context.Context, keys [2]httpsig.Key, id string, flags ResolverFlag) (*Actor, error) Resolve(ctx context.Context, keys [2]httpsig.Key, host, name string, flags ResolverFlag) (*Actor, error) - Get(ctx context.Context, keys [2]httpsig.Key, url string) (*http.Response, error) + Get(ctx context.Context, keys [2]httpsig.Key, url string) (int, []byte, func(), error) } diff --git a/cluster/server.go b/cluster/server.go index 81ea6494..834c8db3 100644 --- a/cluster/server.go +++ b/cluster/server.go @@ -188,6 +188,11 @@ func NewServer(ctx context.Context, t *testing.T, domain string, client fed.Clie DB: db, ActorKeys: nobodyKeys, Resolver: resolver, + Buffers: sync.Pool{ + New: func() any { + return make([]byte, cfg.MaxRequestBodySize) + }, + }, }).NewHandler() if err != nil { t.Fatalf("Failed to run create the federation handler: %v", err) @@ -220,6 +225,11 @@ func NewServer(ctx context.Context, t *testing.T, domain string, client fed.Clie Config: &cfg, DB: db, Handler: handler, + Buffers: sync.Pool{ + New: func() any { + return make([]byte, 1024) + }, + }, }, Backend: backend, Inbox: localInbox, diff --git a/cmd/tootik/main.go b/cmd/tootik/main.go index 15be2f85..93c569a7 100644 --- a/cmd/tootik/main.go +++ b/cmd/tootik/main.go @@ -342,6 +342,11 @@ func main() { Key: *key, Plain: *plain, BlockList: blockList, + Buffers: sync.Pool{ + New: func() any { + return make([]byte, cfg.MaxRequestBodySize) + }, + }, }, }, { @@ -354,6 +359,11 @@ func main() { Addr: *gemAddr, CertPath: *gemCert, KeyPath: *gemKey, + Buffers: sync.Pool{ + New: func() any { + return make([]byte, 1024+2) + }, + }, }, }, { @@ -363,6 +373,11 @@ func main() { Config: &cfg, Handler: handler, Addr: *gopherAddr, + Buffers: sync.Pool{ + New: func() any { + return make([]byte, 256) + }, + }, }, }, { @@ -372,6 +387,11 @@ func main() { Config: &cfg, DB: db, Addr: *fingerAddr, + Buffers: sync.Pool{ + New: func() any { + return make([]byte, 34) + }, + }, }, }, { diff --git a/fed/deliver.go b/fed/deliver.go index 96086a0e..24b94386 100644 --- a/fed/deliver.go +++ b/fed/deliver.go @@ -258,9 +258,9 @@ func (q *Queue) deliverWithTimeout(parent context.Context, task *deliveryTask) e req := task.Request.WithContext(ctx) - resp, err := q.Resolver.send(task.Keys, req) + _, _, cleanup, err := q.Resolver.send(task.Keys, req) if err == nil { - resp.Body.Close() + cleanup() } return err } diff --git a/fed/followers.go b/fed/followers.go index 87ef11d6..25afd9e9 100644 --- a/fed/followers.go +++ b/fed/followers.go @@ -23,7 +23,6 @@ import ( "encoding/json" "errors" "fmt" - "io" "log/slog" "net/http" "net/url" @@ -248,20 +247,16 @@ func (d *followersDigest) Sync(ctx context.Context, domain string, cfg *cfg.Conf slog.Info("Synchronizing followers", "followed", d.Followed) - resp, err := resolver.Get(ctx, keys, d.URL) + _, body, cleanup, err := resolver.Get(ctx, keys, d.URL) if err != nil { return err } - defer resp.Body.Close() - - if resp.ContentLength > cfg.MaxResponseBodySize { - return errors.New("response is too big") - } + defer cleanup() var remote struct { OrderedItems ap.Audience `json:"orderedItems"` } - if err := json.NewDecoder(io.LimitReader(resp.Body, cfg.MaxResponseBodySize)).Decode(&remote); err != nil { + if err := json.Unmarshal(body, &remote); err != nil { return err } diff --git a/fed/inbox.go b/fed/inbox.go index 85dcd0c8..a7de2a45 100644 --- a/fed/inbox.go +++ b/fed/inbox.go @@ -219,23 +219,14 @@ func (l *Listener) validateActivity(activity *ap.Activity, origin string, depth } func (l *Listener) fetchObject(ctx context.Context, id string, keys [2]httpsig.Key) (bool, []byte, error) { - resp, err := l.Resolver.Get(ctx, keys, id) + statusCode, body, cleanup, err := l.Resolver.Get(ctx, keys, id) if err != nil { - if resp != nil && (resp.StatusCode == http.StatusNotFound || resp.StatusCode == http.StatusGone) { + if statusCode == http.StatusNotFound || statusCode == http.StatusGone { return false, nil, err } return true, nil, err } - defer resp.Body.Close() - - if resp.ContentLength > l.Config.MaxRequestBodySize { - return true, nil, fmt.Errorf("object is too big: %d", resp.ContentLength) - } - - body, err := io.ReadAll(io.LimitReader(resp.Body, l.Config.MaxRequestBodySize)) - if err != nil { - return true, nil, err - } + defer cleanup() if !ap.IsPortable(id) { return true, body, nil @@ -320,11 +311,15 @@ func (l *Listener) doHandleInbox(w http.ResponseWriter, r *http.Request, keys [2 return } - rawActivity, err := io.ReadAll(io.LimitReader(r.Body, l.Config.MaxRequestBodySize)) - if err != nil { + buf := l.Buffers.Get().([]byte) + defer l.Buffers.Put(buf) + + n, err := r.Body.Read(buf) + if err != nil && !errors.Is(err, io.EOF) { w.WriteHeader(http.StatusInternalServerError) return } + rawActivity := buf[:n] var activity ap.Activity if err := json.Unmarshal(rawActivity, &activity); err != nil { diff --git a/fed/listener.go b/fed/listener.go index 9eacb401..19e8de01 100644 --- a/fed/listener.go +++ b/fed/listener.go @@ -46,6 +46,7 @@ type Listener struct { Key string Plain bool BlockList *BlockList + Buffers sync.Pool } const certReloadDelay = time.Second * 5 diff --git a/fed/resolve.go b/fed/resolve.go index 16733698..9e0d78e1 100644 --- a/fed/resolve.go +++ b/fed/resolve.go @@ -23,12 +23,12 @@ import ( "errors" "fmt" "hash/crc32" - "io" "log/slog" "net" "net/http" "net/url" "strings" + "sync" "time" "github.com/dimkr/tootik/ap" @@ -69,6 +69,11 @@ func NewResolver(blockedDomains *BlockList, domain string, cfg *cfg.Config, clie Config: cfg, client: client, DB: db, + Buffers: sync.Pool{ + New: func() any { + return make([]byte, cfg.MaxResponseBodySize) + }, + }, }, BlockedDomains: blockedDomains, db: db, @@ -174,8 +179,8 @@ func deleteActor(ctx context.Context, db *sql.DB, id string) { } } -func (r *Resolver) handleFetchFailure(ctx context.Context, fetched string, cachedActor *ap.Actor, sinceLastUpdate time.Duration, resp *http.Response, err error) (*ap.Actor, *ap.Actor, error) { - if resp != nil && (resp.StatusCode == http.StatusGone || resp.StatusCode == http.StatusNotFound) { +func (r *Resolver) handleFetchFailure(ctx context.Context, fetched string, cachedActor *ap.Actor, sinceLastUpdate time.Duration, statusCode int, err error) (*ap.Actor, *ap.Actor, error) { + if statusCode == http.StatusGone || statusCode == http.StatusNotFound { if cachedActor != nil { slog.Warn("Actor is gone, deleting associated objects", "id", cachedActor.ID) deleteActor(ctx, r.db, cachedActor.ID) @@ -289,18 +294,14 @@ func (r *Resolver) tryResolve(ctx context.Context, keys [2]httpsig.Key, host, na } req.Header.Add("Accept", "application/json") - resp, err := r.send(keys, req) + statusCode, body, cleanup, err := r.send(keys, req) if err != nil { - return r.handleFetchFailure(ctx, finger, cachedActor, sinceLastUpdate, resp, err) - } - defer resp.Body.Close() - - if resp.ContentLength > r.Config.MaxResponseBodySize { - return nil, cachedActor, fmt.Errorf("failed to decode %s response: response is too big", finger) + return r.handleFetchFailure(ctx, finger, cachedActor, sinceLastUpdate, statusCode, err) } + defer cleanup() var webFingerResponse webFingerResponse - if err := json.NewDecoder(io.LimitReader(resp.Body, r.Config.MaxResponseBodySize)).Decode(&webFingerResponse); err != nil { + if err := json.Unmarshal(body, &webFingerResponse); err != nil { return nil, cachedActor, fmt.Errorf("failed to decode %s response: %w", finger, err) } @@ -432,20 +433,11 @@ func (r *Resolver) fetchActor(ctx context.Context, keys [2]httpsig.Key, host, pr req.Header.Add("Accept", `application/ld+json; profile="https://www.w3.org/ns/activitystreams"`) - resp, err := r.send(keys, req) + resp, body, cleanup, err := r.send(keys, req) if err != nil { return r.handleFetchFailure(ctx, profile, cachedActor, sinceLastUpdate, resp, err) } - defer resp.Body.Close() - - if resp.ContentLength > r.Config.MaxResponseBodySize { - return nil, cachedActor, fmt.Errorf("failed to fetch %s: response is too big", profile) - } - - body, err := io.ReadAll(io.LimitReader(resp.Body, r.Config.MaxResponseBodySize)) - if err != nil { - return nil, cachedActor, fmt.Errorf("failed to fetch %s: %w", profile, err) - } + defer cleanup() var actor ap.Actor if err := json.Unmarshal(body, &actor); err != nil { diff --git a/fed/resolve_test.go b/fed/resolve_test.go index bc99b9fe..d68d575c 100644 --- a/fed/resolve_test.go +++ b/fed/resolve_test.go @@ -2614,7 +2614,11 @@ func TestResolve_FederatedActorOldCacheBigWebFingerResponse(t *testing.T) { }, } - cfg.MaxResponseBodySize = 1 + resolver.Buffers = sync.Pool{ + New: func() any { + return make([]byte, 1) + }, + } actor, err = resolver.Resolve(context.Background(), key, "0.0.0.0", "dan", 0) assert.NoError(err) @@ -2874,7 +2878,11 @@ func TestResolve_FederatedActorOldCacheBigActor(t *testing.T) { }, } - cfg.MaxResponseBodySize = 419 + resolver.Buffers = sync.Pool{ + New: func() any { + return make([]byte, 419) + }, + } actor, err = resolver.Resolve(context.Background(), key, "0.0.0.0", "dan", 0) assert.NoError(err) diff --git a/fed/send.go b/fed/send.go index fea3968e..be5d149a 100644 --- a/fed/send.go +++ b/fed/send.go @@ -25,6 +25,7 @@ import ( "log/slog" "math/rand/v2" "net/http" + "sync" "time" "github.com/dimkr/tootik/ap" @@ -34,23 +35,24 @@ import ( ) type sender struct { - Domain string - Config *cfg.Config - client Client - DB *sql.DB + Domain string + Config *cfg.Config + client Client + DB *sql.DB + Buffers sync.Pool } var userAgent = "tootik/" + buildinfo.Version -func (s *sender) send(keys [2]httpsig.Key, req *http.Request) (*http.Response, error) { +func (s *sender) send(keys [2]httpsig.Key, req *http.Request) (int, []byte, func(), error) { urlString := req.URL.String() if req.URL.Scheme != "https" { - return nil, fmt.Errorf("invalid scheme in %s: %s", urlString, req.URL.Scheme) + return -1, nil, nil, fmt.Errorf("invalid scheme in %s: %s", urlString, req.URL.Scheme) } if req.URL.Host == "localhost" || req.URL.Host == "localhost.localdomain" || req.URL.Host == "127.0.0.1" || req.URL.Host == "::1" { - return nil, fmt.Errorf("invalid host in %s: %s", urlString, req.URL.Host) + return -1, nil, nil, fmt.Errorf("invalid host in %s: %s", urlString, req.URL.Host) } req.Header.Set("User-Agent", userAgent) @@ -63,7 +65,7 @@ func (s *sender) send(keys [2]httpsig.Key, req *http.Request) (*http.Response, e if err := s.DB.QueryRowContext(req.Context(), `select capabilities from servers where host = ?`, req.URL.Host).Scan(&capabilities); errors.Is(err, sql.ErrNoRows) { slog.Debug("Server capabilities are unknown", "url", urlString) } else if err != nil { - return nil, fmt.Errorf("failed to query server capabilities for %s: %w", req.URL.Host, err) + return -1, nil, nil, fmt.Errorf("failed to query server capabilities for %s: %w", req.URL.Host, err) } if capabilities&ap.RFC9421Ed25519Signatures == 0 && req.Method == http.MethodPost && rand.Float32() > s.Config.Ed25519Threshold { @@ -78,37 +80,36 @@ func (s *sender) send(keys [2]httpsig.Key, req *http.Request) (*http.Response, e slog.Debug("Signing request using RFC9421 with Ed25519", "method", req.Method, "url", urlString, "key", keys[1].ID) if err := httpsig.SignRFC9421(req, keys[1], time.Now(), time.Time{}, httpsig.RFC9421DigestSHA256, "ed25519", nil); err != nil { - return nil, fmt.Errorf("failed to sign request for %s: %w", urlString, err) + return -1, nil, nil, fmt.Errorf("failed to sign request for %s: %w", urlString, err) } } else if capabilities&ap.RFC9421RSASignatures > 0 { slog.Debug("Signing request using RFC9421 with RSA", "method", req.Method, "url", urlString, "key", keys[0].ID) if err := httpsig.SignRFC9421(req, keys[0], time.Now(), time.Time{}, httpsig.RFC9421DigestSHA256, "rsa-v1_5-sha256", nil); err != nil { - return nil, fmt.Errorf("failed to sign request for %s: %w", urlString, err) + return -1, nil, nil, fmt.Errorf("failed to sign request for %s: %w", urlString, err) } } else if err := httpsig.Sign(req, keys[0], time.Now()); err != nil { slog.Debug("Signing request using draft-cavage-http-signatures", "method", req.Method, "url", urlString, "key", keys[0].ID) - return nil, fmt.Errorf("failed to sign request for %s: %w", urlString, err) + return -1, nil, nil, fmt.Errorf("failed to sign request for %s: %w", urlString, err) } resp, err := s.client.Do(req) if err != nil { - return nil, fmt.Errorf("failed to send request to %s: %w", urlString, err) + return -1, nil, nil, fmt.Errorf("failed to send request to %s: %w", urlString, err) } + defer resp.Body.Close() if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted { - defer resp.Body.Close() + buf := s.Buffers.Get().([]byte) + defer s.Buffers.Put(buf) - if resp.ContentLength > s.Config.MaxResponseBodySize { - return resp, fmt.Errorf("failed to send request to %s: %d", urlString, resp.StatusCode) + n, err := resp.Body.Read(buf) + if err != nil && !errors.Is(err, io.EOF) { + return resp.StatusCode, nil, nil, fmt.Errorf("failed to send request to %s: %d, %w", urlString, resp.StatusCode, err) } - body, err := io.ReadAll(io.LimitReader(resp.Body, s.Config.MaxResponseBodySize)) - if err != nil { - return resp, fmt.Errorf("failed to send request to %s: %d, %w", urlString, resp.StatusCode, err) - } - return resp, fmt.Errorf("failed to send request to %s: %d, %s", urlString, resp.StatusCode, string(body)) + return resp.StatusCode, nil, nil, fmt.Errorf("failed to send request to %s: %d, %s", urlString, resp.StatusCode, string(buf[:n])) } // other servers may ignore the signature if the request includes a valid integrity proof @@ -123,13 +124,21 @@ func (s *sender) send(keys [2]httpsig.Key, req *http.Request) (*http.Response, e } } - return resp, nil + buf := s.Buffers.Get().([]byte) + if n, err := resp.Body.Read(buf); err != nil && !errors.Is(err, io.EOF) { + s.Buffers.Put(buf) + return resp.StatusCode, nil, nil, fmt.Errorf("failed to send request to %s: %d, %w", urlString, resp.StatusCode, err) + } else { + return resp.StatusCode, buf[:n], func() { + s.Buffers.Put(buf) + }, nil + } } -func (s *sender) Get(ctx context.Context, keys [2]httpsig.Key, url string) (*http.Response, error) { +func (s *sender) Get(ctx context.Context, keys [2]httpsig.Key, url string) (int, []byte, func(), error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { - return nil, fmt.Errorf("failed to send request to %s: %w", url, err) + return -1, nil, nil, fmt.Errorf("failed to send request to %s: %w", url, err) } req.Header.Set("Accept", `application/ld+json; profile="https://www.w3.org/ns/activitystreams"`) diff --git a/front/finger/finger.go b/front/finger/finger.go index 56ed72ef..7dc660c6 100644 --- a/front/finger/finger.go +++ b/front/finger/finger.go @@ -37,38 +37,36 @@ import ( ) type Listener struct { - Domain string - Config *cfg.Config - DB *sql.DB - Addr string + Domain string + Config *cfg.Config + DB *sql.DB + Addr string + Buffers sync.Pool } -func (fl *Listener) handle(ctx context.Context, conn net.Conn) { - if err := conn.SetDeadline(time.Now().Add(fl.Config.GuppyRequestTimeout)); err != nil { - slog.Warn("Failed to set deadline", "error", err) - return - } +func (fl *Listener) readRequest(conn net.Conn) string { + req := fl.Buffers.Get().([]byte) + defer fl.Buffers.Put(req) - req := make([]byte, 34) total := 0 for { n, err := conn.Read(req[total:]) if err != nil && total == 0 && errors.Is(err, io.EOF) { slog.Debug("Failed to receive request", "error", err) - return + return "" } else if err != nil { slog.Warn("Failed to receive request", "error", err) - return + return "" } if n <= 0 { slog.Warn("Failed to receive request") - return + return "" } total += n if total == cap(req) { slog.Warn("Request is too big") - return + return "" } if total >= 2 && req[total-2] == '\r' && req[total-1] == '\n' { @@ -76,14 +74,23 @@ func (fl *Listener) handle(ctx context.Context, conn net.Conn) { } } - user := string(req[:total-2]) - log := slog.With(slog.String("user", user)) + return string(req[:total-2]) +} +func (fl *Listener) handle(ctx context.Context, conn net.Conn) { + if err := conn.SetDeadline(time.Now().Add(fl.Config.GuppyRequestTimeout)); err != nil { + slog.Warn("Failed to set deadline", "error", err) + return + } + + user := fl.readRequest(conn) if user == "" { - log.Warn("Invalid username specified") + slog.Warn("Invalid username specified") return } + log := slog.With(slog.String("user", user)) + sep := strings.IndexByte(user, '@') if sep > 0 && user[sep+1:] != fl.Domain { log.Warn("Invalid domain specified") diff --git a/front/gemini/gemini.go b/front/gemini/gemini.go index 8b8f38e5..d9ec4eaf 100644 --- a/front/gemini/gemini.go +++ b/front/gemini/gemini.go @@ -47,6 +47,7 @@ type Listener struct { Addr string CertPath string KeyPath string + Buffers sync.Pool } func (gl *Listener) getUser(ctx context.Context, tlsConn *tls.Conn) (*ap.Actor, [2]httpsig.Key, error) { @@ -94,44 +95,29 @@ func (gl *Listener) getUser(ctx context.Context, tlsConn *tls.Conn) (*ap.Actor, }, nil } -// Handle handles a Gemini request. -func (gl *Listener) Handle(ctx context.Context, conn net.Conn) { - if err := conn.SetDeadline(time.Now().Add(gl.Config.GeminiRequestTimeout)); err != nil { - slog.Warn("Failed to set deadline", "error", err) - return - } - - tlsConn, ok := conn.(*tls.Conn) - if !ok { - slog.Warn("Invalid connection") - return - } - - if err := tlsConn.HandshakeContext(ctx); err != nil { - slog.Warn("Handshake failed", "error", err) - return - } +func (gl *Listener) readRequest(ctx context.Context, conn net.Conn) *front.Request { + req := gl.Buffers.Get().([]byte) + defer gl.Buffers.Put(req) - req := make([]byte, 1024+2) total := 0 for { n, err := conn.Read(req[total : total+1]) if err != nil && total == 0 && errors.Is(err, io.EOF) { slog.Debug("Failed to receive request", "error", err) - return + return nil } else if err != nil { slog.Warn("Failed to receive request", "error", err) - return + return nil } if n <= 0 { slog.Warn("Failed to receive request") - return + return nil } total += n if total == cap(req) { slog.Warn("Request is too big") - return + return nil } if total > 2 && req[total-2] == '\r' && req[total-1] == '\n' { @@ -139,7 +125,7 @@ func (gl *Listener) Handle(ctx context.Context, conn net.Conn) { } } - r := front.Request{ + r := &front.Request{ Context: ctx, Body: conn, } @@ -148,12 +134,39 @@ func (gl *Listener) Handle(ctx context.Context, conn net.Conn) { r.URL, err = url.Parse(string(req[:total-2])) if err != nil { slog.Warn("Failed to parse request", "request", string(req[:total-2]), "error", err) + return nil + } + + return r +} + +// Handle handles a Gemini request. +func (gl *Listener) Handle(ctx context.Context, conn net.Conn) { + if err := conn.SetDeadline(time.Now().Add(gl.Config.GeminiRequestTimeout)); err != nil { + slog.Warn("Failed to set deadline", "error", err) + return + } + + tlsConn, ok := conn.(*tls.Conn) + if !ok { + slog.Warn("Invalid connection") + return + } + + if err := tlsConn.HandshakeContext(ctx); err != nil { + slog.Warn("Handshake failed", "error", err) + return + } + + r := gl.readRequest(ctx, conn) + if r == nil { return } w := gmi.Wrap(conn) defer w.Flush() + var err error r.User, r.Keys, err = gl.getUser(ctx, tlsConn) if err != nil && errors.Is(err, front.ErrNotRegistered) && r.URL.Path == "/users" { slog.Info("Redirecting new user") @@ -180,7 +193,7 @@ func (gl *Listener) Handle(ctx context.Context, conn net.Conn) { r.Log = slog.With(slog.Group("request", "path", r.URL.Path, "user", r.User.PreferredUsername)) } - gl.Handler.Handle(&r, w) + gl.Handler.Handle(r, w) } // ListenAndServe handles Gemini requests. diff --git a/front/gopher/gopher.go b/front/gopher/gopher.go index 4fe9b1bf..e7afdcd1 100644 --- a/front/gopher/gopher.go +++ b/front/gopher/gopher.go @@ -37,34 +37,32 @@ type Listener struct { Config *cfg.Config Handler front.Handler Addr string + Buffers sync.Pool } -func (gl *Listener) handle(ctx context.Context, conn net.Conn) { - if err := conn.SetDeadline(time.Now().Add(gl.Config.GopherRequestTimeout)); err != nil { - slog.Warn("Failed to set deadline", "error", err) - return - } +func (gl *Listener) readRequest(ctx context.Context, conn net.Conn) *front.Request { + req := gl.Buffers.Get().([]byte) + defer gl.Buffers.Put(req) - req := make([]byte, 256) total := 0 for { n, err := conn.Read(req[total:]) if err != nil && total == 0 && errors.Is(err, io.EOF) { slog.Debug("Failed to receive request", "error", err) - return + return nil } else if err != nil { slog.Warn("Failed to receive request", "error", err) - return + return nil } if n <= 0 { slog.Warn("Failed to receive request") - return + return nil } total += n if total == cap(req) { slog.Warn("Request is too big") - return + return nil } if total >= 2 && req[total-2] == '\r' && req[total-1] == '\n' { @@ -77,7 +75,7 @@ func (gl *Listener) handle(ctx context.Context, conn net.Conn) { path = "/" } - r := front.Request{ + r := &front.Request{ Context: ctx, Body: conn, } @@ -86,15 +84,26 @@ func (gl *Listener) handle(ctx context.Context, conn net.Conn) { r.URL, err = url.Parse(path) if err != nil { slog.Warn("Failed to parse request", "path", path, "error", err) - return + return nil } r.Log = slog.With(slog.Group("request", "path", r.URL.Path)) - w := gmap.Wrap(conn, gl.Domain, gl.Config) - defer w.Flush() + return r +} + +func (gl *Listener) handle(ctx context.Context, conn net.Conn) { + if err := conn.SetDeadline(time.Now().Add(gl.Config.GopherRequestTimeout)); err != nil { + slog.Warn("Failed to set deadline", "error", err) + return + } + + if r := gl.readRequest(ctx, conn); r != nil { + w := gmap.Wrap(conn, gl.Domain, gl.Config) + defer w.Flush() - gl.Handler.Handle(&r, w) + gl.Handler.Handle(r, w) + } } // ListenAndServe handles Gopher requests. diff --git a/test/register_test.go b/test/register_test.go index 0ee8f349..41ad5ff8 100644 --- a/test/register_test.go +++ b/test/register_test.go @@ -193,6 +193,11 @@ func TestRegister_RedirectNoCertificate(t *testing.T) { Config: &cfg, Handler: handler, DB: db, + Buffers: sync.Pool{ + New: func() any { + return make([]byte, 1024) + }, + }, } l.Handle(context.Background(), tlsWriter) @@ -273,6 +278,11 @@ func TestRegister_Redirect(t *testing.T) { Config: &cfg, Handler: handler, DB: db, + Buffers: sync.Pool{ + New: func() any { + return make([]byte, 1024) + }, + }, } l.Handle(context.Background(), tlsWriter) @@ -350,6 +360,11 @@ func TestRegister_NoCertificate(t *testing.T) { Config: &cfg, Handler: handler, DB: db, + Buffers: sync.Pool{ + New: func() any { + return make([]byte, 1024) + }, + }, } l.Handle(context.Background(), tlsWriter) @@ -430,6 +445,11 @@ func TestRegister_HappyFlow(t *testing.T) { Config: &cfg, Handler: handler, DB: db, + Buffers: sync.Pool{ + New: func() any { + return make([]byte, 1024) + }, + }, } l.Handle(context.Background(), tlsWriter) @@ -510,6 +530,11 @@ func TestRegister_HappyFlowRegistrationClosed(t *testing.T) { Config: &cfg, Handler: handler, DB: db, + Buffers: sync.Pool{ + New: func() any { + return make([]byte, 1024) + }, + }, } l.Handle(context.Background(), tlsWriter) @@ -594,6 +619,11 @@ func TestRegister_AlreadyRegistered(t *testing.T) { Config: &cfg, Handler: handler, DB: db, + Buffers: sync.Pool{ + New: func() any { + return make([]byte, 1024) + }, + }, } l.Handle(context.Background(), tlsWriter) @@ -680,6 +710,11 @@ func TestRegister_Twice(t *testing.T) { Config: &cfg, Handler: handler, DB: db, + Buffers: sync.Pool{ + New: func() any { + return make([]byte, 1024) + }, + }, } l.Handle(context.Background(), tlsWriter) @@ -776,6 +811,11 @@ func TestRegister_Throttling(t *testing.T) { Config: &cfg, Handler: handler, DB: db, + Buffers: sync.Pool{ + New: func() any { + return make([]byte, 1024) + }, + }, } l.Handle(context.Background(), tlsWriter) @@ -872,6 +912,11 @@ func TestRegister_Throttling30Minutes(t *testing.T) { Config: &cfg, Handler: handler, DB: db, + Buffers: sync.Pool{ + New: func() any { + return make([]byte, 1024) + }, + }, } l.Handle(context.Background(), tlsWriter) @@ -971,6 +1016,11 @@ func TestRegister_Throttling1Hour(t *testing.T) { Config: &cfg, Handler: handler, DB: db, + Buffers: sync.Pool{ + New: func() any { + return make([]byte, 1024) + }, + }, } l.Handle(context.Background(), tlsWriter) @@ -1086,6 +1136,11 @@ func TestRegister_TwoCertificates(t *testing.T) { Config: &cfg, Handler: handler, DB: db, + Buffers: sync.Pool{ + New: func() any { + return make([]byte, 1024) + }, + }, } l.Handle(context.Background(), tlsWriter) @@ -1168,6 +1223,11 @@ func TestRegister_ForbiddenUserName(t *testing.T) { Config: &cfg, Handler: handler, DB: db, + Buffers: sync.Pool{ + New: func() any { + return make([]byte, 1024) + }, + }, } l.Handle(context.Background(), tlsWriter)