diff --git a/contrib/mark3labs/mcp-go/README.md b/contrib/mark3labs/mcp-go/README.md index 32ff4e3c68..7273d06721 100644 --- a/contrib/mark3labs/mcp-go/README.md +++ b/contrib/mark3labs/mcp-go/README.md @@ -17,13 +17,11 @@ func main() { tracer.Start() defer tracer.Stop() - // Add tracing to your server hooks - hooks := &server.Hooks{} - mcpgotrace.AddServerHooks(hooks) - + // Do not use with `server.WithHooks(...)`, as this overwrites the tracing hooks. + // To add custom hooks alongside tracing, pass them via TracingConfig.Hooks, e.g.: + // mcpgotrace.WithMCPServerTracing(&mcpgotrace.TracingConfig{Hooks: customHooks}) srv := server.NewMCPServer("my-server", "1.0.0", - server.WithHooks(hooks), - server.WithToolHandlerMiddleware(mcpgotrace.NewToolHandlerMiddleware())) + mcpgotrace.WithMCPServerTracing(nil)) } ``` diff --git a/contrib/mark3labs/mcp-go/example_test.go b/contrib/mark3labs/mcp-go/example_test.go index b4c32d8a53..e36603e132 100644 --- a/contrib/mark3labs/mcp-go/example_test.go +++ b/contrib/mark3labs/mcp-go/example_test.go @@ -6,8 +6,11 @@ package mcpgo_test import ( + "context" + mcpgotrace "github.com/DataDog/dd-trace-go/contrib/mark3labs/mcp-go/v2" "github.com/DataDog/dd-trace-go/v2/ddtrace/tracer" + "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" ) @@ -15,12 +18,21 @@ func Example() { tracer.Start() defer tracer.Stop() - // Create server hooks and add Datadog tracing - hooks := &server.Hooks{} - mcpgotrace.AddServerHooks(hooks) + srv := server.NewMCPServer("my-server", "1.0.0", + mcpgotrace.WithMCPServerTracing(nil)) + _ = srv +} + +func Example_withCustomHooks() { + tracer.Start() + defer tracer.Stop() + + customHooks := &server.Hooks{} + customHooks.AddBeforeInitialize(func(ctx context.Context, id any, request *mcp.InitializeRequest) { + // Your custom logic here + }) srv := server.NewMCPServer("my-server", "1.0.0", - server.WithHooks(hooks), - server.WithToolHandlerMiddleware(mcpgotrace.NewToolHandlerMiddleware())) + mcpgotrace.WithMCPServerTracing(&mcpgotrace.TracingConfig{Hooks: customHooks})) _ = srv } diff --git a/contrib/mark3labs/mcp-go/option.go b/contrib/mark3labs/mcp-go/option.go new file mode 100644 index 0000000000..de2560e40a --- /dev/null +++ b/contrib/mark3labs/mcp-go/option.go @@ -0,0 +1,59 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2025 Datadog, Inc. + +package mcpgo + +import ( + "github.com/mark3labs/mcp-go/server" +) + +// The file contains methods for easily adding tracing to a MCP server. + +// TracingConfig holds configuration for adding tracing to an MCP server. +type TracingConfig struct { + // Hooks allows you to provide custom hooks that will be merged with Datadog tracing hooks. + // If nil, only Datadog tracing hooks will be added and any custom hooks provided via server.WithHooks(...) will be removed. + // If provided, your custom hooks will be executed alongside Datadog tracing hooks. + Hooks *server.Hooks +} + +// WithMCPServerTracing adds Datadog tracing to an MCP server. +// Pass this option to server.NewMCPServer to enable tracing. +// +// Do not use with `server.WithHooks(...)`, as this overwrites the hooks. +// Instead, pass custom hooks in the TracingConfig, which will be merged with tracing hooks. +// +// Usage: +// +// // Simple usage with only tracing hooks +// srv := server.NewMCPServer("my-server", "1.0.0", +// WithMCPServerTracing(nil)) +// +// // With custom hooks +// customHooks := &server.Hooks{} +// customHooks.AddBeforeInitialize(func(ctx context.Context, id any, request *mcp.InitializeRequest) { +// // Your custom logic here +// }) +// srv := server.NewMCPServer("my-server", "1.0.0", +// WithMCPServerTracing(&TracingConfig{Hooks: customHooks})) +func WithMCPServerTracing(options *TracingConfig) server.ServerOption { + return func(s *server.MCPServer) { + if options == nil { + options = new(TracingConfig) + } + + hooks := options.Hooks + + // Append hooks (hooks is a private field) + if hooks == nil { + hooks = &server.Hooks{} + } + appendTracingHooks(hooks) + + server.WithHooks(hooks)(s) + + server.WithToolHandlerMiddleware(toolHandlerMiddleware)(s) + } +} diff --git a/contrib/mark3labs/mcp-go/mcpgo.go b/contrib/mark3labs/mcp-go/tracing.go similarity index 69% rename from contrib/mark3labs/mcp-go/mcpgo.go rename to contrib/mark3labs/mcp-go/tracing.go index fc75adf0e7..f7ff3c9f15 100644 --- a/contrib/mark3labs/mcp-go/mcpgo.go +++ b/contrib/mark3labs/mcp-go/tracing.go @@ -27,46 +27,44 @@ type hooks struct { spanCache *sync.Map } -// AddServerHooks appends Datadog tracing hooks to an existing server.Hooks object. -func AddServerHooks(hooks *server.Hooks) { - ddHooks := newHooks() - hooks.AddBeforeInitialize(ddHooks.onBeforeInitialize) - hooks.AddAfterInitialize(ddHooks.onAfterInitialize) - hooks.AddOnError(ddHooks.onError) +// appendTracingHooks appends Datadog tracing hooks to an existing server.Hooks object. +func appendTracingHooks(hooks *server.Hooks) { + tracingHooks := newHooks() + hooks.AddBeforeInitialize(tracingHooks.onBeforeInitialize) + hooks.AddAfterInitialize(tracingHooks.onAfterInitialize) + hooks.AddOnError(tracingHooks.onError) } -func NewToolHandlerMiddleware() server.ToolHandlerMiddleware { - return func(next server.ToolHandlerFunc) server.ToolHandlerFunc { - return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - toolSpan, ctx := llmobs.StartToolSpan(ctx, request.Params.Name, llmobs.WithIntegration(string(instrumentation.PackageMark3LabsMCPGo))) +var toolHandlerMiddleware = func(next server.ToolHandlerFunc) server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + toolSpan, ctx := llmobs.StartToolSpan(ctx, request.Params.Name, llmobs.WithIntegration(string(instrumentation.PackageMark3LabsMCPGo))) - result, err := next(ctx, request) + result, err := next(ctx, request) - inputJSON, marshalErr := json.Marshal(request) + inputJSON, marshalErr := json.Marshal(request) + if marshalErr != nil { + instr.Logger().Warn("mcp-go: failed to marshal tool request: %v", marshalErr) + } + var outputText string + if result != nil { + resultJSON, marshalErr := json.Marshal(result) if marshalErr != nil { - instr.Logger().Warn("mcp-go: failed to marshal tool request: %v", marshalErr) - } - var outputText string - if result != nil { - resultJSON, marshalErr := json.Marshal(result) - if marshalErr != nil { - instr.Logger().Warn("mcp-go: failed to marshal tool result: %v", marshalErr) - } - outputText = string(resultJSON) + instr.Logger().Warn("mcp-go: failed to marshal tool result: %v", marshalErr) } + outputText = string(resultJSON) + } - tagWithSessionID(ctx, toolSpan) - - toolSpan.AnnotateTextIO(string(inputJSON), outputText) + tagWithSessionID(ctx, toolSpan) - if err != nil { - toolSpan.Finish(llmobs.WithError(err)) - } else { - toolSpan.Finish() - } + toolSpan.AnnotateTextIO(string(inputJSON), outputText) - return result, err + if err != nil { + toolSpan.Finish(llmobs.WithError(err)) + } else { + toolSpan.Finish() } + + return result, err } } diff --git a/contrib/mark3labs/mcp-go/mcpgo_test.go b/contrib/mark3labs/mcp-go/tracing_test.go similarity index 87% rename from contrib/mark3labs/mcp-go/mcpgo_test.go rename to contrib/mark3labs/mcp-go/tracing_test.go index 678feb6974..287bfaf726 100644 --- a/contrib/mark3labs/mcp-go/mcpgo_test.go +++ b/contrib/mark3labs/mcp-go/tracing_test.go @@ -21,11 +21,11 @@ import ( "github.com/stretchr/testify/require" ) -func TestNewToolHandlerMiddleware(t *testing.T) { +func TestToolHandlerMiddleware(t *testing.T) { mt := mocktracer.Start() defer mt.Stop() - middleware := NewToolHandlerMiddleware() + middleware := toolHandlerMiddleware assert.NotNil(t, middleware) } @@ -34,7 +34,7 @@ func TestAddServerHooks(t *testing.T) { defer mt.Stop() serverHooks := &server.Hooks{} - AddServerHooks(serverHooks) + appendTracingHooks(serverHooks) assert.Len(t, serverHooks.OnBeforeInitialize, 1) assert.Len(t, serverHooks.OnAfterInitialize, 1) @@ -45,11 +45,8 @@ func TestIntegrationSessionInitialize(t *testing.T) { tt := testTracer(t) defer tt.Stop() - hooks := &server.Hooks{} - AddServerHooks(hooks) - srv := server.NewMCPServer("test-server", "1.0.0", - server.WithHooks(hooks)) + WithMCPServerTracing(nil)) ctx := context.Background() sessionID := "test-session-init" @@ -109,11 +106,10 @@ func TestIntegrationToolCallSuccess(t *testing.T) { defer tt.Stop() hooks := &server.Hooks{} - AddServerHooks(hooks) + appendTracingHooks(hooks) srv := server.NewMCPServer("test-server", "1.0.0", - server.WithHooks(hooks), - server.WithToolHandlerMiddleware(NewToolHandlerMiddleware())) + WithMCPServerTracing(nil)) calcTool := mcp.NewTool("calculator", mcp.WithDescription("A simple calculator")) @@ -211,7 +207,7 @@ func TestIntegrationToolCallError(t *testing.T) { defer tt.Stop() srv := server.NewMCPServer("test-server", "1.0.0", - server.WithToolHandlerMiddleware(NewToolHandlerMiddleware())) + WithMCPServerTracing(&TracingConfig{})) errorTool := mcp.NewTool("error_tool", mcp.WithDescription("A tool that always errors")) @@ -258,6 +254,35 @@ func TestIntegrationToolCallError(t *testing.T) { assert.Contains(t, toolSpan.Meta, "input") } +func TestWithMCPServerTracingWithCustomHooks(t *testing.T) { + tt := testTracer(t) + defer tt.Stop() + + customHookCalled := false + customHooks := &server.Hooks{} + customHooks.AddBeforeInitialize(func(ctx context.Context, id any, request *mcp.InitializeRequest) { + customHookCalled = true + }) + + srv := server.NewMCPServer("test-server", "1.0.0", + WithMCPServerTracing(&TracingConfig{Hooks: customHooks})) + + ctx := context.Background() + initRequest := `{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test-client","version":"1.0.0"}}}` + + response := srv.HandleMessage(ctx, []byte(initRequest)) + assert.NotNil(t, response) + + assert.True(t, customHookCalled, "custom hook should have been called") + + spans := tt.WaitForLLMObsSpans(t, 1) + require.Len(t, spans, 1) + + taskSpan := spans[0] + assert.Equal(t, "mcp.initialize", taskSpan.Name) + assert.Equal(t, "task", taskSpan.Meta["span.kind"]) +} + // Test helpers // testTracer creates a testtracer with LLMObs enabled for integration tests