Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions mcp/event.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,12 @@ func writeEvent(w io.Writer, evt Event) (int, error) {
//
// TODO(rfindley): consider a different API here that makes failure modes more
// apparent.
func scanEvents(r io.Reader) iter.Seq2[Event, error] {
func scanEvents(r io.Reader, maxLineSize int) iter.Seq2[Event, error] {
scanner := bufio.NewScanner(r)
const maxTokenSize = 1 * 1024 * 1024 // 1 MiB max line size
scanner.Buffer(nil, maxTokenSize)
if maxLineSize == 0 {
maxLineSize = 1 * 1024 * 1024 // defaults to 1MiB
}
scanner.Buffer(nil, maxLineSize)

// TODO: investigate proper behavior when events are out of order, or have
// non-standard names.
Expand Down Expand Up @@ -139,7 +141,7 @@ func scanEvents(r io.Reader) iter.Seq2[Event, error] {
}
if err := scanner.Err(); err != nil {
if errors.Is(err, bufio.ErrTooLong) {
err = fmt.Errorf("event exceeded max line length of %d", maxTokenSize)
err = fmt.Errorf("event exceeded max line length of %d", maxLineSize)
}
if !yield(Event{}, err) {
return
Expand Down
17 changes: 12 additions & 5 deletions mcp/event_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ import (

func TestScanEvents(t *testing.T) {
tests := []struct {
name string
input string
want []Event
wantErr string
name string
input string
want []Event
wantErr string
maxLineSize int
}{
{
name: "simple event",
Expand Down Expand Up @@ -54,14 +55,20 @@ func TestScanEvents(t *testing.T) {
input: "invalid line\n\n",
wantErr: "malformed line",
},
{
name: "event exceeds buffer size",
input: "data: " + strings.Repeat("x", 200) + "\n\n",
maxLineSize: 100,
wantErr: "event exceeded max line length of 100",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := strings.NewReader(tt.input)
var got []Event
var err error
for e, err2 := range scanEvents(r) {
for e, err2 := range scanEvents(r, tt.maxLineSize) {
if err2 != nil {
err = err2
break
Expand Down
34 changes: 25 additions & 9 deletions mcp/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,9 @@ type SSEClientTransport struct {
// HTTPClient is the client to use for making HTTP requests. If nil,
// http.DefaultClient is used.
HTTPClient *http.Client

// MaxLineSize is the maximum buffer size (in bytes) used when reading SSE events. If unset, defaults to 1MiB.
MaxLineSize int
}

// Connect connects through the client endpoint.
Expand All @@ -353,7 +356,7 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) {

msgEndpoint, err := func() (*url.URL, error) {
var evt Event
for evt, err = range scanEvents(resp.Body) {
for evt, err = range scanEvents(resp.Body, c.MaxLineSize) {
break
}
if err != nil {
Expand All @@ -374,20 +377,24 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) {
s := &sseClientConn{
client: httpClient,
msgEndpoint: msgEndpoint,
incoming: make(chan []byte, 100),
incoming: make(chan sseMessage, 100),
body: resp.Body,
done: make(chan struct{}),
}

go func() {
defer s.Close() // close the transport when the GET exits

for evt, err := range scanEvents(resp.Body) {
for evt, err := range scanEvents(resp.Body, c.MaxLineSize) {
if err != nil {
select {
case s.incoming <- sseMessage{err: err}:
case <-s.done:
}
return
}
select {
case s.incoming <- evt.Data:
case s.incoming <- sseMessage{data: evt.Data}:
case <-s.done:
return
}
Expand All @@ -397,15 +404,21 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) {
return s, nil
}

// sseMessage represents a message or error from the SSE stream.
type sseMessage struct {
data []byte
err error
}

// An sseClientConn is a logical jsonrpc2 connection that implements the client
// half of the SSE protocol:
// - Writes are POSTS to the session endpoint.
// - Reads are SSE 'message' events, and pushes them onto a buffered channel.
// - Close terminates the GET request.
type sseClientConn struct {
client *http.Client // HTTP client to use for requests
msgEndpoint *url.URL // session endpoint for POSTs
incoming chan []byte // queue of incoming messages
client *http.Client // HTTP client to use for requests
msgEndpoint *url.URL // session endpoint for POSTs
incoming chan sseMessage // queue of incoming messages or errors

mu sync.Mutex
body io.ReadCloser // body of the hanging GET
Expand All @@ -430,12 +443,15 @@ func (c *sseClientConn) Read(ctx context.Context) (jsonrpc.Message, error) {
case <-c.done:
return nil, io.EOF

case data := <-c.incoming:
case m := <-c.incoming:
if m.err != nil {
return nil, m.err
}
// TODO(rfindley): do we really need to check this? We receive from c.done above.
if c.isDone() {
return nil, io.EOF
}
msg, err := jsonrpc2.DecodeMessage(data)
msg, err := jsonrpc2.DecodeMessage(m.data)
if err != nil {
return nil, err
}
Expand Down
29 changes: 18 additions & 11 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -1388,6 +1388,8 @@ type StreamableClientTransport struct {
// MaxRetries is the maximum number of times to attempt a reconnect before giving up.
// It defaults to 5. To disable retries, use a negative number.
MaxRetries int
// MaxLineSize is the maximum buffer size (in bytes) used when reading SSE events. If unset, defaults to 1MiB.
MaxLineSize int

// TODO(rfindley): propose exporting these.
// If strict is set, the transport is in 'strict mode', where any violation
Expand Down Expand Up @@ -1453,16 +1455,17 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er
// middleware), yet only cancel the standalone stream when the connection is closed.
connCtx, cancel := context.WithCancel(xcontext.Detach(ctx))
conn := &streamableClientConn{
url: t.Endpoint,
client: client,
incoming: make(chan jsonrpc.Message, 10),
done: make(chan struct{}),
maxRetries: maxRetries,
strict: t.strict,
logger: ensureLogger(t.logger), // must be non-nil for safe logging
ctx: connCtx,
cancel: cancel,
failed: make(chan struct{}),
url: t.Endpoint,
client: client,
incoming: make(chan jsonrpc.Message, 10),
done: make(chan struct{}),
maxRetries: maxRetries,
strict: t.strict,
logger: ensureLogger(t.logger), // must be non-nil for safe logging
ctx: connCtx,
cancel: cancel,
failed: make(chan struct{}),
maxLineSize: t.MaxLineSize,
}
return conn, nil
}
Expand Down Expand Up @@ -1497,6 +1500,7 @@ type streamableClientConn struct {
mu sync.Mutex
initializedResult *InitializeResult
sessionID string
maxLineSize int
}

// errSessionMissing distinguishes if the session is known to not be present on
Expand Down Expand Up @@ -1854,11 +1858,14 @@ func (c *streamableClientConn) processStream(ctx context.Context, requestSummary
io.Copy(io.Discard, resp.Body)
resp.Body.Close()
}()
for evt, err := range scanEvents(resp.Body) {
for evt, err := range scanEvents(resp.Body, c.maxLineSize) {
if err != nil {
if ctx.Err() != nil {
return "", 0, true // don't reconnect: client cancelled
}

// EOF errors are returned as nil from bufio.Scanner, so all errors should be returned back
c.fail(fmt.Errorf("%s: failed to process stream: %v", requestSummary, err))
break
}

Expand Down
4 changes: 2 additions & 2 deletions mcp/streamable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1425,7 +1425,7 @@ func (s streamableRequest) do(ctx context.Context, serverURL, sessionID string,
var respBody []byte
if strings.HasPrefix(contentType, "text/event-stream") {
r := readerInto{resp.Body, new(bytes.Buffer)}
for evt, err := range scanEvents(r) {
for evt, err := range scanEvents(r, 0) {
if err != nil {
return newSessionID, resp.StatusCode, nil, fmt.Errorf("reading events: %v", err)
}
Expand Down Expand Up @@ -2143,7 +2143,7 @@ data: {"jsonrpc":"2.0","method":"test2","params":{}}
var events []Event

// Scan all events
for evt, err := range scanEvents(reader) {
for evt, err := range scanEvents(reader, 0) {
if err != nil {
if err != io.EOF {
t.Fatalf("scanEvents error: %v", err)
Expand Down
108 changes: 108 additions & 0 deletions mcp/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ package mcp
import (
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"

Expand Down Expand Up @@ -124,3 +126,109 @@ func TestIOConnRead(t *testing.T) {
})
}
}

func TestScanEventsBufferError(t *testing.T) {
ctx := context.Background()
tests := []struct {
name string
clientTransport func(url string) Transport
serverHandler func(server *Server) http.Handler
responseLength int
expectedContainsError string
}{
{
name: "sse-large-output",
clientTransport: func(url string) Transport {
return &SSEClientTransport{
Endpoint: url,
MaxLineSize: 1024,
}
},
serverHandler: func(server *Server) http.Handler {
return NewSSEHandler(func(req *http.Request) *Server { return server }, nil)
},
responseLength: 10000,
expectedContainsError: "exceeded max line length",
},
{
name: "streamable-large-output",
clientTransport: func(url string) Transport {
return &StreamableClientTransport{
Endpoint: url,
MaxLineSize: 1024,
}
},
serverHandler: func(server *Server) http.Handler {
return NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil)
},
responseLength: 10000,
expectedContainsError: "exceeded max line length",
},
{
name: "sse-small-output",
clientTransport: func(url string) Transport {
return &SSEClientTransport{
Endpoint: url,
MaxLineSize: 1024,
}
},
serverHandler: func(server *Server) http.Handler {
return NewSSEHandler(func(req *http.Request) *Server { return server }, nil)
},
responseLength: 512,
},
{
name: "streamable-small-output",
clientTransport: func(url string) Transport {
return &StreamableClientTransport{
Endpoint: url,
MaxLineSize: 1024,
}
},
serverHandler: func(server *Server) http.Handler {
return NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil)
},
responseLength: 512,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
largeResponse := strings.Repeat("x", tt.responseLength)
server := NewServer(testImpl, nil)
AddTool(server, &Tool{Name: "largeTool", Description: "returns large response"}, func(ctx context.Context, req *CallToolRequest, args any) (*CallToolResult, any, error) {
return &CallToolResult{Content: []Content{&TextContent{Text: largeResponse}}}, nil, nil
})

httpHandler := tt.serverHandler(server)
httpServer := httptest.NewServer(mustNotPanic(t, httpHandler))
defer httpServer.Close()

client := NewClient(testImpl, nil)
clientTransport := tt.clientTransport(httpServer.URL)
session, err := client.Connect(ctx, clientTransport, nil)
if err != nil {
t.Fatalf("client.Connect() failed: %v", err)
}
defer session.Close()

_, err = session.CallTool(ctx, &CallToolParams{
Name: "largeTool",
Arguments: map[string]any{},
})
if tt.expectedContainsError != "" {
if tt.expectedContainsError != "" && err == nil {
t.Fatal("expected error due to small buffer, got nil")
}

if !strings.Contains(err.Error(), "exceeded max line length") {
t.Fatalf("expected buffer-related error, got: %v", err)
}
} else {
if err != nil {
t.Fatalf("client.CallTool() unexpectedly failed: %v", err)
}
}

})
}
}