diff --git a/mcp/event.go b/mcp/event.go index 5c322c4a..57ce4ab3 100644 --- a/mcp/event.go +++ b/mcp/event.go @@ -66,10 +66,12 @@ func writeEvent(w io.Writer, evt Event) (int, error) { // // TODO(rfindley): consider a different API here that makes failure modes more // apparent. -func scanEvents(r io.Reader) iter.Seq2[Event, error] { +func scanEvents(r io.Reader, maxLineSize int) iter.Seq2[Event, error] { scanner := bufio.NewScanner(r) - const maxTokenSize = 1 * 1024 * 1024 // 1 MiB max line size - scanner.Buffer(nil, maxTokenSize) + if maxLineSize == 0 { + maxLineSize = 1 * 1024 * 1024 // defaults to 1MiB + } + scanner.Buffer(nil, maxLineSize) // TODO: investigate proper behavior when events are out of order, or have // non-standard names. @@ -139,7 +141,7 @@ func scanEvents(r io.Reader) iter.Seq2[Event, error] { } if err := scanner.Err(); err != nil { if errors.Is(err, bufio.ErrTooLong) { - err = fmt.Errorf("event exceeded max line length of %d", maxTokenSize) + err = fmt.Errorf("event exceeded max line length of %d", maxLineSize) } if !yield(Event{}, err) { return diff --git a/mcp/event_test.go b/mcp/event_test.go index dacf30e8..cbb0a46e 100644 --- a/mcp/event_test.go +++ b/mcp/event_test.go @@ -15,10 +15,11 @@ import ( func TestScanEvents(t *testing.T) { tests := []struct { - name string - input string - want []Event - wantErr string + name string + input string + want []Event + wantErr string + maxLineSize int }{ { name: "simple event", @@ -54,6 +55,12 @@ func TestScanEvents(t *testing.T) { input: "invalid line\n\n", wantErr: "malformed line", }, + { + name: "event exceeds buffer size", + input: "data: " + strings.Repeat("x", 200) + "\n\n", + maxLineSize: 100, + wantErr: "event exceeded max line length of 100", + }, } for _, tt := range tests { @@ -61,7 +68,7 @@ func TestScanEvents(t *testing.T) { r := strings.NewReader(tt.input) var got []Event var err error - for e, err2 := range scanEvents(r) { + for e, err2 := range scanEvents(r, tt.maxLineSize) { if err2 != nil { err = err2 break diff --git a/mcp/sse.go b/mcp/sse.go index 7f644918..02da7bfa 100644 --- a/mcp/sse.go +++ b/mcp/sse.go @@ -329,6 +329,9 @@ type SSEClientTransport struct { // HTTPClient is the client to use for making HTTP requests. If nil, // http.DefaultClient is used. HTTPClient *http.Client + + // MaxLineSize is the maximum buffer size (in bytes) used when reading SSE events. If unset, defaults to 1MiB. + MaxLineSize int } // Connect connects through the client endpoint. @@ -353,7 +356,7 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) { msgEndpoint, err := func() (*url.URL, error) { var evt Event - for evt, err = range scanEvents(resp.Body) { + for evt, err = range scanEvents(resp.Body, c.MaxLineSize) { break } if err != nil { @@ -374,7 +377,7 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) { s := &sseClientConn{ client: httpClient, msgEndpoint: msgEndpoint, - incoming: make(chan []byte, 100), + incoming: make(chan sseMessage, 100), body: resp.Body, done: make(chan struct{}), } @@ -382,12 +385,16 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) { go func() { defer s.Close() // close the transport when the GET exits - for evt, err := range scanEvents(resp.Body) { + for evt, err := range scanEvents(resp.Body, c.MaxLineSize) { if err != nil { + select { + case s.incoming <- sseMessage{err: err}: + case <-s.done: + } return } select { - case s.incoming <- evt.Data: + case s.incoming <- sseMessage{data: evt.Data}: case <-s.done: return } @@ -397,15 +404,21 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) { return s, nil } +// sseMessage represents a message or error from the SSE stream. +type sseMessage struct { + data []byte + err error +} + // An sseClientConn is a logical jsonrpc2 connection that implements the client // half of the SSE protocol: // - Writes are POSTS to the session endpoint. // - Reads are SSE 'message' events, and pushes them onto a buffered channel. // - Close terminates the GET request. type sseClientConn struct { - client *http.Client // HTTP client to use for requests - msgEndpoint *url.URL // session endpoint for POSTs - incoming chan []byte // queue of incoming messages + client *http.Client // HTTP client to use for requests + msgEndpoint *url.URL // session endpoint for POSTs + incoming chan sseMessage // queue of incoming messages or errors mu sync.Mutex body io.ReadCloser // body of the hanging GET @@ -430,12 +443,15 @@ func (c *sseClientConn) Read(ctx context.Context) (jsonrpc.Message, error) { case <-c.done: return nil, io.EOF - case data := <-c.incoming: + case m := <-c.incoming: + if m.err != nil { + return nil, m.err + } // TODO(rfindley): do we really need to check this? We receive from c.done above. if c.isDone() { return nil, io.EOF } - msg, err := jsonrpc2.DecodeMessage(data) + msg, err := jsonrpc2.DecodeMessage(m.data) if err != nil { return nil, err } diff --git a/mcp/streamable.go b/mcp/streamable.go index b4b2fa31..09f9c1f1 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -1388,6 +1388,8 @@ type StreamableClientTransport struct { // MaxRetries is the maximum number of times to attempt a reconnect before giving up. // It defaults to 5. To disable retries, use a negative number. MaxRetries int + // MaxLineSize is the maximum buffer size (in bytes) used when reading SSE events. If unset, defaults to 1MiB. + MaxLineSize int // TODO(rfindley): propose exporting these. // If strict is set, the transport is in 'strict mode', where any violation @@ -1453,16 +1455,17 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er // middleware), yet only cancel the standalone stream when the connection is closed. connCtx, cancel := context.WithCancel(xcontext.Detach(ctx)) conn := &streamableClientConn{ - url: t.Endpoint, - client: client, - incoming: make(chan jsonrpc.Message, 10), - done: make(chan struct{}), - maxRetries: maxRetries, - strict: t.strict, - logger: ensureLogger(t.logger), // must be non-nil for safe logging - ctx: connCtx, - cancel: cancel, - failed: make(chan struct{}), + url: t.Endpoint, + client: client, + incoming: make(chan jsonrpc.Message, 10), + done: make(chan struct{}), + maxRetries: maxRetries, + strict: t.strict, + logger: ensureLogger(t.logger), // must be non-nil for safe logging + ctx: connCtx, + cancel: cancel, + failed: make(chan struct{}), + maxLineSize: t.MaxLineSize, } return conn, nil } @@ -1497,6 +1500,7 @@ type streamableClientConn struct { mu sync.Mutex initializedResult *InitializeResult sessionID string + maxLineSize int } // errSessionMissing distinguishes if the session is known to not be present on @@ -1854,11 +1858,14 @@ func (c *streamableClientConn) processStream(ctx context.Context, requestSummary io.Copy(io.Discard, resp.Body) resp.Body.Close() }() - for evt, err := range scanEvents(resp.Body) { + for evt, err := range scanEvents(resp.Body, c.maxLineSize) { if err != nil { if ctx.Err() != nil { return "", 0, true // don't reconnect: client cancelled } + + // EOF errors are returned as nil from bufio.Scanner, so all errors should be returned back + c.fail(fmt.Errorf("%s: failed to process stream: %v", requestSummary, err)) break } diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index b1c3f074..2230ae6d 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -1425,7 +1425,7 @@ func (s streamableRequest) do(ctx context.Context, serverURL, sessionID string, var respBody []byte if strings.HasPrefix(contentType, "text/event-stream") { r := readerInto{resp.Body, new(bytes.Buffer)} - for evt, err := range scanEvents(r) { + for evt, err := range scanEvents(r, 0) { if err != nil { return newSessionID, resp.StatusCode, nil, fmt.Errorf("reading events: %v", err) } @@ -2143,7 +2143,7 @@ data: {"jsonrpc":"2.0","method":"test2","params":{}} var events []Event // Scan all events - for evt, err := range scanEvents(reader) { + for evt, err := range scanEvents(reader, 0) { if err != nil { if err != io.EOF { t.Fatalf("scanEvents error: %v", err) diff --git a/mcp/transport_test.go b/mcp/transport_test.go index 515b8c19..6f9363be 100644 --- a/mcp/transport_test.go +++ b/mcp/transport_test.go @@ -7,6 +7,8 @@ package mcp import ( "context" "io" + "net/http" + "net/http/httptest" "strings" "testing" @@ -124,3 +126,109 @@ func TestIOConnRead(t *testing.T) { }) } } + +func TestScanEventsBufferError(t *testing.T) { + ctx := context.Background() + tests := []struct { + name string + clientTransport func(url string) Transport + serverHandler func(server *Server) http.Handler + responseLength int + expectedContainsError string + }{ + { + name: "sse-large-output", + clientTransport: func(url string) Transport { + return &SSEClientTransport{ + Endpoint: url, + MaxLineSize: 1024, + } + }, + serverHandler: func(server *Server) http.Handler { + return NewSSEHandler(func(req *http.Request) *Server { return server }, nil) + }, + responseLength: 10000, + expectedContainsError: "exceeded max line length", + }, + { + name: "streamable-large-output", + clientTransport: func(url string) Transport { + return &StreamableClientTransport{ + Endpoint: url, + MaxLineSize: 1024, + } + }, + serverHandler: func(server *Server) http.Handler { + return NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + }, + responseLength: 10000, + expectedContainsError: "exceeded max line length", + }, + { + name: "sse-small-output", + clientTransport: func(url string) Transport { + return &SSEClientTransport{ + Endpoint: url, + MaxLineSize: 1024, + } + }, + serverHandler: func(server *Server) http.Handler { + return NewSSEHandler(func(req *http.Request) *Server { return server }, nil) + }, + responseLength: 512, + }, + { + name: "streamable-small-output", + clientTransport: func(url string) Transport { + return &StreamableClientTransport{ + Endpoint: url, + MaxLineSize: 1024, + } + }, + serverHandler: func(server *Server) http.Handler { + return NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + }, + responseLength: 512, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + largeResponse := strings.Repeat("x", tt.responseLength) + server := NewServer(testImpl, nil) + AddTool(server, &Tool{Name: "largeTool", Description: "returns large response"}, func(ctx context.Context, req *CallToolRequest, args any) (*CallToolResult, any, error) { + return &CallToolResult{Content: []Content{&TextContent{Text: largeResponse}}}, nil, nil + }) + + httpHandler := tt.serverHandler(server) + httpServer := httptest.NewServer(mustNotPanic(t, httpHandler)) + defer httpServer.Close() + + client := NewClient(testImpl, nil) + clientTransport := tt.clientTransport(httpServer.URL) + session, err := client.Connect(ctx, clientTransport, nil) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + defer session.Close() + + _, err = session.CallTool(ctx, &CallToolParams{ + Name: "largeTool", + Arguments: map[string]any{}, + }) + if tt.expectedContainsError != "" { + if tt.expectedContainsError != "" && err == nil { + t.Fatal("expected error due to small buffer, got nil") + } + + if !strings.Contains(err.Error(), "exceeded max line length") { + t.Fatalf("expected buffer-related error, got: %v", err) + } + } else { + if err != nil { + t.Fatalf("client.CallTool() unexpectedly failed: %v", err) + } + } + + }) + } +}