Skip to content
Draft
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions ap/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package ap

import (
"context"
"net/http"

"github.com/dimkr/tootik/httpsig"
)
Expand All @@ -40,5 +39,5 @@ const (
type Resolver interface {
ResolveID(ctx context.Context, keys [2]httpsig.Key, id string, flags ResolverFlag) (*Actor, error)
Resolve(ctx context.Context, keys [2]httpsig.Key, host, name string, flags ResolverFlag) (*Actor, error)
Get(ctx context.Context, keys [2]httpsig.Key, url string) (*http.Response, error)
Get(ctx context.Context, keys [2]httpsig.Key, url string) (int, []byte, func(), error)
}
10 changes: 10 additions & 0 deletions cluster/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,11 @@ func NewServer(ctx context.Context, t *testing.T, domain string, client fed.Clie
DB: db,
ActorKeys: nobodyKeys,
Resolver: resolver,
Buffers: sync.Pool{
New: func() any {
return make([]byte, cfg.MaxRequestBodySize)
},
},
}).NewHandler()
if err != nil {
t.Fatalf("Failed to run create the federation handler: %v", err)
Expand Down Expand Up @@ -220,6 +225,11 @@ func NewServer(ctx context.Context, t *testing.T, domain string, client fed.Clie
Config: &cfg,
DB: db,
Handler: handler,
Buffers: sync.Pool{
New: func() any {
return make([]byte, 1024)
},
},
},
Backend: backend,
Inbox: localInbox,
Expand Down
20 changes: 20 additions & 0 deletions cmd/tootik/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,11 @@ func main() {
Key: *key,
Plain: *plain,
BlockList: blockList,
Buffers: sync.Pool{
New: func() any {
return make([]byte, cfg.MaxRequestBodySize)
},
},
},
},
{
Expand All @@ -354,6 +359,11 @@ func main() {
Addr: *gemAddr,
CertPath: *gemCert,
KeyPath: *gemKey,
Buffers: sync.Pool{
New: func() any {
return make([]byte, 1024+2)
},
},
},
},
{
Expand All @@ -363,6 +373,11 @@ func main() {
Config: &cfg,
Handler: handler,
Addr: *gopherAddr,
Buffers: sync.Pool{
New: func() any {
return make([]byte, 256)
},
},
},
},
{
Expand All @@ -372,6 +387,11 @@ func main() {
Config: &cfg,
DB: db,
Addr: *fingerAddr,
Buffers: sync.Pool{
New: func() any {
return make([]byte, 34)
},
},
},
},
{
Expand Down
4 changes: 2 additions & 2 deletions fed/deliver.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,9 @@ func (q *Queue) deliverWithTimeout(parent context.Context, task deliveryTask) er

req := task.Request.WithContext(ctx)

resp, err := q.Resolver.send(task.Keys, req)
_, _, cleanup, err := q.Resolver.send(task.Keys, req)
if err == nil {
resp.Body.Close()
cleanup()
}
return err
}
Expand Down
11 changes: 3 additions & 8 deletions fed/followers.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
Expand Down Expand Up @@ -248,20 +247,16 @@ func (d *followersDigest) Sync(ctx context.Context, domain string, cfg *cfg.Conf

slog.Info("Synchronizing followers", "followed", d.Followed)

resp, err := resolver.Get(ctx, keys, d.URL)
_, body, cleanup, err := resolver.Get(ctx, keys, d.URL)
if err != nil {
return err
}
defer resp.Body.Close()

if resp.ContentLength > cfg.MaxResponseBodySize {
return errors.New("response is too big")
}
defer cleanup()

var remote struct {
OrderedItems ap.Audience `json:"orderedItems"`
}
if err := json.NewDecoder(io.LimitReader(resp.Body, cfg.MaxResponseBodySize)).Decode(&remote); err != nil {
if err := json.Unmarshal(body, &remote); err != nil {
return err
}

Expand Down
21 changes: 8 additions & 13 deletions fed/inbox.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,23 +219,14 @@ func (l *Listener) validateActivity(activity *ap.Activity, origin string, depth
}

func (l *Listener) fetchObject(ctx context.Context, id string, keys [2]httpsig.Key) (bool, []byte, error) {
resp, err := l.Resolver.Get(ctx, keys, id)
statusCode, body, cleanup, err := l.Resolver.Get(ctx, keys, id)
if err != nil {
if resp != nil && (resp.StatusCode == http.StatusNotFound || resp.StatusCode == http.StatusGone) {
if statusCode == http.StatusNotFound || statusCode == http.StatusGone {
return false, nil, err
}
return true, nil, err
}
defer resp.Body.Close()

if resp.ContentLength > l.Config.MaxRequestBodySize {
return true, nil, fmt.Errorf("object is too big: %d", resp.ContentLength)
}

body, err := io.ReadAll(io.LimitReader(resp.Body, l.Config.MaxRequestBodySize))
if err != nil {
return true, nil, err
}
defer cleanup()

if !ap.IsPortable(id) {
return true, body, nil
Expand Down Expand Up @@ -320,11 +311,15 @@ func (l *Listener) doHandleInbox(w http.ResponseWriter, r *http.Request, keys [2
return
}

rawActivity, err := io.ReadAll(io.LimitReader(r.Body, l.Config.MaxRequestBodySize))
buf := l.Buffers.Get().([]byte)
defer l.Buffers.Put(buf)

n, err := r.Body.Read(buf)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
rawActivity := buf[:n]

var activity ap.Activity
if err := json.Unmarshal(rawActivity, &activity); err != nil {
Expand Down
1 change: 1 addition & 0 deletions fed/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ type Listener struct {
Key string
Plain bool
BlockList *BlockList
Buffers sync.Pool
}

const certReloadDelay = time.Second * 5
Expand Down
36 changes: 14 additions & 22 deletions fed/resolve.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ import (
"errors"
"fmt"
"hash/crc32"
"io"
"log/slog"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"

"github.com/dimkr/tootik/ap"
Expand Down Expand Up @@ -69,6 +69,11 @@ func NewResolver(blockedDomains *BlockList, domain string, cfg *cfg.Config, clie
Config: cfg,
client: client,
DB: db,
Buffers: sync.Pool{
New: func() any {
return make([]byte, cfg.MaxResponseBodySize)
},
},
},
BlockedDomains: blockedDomains,
db: db,
Expand Down Expand Up @@ -174,8 +179,8 @@ func deleteActor(ctx context.Context, db *sql.DB, id string) {
}
}

func (r *Resolver) handleFetchFailure(ctx context.Context, fetched string, cachedActor *ap.Actor, sinceLastUpdate time.Duration, resp *http.Response, err error) (*ap.Actor, *ap.Actor, error) {
if resp != nil && (resp.StatusCode == http.StatusGone || resp.StatusCode == http.StatusNotFound) {
func (r *Resolver) handleFetchFailure(ctx context.Context, fetched string, cachedActor *ap.Actor, sinceLastUpdate time.Duration, statusCode int, err error) (*ap.Actor, *ap.Actor, error) {
if statusCode == http.StatusGone || statusCode == http.StatusNotFound {
if cachedActor != nil {
slog.Warn("Actor is gone, deleting associated objects", "id", cachedActor.ID)
deleteActor(ctx, r.db, cachedActor.ID)
Expand Down Expand Up @@ -289,18 +294,14 @@ func (r *Resolver) tryResolve(ctx context.Context, keys [2]httpsig.Key, host, na
}
req.Header.Add("Accept", "application/json")

resp, err := r.send(keys, req)
statusCode, body, cleanup, err := r.send(keys, req)
if err != nil {
return r.handleFetchFailure(ctx, finger, cachedActor, sinceLastUpdate, resp, err)
}
defer resp.Body.Close()

if resp.ContentLength > r.Config.MaxResponseBodySize {
return nil, cachedActor, fmt.Errorf("failed to decode %s response: response is too big", finger)
return r.handleFetchFailure(ctx, finger, cachedActor, sinceLastUpdate, statusCode, err)
}
defer cleanup()

var webFingerResponse webFingerResponse
if err := json.NewDecoder(io.LimitReader(resp.Body, r.Config.MaxResponseBodySize)).Decode(&webFingerResponse); err != nil {
if err := json.Unmarshal(body, &webFingerResponse); err != nil {
return nil, cachedActor, fmt.Errorf("failed to decode %s response: %w", finger, err)
}

Expand Down Expand Up @@ -432,20 +433,11 @@ func (r *Resolver) fetchActor(ctx context.Context, keys [2]httpsig.Key, host, pr

req.Header.Add("Accept", `application/ld+json; profile="https://www.w3.org/ns/activitystreams"`)

resp, err := r.send(keys, req)
resp, body, cleanup, err := r.send(keys, req)
if err != nil {
return r.handleFetchFailure(ctx, profile, cachedActor, sinceLastUpdate, resp, err)
}
defer resp.Body.Close()

if resp.ContentLength > r.Config.MaxResponseBodySize {
return nil, cachedActor, fmt.Errorf("failed to fetch %s: response is too big", profile)
}

body, err := io.ReadAll(io.LimitReader(resp.Body, r.Config.MaxResponseBodySize))
if err != nil {
return nil, cachedActor, fmt.Errorf("failed to fetch %s: %w", profile, err)
}
defer cleanup()

var actor ap.Actor
if err := json.Unmarshal(body, &actor); err != nil {
Expand Down
12 changes: 10 additions & 2 deletions fed/resolve_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2614,7 +2614,11 @@ func TestResolve_FederatedActorOldCacheBigWebFingerResponse(t *testing.T) {
},
}

cfg.MaxResponseBodySize = 1
resolver.Buffers = sync.Pool{
New: func() any {
return make([]byte, 1)
},
}

actor, err = resolver.Resolve(context.Background(), key, "0.0.0.0", "dan", 0)
assert.NoError(err)
Expand Down Expand Up @@ -2874,7 +2878,11 @@ func TestResolve_FederatedActorOldCacheBigActor(t *testing.T) {
},
}

cfg.MaxResponseBodySize = 419
resolver.Buffers = sync.Pool{
New: func() any {
return make([]byte, 419)
},
}

actor, err = resolver.Resolve(context.Background(), key, "0.0.0.0", "dan", 0)
assert.NoError(err)
Expand Down
Loading
Loading