Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions ap/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package ap

import (
"context"
"net/http"

"github.com/dimkr/tootik/httpsig"
)
Expand All @@ -40,5 +39,5 @@ const (
type Resolver interface {
ResolveID(ctx context.Context, keys [2]httpsig.Key, id string, flags ResolverFlag) (*Actor, error)
Resolve(ctx context.Context, keys [2]httpsig.Key, host, name string, flags ResolverFlag) (*Actor, error)
Get(ctx context.Context, keys [2]httpsig.Key, url string) (*http.Response, error)
Get(ctx context.Context, keys [2]httpsig.Key, url string) (int, []byte, func(), error)
}
10 changes: 10 additions & 0 deletions cluster/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,11 @@ func NewServer(ctx context.Context, t *testing.T, domain string, client fed.Clie
DB: db,
ActorKeys: nobodyKeys,
Resolver: resolver,
Buffers: sync.Pool{
New: func() any {
return make([]byte, cfg.MaxRequestBodySize)
},
},
}).NewHandler()
if err != nil {
t.Fatalf("Failed to run create the federation handler: %v", err)
Expand Down Expand Up @@ -220,6 +225,11 @@ func NewServer(ctx context.Context, t *testing.T, domain string, client fed.Clie
Config: &cfg,
DB: db,
Handler: handler,
Buffers: sync.Pool{
New: func() any {
return make([]byte, 1024)
},
},
},
Backend: backend,
Inbox: localInbox,
Expand Down
20 changes: 20 additions & 0 deletions cmd/tootik/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,11 @@ func main() {
Key: *key,
Plain: *plain,
BlockList: blockList,
Buffers: sync.Pool{
New: func() any {
return make([]byte, cfg.MaxRequestBodySize)
},
},
},
},
{
Expand All @@ -354,6 +359,11 @@ func main() {
Addr: *gemAddr,
CertPath: *gemCert,
KeyPath: *gemKey,
Buffers: sync.Pool{
New: func() any {
return make([]byte, 1024+2)
},
},
},
},
{
Expand All @@ -363,6 +373,11 @@ func main() {
Config: &cfg,
Handler: handler,
Addr: *gopherAddr,
Buffers: sync.Pool{
New: func() any {
return make([]byte, 256)
},
},
},
},
{
Expand All @@ -372,6 +387,11 @@ func main() {
Config: &cfg,
DB: db,
Addr: *fingerAddr,
Buffers: sync.Pool{
New: func() any {
return make([]byte, 34)
},
},
},
},
{
Expand Down
4 changes: 2 additions & 2 deletions fed/deliver.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,9 @@ func (q *Queue) deliverWithTimeout(parent context.Context, task *deliveryTask) e

req := task.Request.WithContext(ctx)

resp, err := q.Resolver.send(task.Keys, req)
_, _, cleanup, err := q.Resolver.send(task.Keys, req)
if err == nil {
resp.Body.Close()
cleanup()
}
return err
}
Expand Down
11 changes: 3 additions & 8 deletions fed/followers.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
Expand Down Expand Up @@ -248,20 +247,16 @@ func (d *followersDigest) Sync(ctx context.Context, domain string, cfg *cfg.Conf

slog.Info("Synchronizing followers", "followed", d.Followed)

resp, err := resolver.Get(ctx, keys, d.URL)
_, body, cleanup, err := resolver.Get(ctx, keys, d.URL)
if err != nil {
return err
}
defer resp.Body.Close()

if resp.ContentLength > cfg.MaxResponseBodySize {
return errors.New("response is too big")
}
defer cleanup()

var remote struct {
OrderedItems ap.Audience `json:"orderedItems"`
}
if err := json.NewDecoder(io.LimitReader(resp.Body, cfg.MaxResponseBodySize)).Decode(&remote); err != nil {
if err := json.Unmarshal(body, &remote); err != nil {
return err
}

Expand Down
23 changes: 9 additions & 14 deletions fed/inbox.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,23 +219,14 @@ func (l *Listener) validateActivity(activity *ap.Activity, origin string, depth
}

func (l *Listener) fetchObject(ctx context.Context, id string, keys [2]httpsig.Key) (bool, []byte, error) {
resp, err := l.Resolver.Get(ctx, keys, id)
statusCode, body, cleanup, err := l.Resolver.Get(ctx, keys, id)
if err != nil {
if resp != nil && (resp.StatusCode == http.StatusNotFound || resp.StatusCode == http.StatusGone) {
if statusCode == http.StatusNotFound || statusCode == http.StatusGone {
return false, nil, err
}
return true, nil, err
}
defer resp.Body.Close()

if resp.ContentLength > l.Config.MaxRequestBodySize {
return true, nil, fmt.Errorf("object is too big: %d", resp.ContentLength)
}

body, err := io.ReadAll(io.LimitReader(resp.Body, l.Config.MaxRequestBodySize))
if err != nil {
return true, nil, err
}
defer cleanup()

if !ap.IsPortable(id) {
return true, body, nil
Expand Down Expand Up @@ -320,11 +311,15 @@ func (l *Listener) doHandleInbox(w http.ResponseWriter, r *http.Request, keys [2
return
}

rawActivity, err := io.ReadAll(io.LimitReader(r.Body, l.Config.MaxRequestBodySize))
if err != nil {
buf := l.Buffers.Get().([]byte)
defer l.Buffers.Put(buf)

n, err := r.Body.Read(buf)
if err != nil && !errors.Is(err, io.EOF) {
w.WriteHeader(http.StatusInternalServerError)
return
}
rawActivity := buf[:n]

var activity ap.Activity
if err := json.Unmarshal(rawActivity, &activity); err != nil {
Expand Down
1 change: 1 addition & 0 deletions fed/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ type Listener struct {
Key string
Plain bool
BlockList *BlockList
Buffers sync.Pool
}

const certReloadDelay = time.Second * 5
Expand Down
36 changes: 14 additions & 22 deletions fed/resolve.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ import (
"errors"
"fmt"
"hash/crc32"
"io"
"log/slog"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"

"github.com/dimkr/tootik/ap"
Expand Down Expand Up @@ -69,6 +69,11 @@ func NewResolver(blockedDomains *BlockList, domain string, cfg *cfg.Config, clie
Config: cfg,
client: client,
DB: db,
Buffers: sync.Pool{
New: func() any {
return make([]byte, cfg.MaxResponseBodySize)
},
},
},
BlockedDomains: blockedDomains,
db: db,
Expand Down Expand Up @@ -174,8 +179,8 @@ func deleteActor(ctx context.Context, db *sql.DB, id string) {
}
}

func (r *Resolver) handleFetchFailure(ctx context.Context, fetched string, cachedActor *ap.Actor, sinceLastUpdate time.Duration, resp *http.Response, err error) (*ap.Actor, *ap.Actor, error) {
if resp != nil && (resp.StatusCode == http.StatusGone || resp.StatusCode == http.StatusNotFound) {
func (r *Resolver) handleFetchFailure(ctx context.Context, fetched string, cachedActor *ap.Actor, sinceLastUpdate time.Duration, statusCode int, err error) (*ap.Actor, *ap.Actor, error) {
if statusCode == http.StatusGone || statusCode == http.StatusNotFound {
if cachedActor != nil {
slog.Warn("Actor is gone, deleting associated objects", "id", cachedActor.ID)
deleteActor(ctx, r.db, cachedActor.ID)
Expand Down Expand Up @@ -289,18 +294,14 @@ func (r *Resolver) tryResolve(ctx context.Context, keys [2]httpsig.Key, host, na
}
req.Header.Add("Accept", "application/json")

resp, err := r.send(keys, req)
statusCode, body, cleanup, err := r.send(keys, req)
if err != nil {
return r.handleFetchFailure(ctx, finger, cachedActor, sinceLastUpdate, resp, err)
}
defer resp.Body.Close()

if resp.ContentLength > r.Config.MaxResponseBodySize {
return nil, cachedActor, fmt.Errorf("failed to decode %s response: response is too big", finger)
return r.handleFetchFailure(ctx, finger, cachedActor, sinceLastUpdate, statusCode, err)
}
defer cleanup()

var webFingerResponse webFingerResponse
if err := json.NewDecoder(io.LimitReader(resp.Body, r.Config.MaxResponseBodySize)).Decode(&webFingerResponse); err != nil {
if err := json.Unmarshal(body, &webFingerResponse); err != nil {
return nil, cachedActor, fmt.Errorf("failed to decode %s response: %w", finger, err)
}

Expand Down Expand Up @@ -432,20 +433,11 @@ func (r *Resolver) fetchActor(ctx context.Context, keys [2]httpsig.Key, host, pr

req.Header.Add("Accept", `application/ld+json; profile="https://www.w3.org/ns/activitystreams"`)

resp, err := r.send(keys, req)
resp, body, cleanup, err := r.send(keys, req)
if err != nil {
return r.handleFetchFailure(ctx, profile, cachedActor, sinceLastUpdate, resp, err)
}
defer resp.Body.Close()

if resp.ContentLength > r.Config.MaxResponseBodySize {
return nil, cachedActor, fmt.Errorf("failed to fetch %s: response is too big", profile)
}

body, err := io.ReadAll(io.LimitReader(resp.Body, r.Config.MaxResponseBodySize))
if err != nil {
return nil, cachedActor, fmt.Errorf("failed to fetch %s: %w", profile, err)
}
defer cleanup()

var actor ap.Actor
if err := json.Unmarshal(body, &actor); err != nil {
Expand Down
12 changes: 10 additions & 2 deletions fed/resolve_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2614,7 +2614,11 @@ func TestResolve_FederatedActorOldCacheBigWebFingerResponse(t *testing.T) {
},
}

cfg.MaxResponseBodySize = 1
resolver.Buffers = sync.Pool{
New: func() any {
return make([]byte, 1)
},
}

actor, err = resolver.Resolve(context.Background(), key, "0.0.0.0", "dan", 0)
assert.NoError(err)
Expand Down Expand Up @@ -2874,7 +2878,11 @@ func TestResolve_FederatedActorOldCacheBigActor(t *testing.T) {
},
}

cfg.MaxResponseBodySize = 419
resolver.Buffers = sync.Pool{
New: func() any {
return make([]byte, 419)
},
}

actor, err = resolver.Resolve(context.Background(), key, "0.0.0.0", "dan", 0)
assert.NoError(err)
Expand Down
Loading
Loading