Skip to content

Commit d21d933

Browse files
authored
Move notification handler registrations to capabilities (#207)
* Move notification handler registrations to capabilities Currently request handlers are set on the capability objects, but notification handlers are set after construction via an AddNotificationHandler method on the IMcpEndpoint interface. This moves handler specification to be at construction as well. This makes it more consistent with request handlers, simplifies the IMcpEndpoint interface to just be about message sending, and avoids a concurrency bug that could occur if someone tried to add a handler while the endpoint was processing notifications. * Address more feedback and further cleanup
1 parent 8fcdf95 commit d21d933

35 files changed

+292
-288
lines changed

samples/QuickstartWeatherServer/Tools/WeatherTools.cs

-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
using ModelContextProtocol;
22
using ModelContextProtocol.Server;
33
using System.ComponentModel;
4-
using System.Net.Http.Json;
54
using System.Text.Json;
65

76
namespace QuickstartWeatherServer.Tools;

src/ModelContextProtocol/Client/McpClient.cs

+53-37
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@ namespace ModelContextProtocol.Client;
1212
/// <inheritdoc/>
1313
internal sealed class McpClient : McpEndpoint, IMcpClient
1414
{
15+
private static Implementation DefaultImplementation { get; } = new()
16+
{
17+
Name = DefaultAssemblyName.Name ?? nameof(McpClient),
18+
Version = DefaultAssemblyName.Version?.ToString() ?? "1.0.0",
19+
};
20+
1521
private readonly IClientTransport _clientTransport;
1622
private readonly McpClientOptions _options;
1723

@@ -29,43 +35,53 @@ internal sealed class McpClient : McpEndpoint, IMcpClient
2935
/// <param name="options">Options for the client, defining protocol version and capabilities.</param>
3036
/// <param name="serverConfig">The server configuration.</param>
3137
/// <param name="loggerFactory">The logger factory.</param>
32-
public McpClient(IClientTransport clientTransport, McpClientOptions options, McpServerConfig serverConfig, ILoggerFactory? loggerFactory)
38+
public McpClient(IClientTransport clientTransport, McpClientOptions? options, McpServerConfig serverConfig, ILoggerFactory? loggerFactory)
3339
: base(loggerFactory)
3440
{
41+
options ??= new();
42+
3543
_clientTransport = clientTransport;
3644
_options = options;
3745

3846
EndpointName = $"Client ({serverConfig.Id}: {serverConfig.Name})";
3947

40-
if (options.Capabilities?.Sampling is { } samplingCapability)
48+
if (options.Capabilities is { } capabilities)
4149
{
42-
if (samplingCapability.SamplingHandler is not { } samplingHandler)
50+
if (capabilities.NotificationHandlers is { } notificationHandlers)
4351
{
44-
throw new InvalidOperationException($"Sampling capability was set but it did not provide a handler.");
52+
NotificationHandlers.AddRange(notificationHandlers);
4553
}
4654

47-
SetRequestHandler(
48-
RequestMethods.SamplingCreateMessage,
49-
(request, cancellationToken) => samplingHandler(
50-
request,
51-
request?.Meta?.ProgressToken is { } token ? new TokenProgress(this, token) : NullProgress.Instance,
52-
cancellationToken),
53-
McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams,
54-
McpJsonUtilities.JsonContext.Default.CreateMessageResult);
55-
}
56-
57-
if (options.Capabilities?.Roots is { } rootsCapability)
58-
{
59-
if (rootsCapability.RootsHandler is not { } rootsHandler)
55+
if (capabilities.Sampling is { } samplingCapability)
6056
{
61-
throw new InvalidOperationException($"Roots capability was set but it did not provide a handler.");
57+
if (samplingCapability.SamplingHandler is not { } samplingHandler)
58+
{
59+
throw new InvalidOperationException($"Sampling capability was set but it did not provide a handler.");
60+
}
61+
62+
RequestHandlers.Set(
63+
RequestMethods.SamplingCreateMessage,
64+
(request, cancellationToken) => samplingHandler(
65+
request,
66+
request?.Meta?.ProgressToken is { } token ? new TokenProgress(this, token) : NullProgress.Instance,
67+
cancellationToken),
68+
McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams,
69+
McpJsonUtilities.JsonContext.Default.CreateMessageResult);
6270
}
6371

64-
SetRequestHandler(
65-
RequestMethods.RootsList,
66-
rootsHandler,
67-
McpJsonUtilities.JsonContext.Default.ListRootsRequestParams,
68-
McpJsonUtilities.JsonContext.Default.ListRootsResult);
72+
if (capabilities.Roots is { } rootsCapability)
73+
{
74+
if (rootsCapability.RootsHandler is not { } rootsHandler)
75+
{
76+
throw new InvalidOperationException($"Roots capability was set but it did not provide a handler.");
77+
}
78+
79+
RequestHandlers.Set(
80+
RequestMethods.RootsList,
81+
rootsHandler,
82+
McpJsonUtilities.JsonContext.Default.ListRootsRequestParams,
83+
McpJsonUtilities.JsonContext.Default.ListRootsResult);
84+
}
6985
}
7086
}
7187

@@ -96,20 +112,20 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default)
96112
using var initializationCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
97113
initializationCts.CancelAfter(_options.InitializationTimeout);
98114

99-
try
100-
{
101-
// Send initialize request
102-
var initializeResponse = await this.SendRequestAsync(
103-
RequestMethods.Initialize,
104-
new InitializeRequestParams
105-
{
106-
ProtocolVersion = _options.ProtocolVersion,
107-
Capabilities = _options.Capabilities ?? new ClientCapabilities(),
108-
ClientInfo = _options.ClientInfo
109-
},
110-
McpJsonUtilities.JsonContext.Default.InitializeRequestParams,
111-
McpJsonUtilities.JsonContext.Default.InitializeResult,
112-
cancellationToken: initializationCts.Token).ConfigureAwait(false);
115+
try
116+
{
117+
// Send initialize request
118+
var initializeResponse = await this.SendRequestAsync(
119+
RequestMethods.Initialize,
120+
new InitializeRequestParams
121+
{
122+
ProtocolVersion = _options.ProtocolVersion,
123+
Capabilities = _options.Capabilities ?? new ClientCapabilities(),
124+
ClientInfo = _options.ClientInfo ?? DefaultImplementation,
125+
},
126+
McpJsonUtilities.JsonContext.Default.InitializeRequestParams,
127+
McpJsonUtilities.JsonContext.Default.InitializeResult,
128+
cancellationToken: initializationCts.Token).ConfigureAwait(false);
113129

114130
// Store server information
115131
_logger.ServerCapabilitiesReceived(EndpointName,

src/ModelContextProtocol/Client/McpClientFactory.cs

-19
Original file line numberDiff line numberDiff line change
@@ -5,30 +5,12 @@
55
using ModelContextProtocol.Utils;
66
using Microsoft.Extensions.Logging;
77
using Microsoft.Extensions.Logging.Abstractions;
8-
using System.Reflection;
98

109
namespace ModelContextProtocol.Client;
1110

1211
/// <summary>Provides factory methods for creating MCP clients.</summary>
1312
public static class McpClientFactory
1413
{
15-
/// <summary>Default client options to use when none are supplied.</summary>
16-
private static readonly McpClientOptions s_defaultClientOptions = CreateDefaultClientOptions();
17-
18-
/// <summary>Creates default client options to use when no options are supplied.</summary>
19-
private static McpClientOptions CreateDefaultClientOptions()
20-
{
21-
var asmName = (Assembly.GetEntryAssembly() ?? Assembly.GetExecutingAssembly()).GetName();
22-
return new()
23-
{
24-
ClientInfo = new()
25-
{
26-
Name = asmName.Name ?? "McpClient",
27-
Version = asmName.Version?.ToString() ?? "1.0.0",
28-
},
29-
};
30-
}
31-
3214
/// <summary>Creates an <see cref="IMcpClient"/>, connecting it to the specified server.</summary>
3315
/// <param name="serverConfig">Configuration for the target server to which the client should connect.</param>
3416
/// <param name="clientOptions">
@@ -52,7 +34,6 @@ public static async Task<IMcpClient> CreateAsync(
5234
{
5335
Throw.IfNull(serverConfig);
5436

55-
clientOptions ??= s_defaultClientOptions;
5637
createTransportFunc ??= CreateTransport;
5738

5839
string endpointName = $"Client ({serverConfig.Id}: {serverConfig.Name})";

src/ModelContextProtocol/Client/McpClientOptions.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ public class McpClientOptions
1212
/// <summary>
1313
/// Information about this client implementation.
1414
/// </summary>
15-
public required Implementation ClientInfo { get; set; }
15+
public Implementation? ClientInfo { get; set; }
1616

1717
/// <summary>
1818
/// Client capabilities to advertise to the server.

src/ModelContextProtocol/Client/McpClientTool.cs

-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
using ModelContextProtocol.Protocol.Types;
22
using ModelContextProtocol.Utils.Json;
3-
using ModelContextProtocol.Utils;
43
using Microsoft.Extensions.AI;
54
using System.Text.Json;
65

src/ModelContextProtocol/Configuration/McpServerOptionsSetup.cs

+1-14
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
using System.Reflection;
2-
using ModelContextProtocol.Server;
1+
using ModelContextProtocol.Server;
32
using Microsoft.Extensions.Options;
43
using ModelContextProtocol.Utils;
54

@@ -25,18 +24,6 @@ public void Configure(McpServerOptions options)
2524
{
2625
Throw.IfNull(options);
2726

28-
// Configure the option's server information based on the current process,
29-
// if it otherwise lacks server information.
30-
if (options.ServerInfo is not { } serverInfo)
31-
{
32-
var assemblyName = Assembly.GetEntryAssembly()?.GetName();
33-
options.ServerInfo = new()
34-
{
35-
Name = assemblyName?.Name ?? "McpServer",
36-
Version = assemblyName?.Version?.ToString() ?? "1.0.0",
37-
};
38-
}
39-
4027
// Collect all of the provided tools into a tools collection. If the options already has
4128
// a collection, add to it, otherwise create a new one. We want to maintain the identity
4229
// of an existing collection in case someone has provided their own derived type, wants

src/ModelContextProtocol/IMcpEndpoint.cs

-16
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,4 @@ public interface IMcpEndpoint : IAsyncDisposable
1515
/// <param name="message">The message.</param>
1616
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
1717
Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default);
18-
19-
/// <summary>
20-
/// Adds a handler for server notifications of a specific method.
21-
/// </summary>
22-
/// <param name="method">The notification method to handle.</param>
23-
/// <param name="handler">The async handler function to process notifications.</param>
24-
/// <remarks>
25-
/// <para>
26-
/// Each method may have multiple handlers. Adding a handler for a method that already has one
27-
/// will not replace the existing handler.
28-
/// </para>
29-
/// <para>
30-
/// <see cref="NotificationMethods"> provides constants for common notification methods.</see>
31-
/// </para>
32-
/// </remarks>
33-
void AddNotificationHandler(string method, Func<JsonRpcNotification, Task> handler);
3418
}

src/ModelContextProtocol/Protocol/Transport/McpTransportException.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ public McpTransportException(string message)
3030
/// </summary>
3131
/// <param name="message">The message that describes the error.</param>
3232
/// <param name="innerException">The exception that is the cause of the current exception.</param>
33-
public McpTransportException(string message, Exception innerException)
33+
public McpTransportException(string message, Exception? innerException)
3434
: base(message, innerException)
3535
{
3636
}

src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs

+15-3
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,22 @@ public StdioClientSessionTransport(StdioClientTransportOptions options, Process
2121
/// <inheritdoc/>
2222
public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
2323
{
24-
if (_process.HasExited)
24+
Exception? processException = null;
25+
bool hasExited = false;
26+
try
27+
{
28+
hasExited = _process.HasExited;
29+
}
30+
catch (Exception e)
31+
{
32+
processException = e;
33+
hasExited = true;
34+
}
35+
36+
if (hasExited)
2537
{
2638
Logger.TransportNotConnected(EndpointName);
27-
throw new McpTransportException("Transport is not connected");
39+
throw new McpTransportException("Transport is not connected", processException);
2840
}
2941

3042
await base.SendMessageAsync(message, cancellationToken).ConfigureAwait(false);
@@ -33,7 +45,7 @@ public override async Task SendMessageAsync(IJsonRpcMessage message, Cancellatio
3345
/// <inheritdoc/>
3446
protected override ValueTask CleanupAsync(CancellationToken cancellationToken)
3547
{
36-
StdioClientTransport.DisposeProcess(_process, processStarted: true, Logger, _options.ShutdownTimeout, EndpointName);
48+
StdioClientTransport.DisposeProcess(_process, processRunning: true, Logger, _options.ShutdownTimeout, EndpointName);
3749

3850
return base.CleanupAsync(cancellationToken);
3951
}

src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs

+14-2
Original file line numberDiff line numberDiff line change
@@ -129,13 +129,25 @@ public async Task<ITransport> ConnectAsync(CancellationToken cancellationToken =
129129
}
130130

131131
internal static void DisposeProcess(
132-
Process? process, bool processStarted, ILogger logger, TimeSpan shutdownTimeout, string endpointName)
132+
Process? process, bool processRunning, ILogger logger, TimeSpan shutdownTimeout, string endpointName)
133133
{
134134
if (process is not null)
135135
{
136+
if (processRunning)
137+
{
138+
try
139+
{
140+
processRunning = !process.HasExited;
141+
}
142+
catch
143+
{
144+
processRunning = false;
145+
}
146+
}
147+
136148
try
137149
{
138-
if (processStarted && !process.HasExited)
150+
if (processRunning)
139151
{
140152
// Wait for the process to exit.
141153
// Kill the while process tree because the process may spawn child processes

src/ModelContextProtocol/Protocol/Transport/StdioServerTransport.cs

+1-3
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,7 @@ public StdioServerTransport(string serverName, ILoggerFactory? loggerFactory = n
7070
private static string GetServerName(McpServerOptions serverOptions)
7171
{
7272
Throw.IfNull(serverOptions);
73-
Throw.IfNull(serverOptions.ServerInfo);
74-
Throw.IfNull(serverOptions.ServerInfo.Name);
7573

76-
return serverOptions.ServerInfo.Name;
74+
return serverOptions.ServerInfo?.Name ?? McpServer.DefaultImplementation.Name;
7775
}
7876
}

src/ModelContextProtocol/Protocol/Types/Capabilities.cs

+10-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using ModelContextProtocol.Server;
1+
using ModelContextProtocol.Protocol.Messages;
2+
using ModelContextProtocol.Server;
23
using System.Text.Json.Serialization;
34

45
namespace ModelContextProtocol.Protocol.Types;
@@ -26,6 +27,14 @@ public class ClientCapabilities
2627
/// </summary>
2728
[JsonPropertyName("sampling")]
2829
public SamplingCapability? Sampling { get; set; }
30+
31+
/// <summary>Gets or sets notification handlers to register with the client.</summary>
32+
/// <remarks>
33+
/// When constructed, the client will enumerate these handlers, which may contain multiple handlers per key.
34+
/// The client will not re-enumerate the sequence.
35+
/// </remarks>
36+
[JsonIgnore]
37+
public IEnumerable<KeyValuePair<string, Func<JsonRpcNotification, Task>>>? NotificationHandlers { get; set; }
2938
}
3039

3140
/// <summary>

src/ModelContextProtocol/Protocol/Types/ListRootsRequestParams.cs

+1-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
using ModelContextProtocol.Protocol.Messages;
2-
3-
namespace ModelContextProtocol.Protocol.Types;
1+
namespace ModelContextProtocol.Protocol.Types;
42

53
/// <summary>
64
/// A request from the server to get a list of root URIs from the client.

src/ModelContextProtocol/Protocol/Types/ServerCapabilities.cs

+10-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using System.Text.Json.Serialization;
1+
using ModelContextProtocol.Protocol.Messages;
2+
using System.Text.Json.Serialization;
23

34
namespace ModelContextProtocol.Protocol.Types;
45

@@ -37,4 +38,12 @@ public class ServerCapabilities
3738
/// </summary>
3839
[JsonPropertyName("tools")]
3940
public ToolsCapability? Tools { get; set; }
41+
42+
/// <summary>Gets or sets notification handlers to register with the server.</summary>
43+
/// <remarks>
44+
/// When constructed, the server will enumerate these handlers, which may contain multiple handlers per key.
45+
/// The server will not re-enumerate the sequence.
46+
/// </remarks>
47+
[JsonIgnore]
48+
public IEnumerable<KeyValuePair<string, Func<JsonRpcNotification, Task>>>? NotificationHandlers { get; set; }
4049
}

0 commit comments

Comments
 (0)