Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Microsoft.Azure.SignalR.AspNet/DispatcherHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ internal static ServiceHubDispatcher PrepareAndGetDispatcher(IAppBuilder builder
configuration.Resolver.Register(typeof(IServiceConnectionFactory), () => scf);
}

var sccf = new ServiceConnectionContainerFactory(scf, endpoint, router, options, loggerFactory);
var sccf = new ServiceConnectionContainerFactory(scf, endpoint, router, options, null, loggerFactory);

if (hubs?.Count > 0)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System.Threading;
Expand Down Expand Up @@ -31,5 +31,7 @@ internal interface ICallerClientResultsManager : IClientResultsManager
bool TryCompleteResult(string connectionId, ErrorCompletionMessage message);

void RemoveInvocation(string invocationId);

void SetAckNumber(string invocationId, int ackNumber);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ namespace Microsoft.Azure.SignalR;
internal class MultiEndpointMessageWriter : IServiceMessageWriter, IPresenceManager
{
private readonly ILogger _logger;
private readonly IClientInvocationManager _clientInvocationManager;

internal HubServiceEndpoint[] TargetEndpoints { get; }

public MultiEndpointMessageWriter(IReadOnlyCollection<ServiceEndpoint> targetEndpoints, ILoggerFactory loggerFactory)
public MultiEndpointMessageWriter(IReadOnlyCollection<ServiceEndpoint> targetEndpoints, IClientInvocationManager invocationManager, ILoggerFactory loggerFactory)
{
_clientInvocationManager = invocationManager;
_logger = loggerFactory.CreateLogger<MultiEndpointMessageWriter>();
var normalized = new List<HubServiceEndpoint>();
if (targetEndpoints != null)
Expand All @@ -52,6 +54,12 @@ public MultiEndpointMessageWriter(IReadOnlyCollection<ServiceEndpoint> targetEnd

public Task WriteAsync(ServiceMessage serviceMessage)
{
if (serviceMessage is ClientInvocationMessage invocationMessage)
{
// Accroding to target endpoints in method `WriteMultiEndpointMessageAsync`
_clientInvocationManager.Caller.SetAckNumber(invocationMessage.InvocationId, TargetEndpoints.Length);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when TargetEndpoints.Length is 0, the result is OK?

Copy link
Contributor Author

@xingsy97 xingsy97 Feb 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

class AckHandler could handle such condition correctly. Refer to its method SetExpectedCount

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we still need to SetAck if Length == 0?

}

return WriteMultiEndpointMessageAsync(serviceMessage, connection => connection.WriteAsync(serviceMessage));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ internal class MultiEndpointServiceConnectionContainer : IServiceConnectionConta

private readonly object _lock = new object();

private readonly IClientInvocationManager _clientInvocationManager;

private (bool needRouter, IReadOnlyList<HubServiceEndpoint> endpoints) _routerEndpoints;

private int _started;
Expand All @@ -56,13 +58,15 @@ public MultiEndpointServiceConnectionContainer(
int? maxCount,
IServiceEndpointManager endpointManager,
IMessageRouter router,
IClientInvocationManager clientInvocationManager,
ILoggerFactory loggerFactory,
TimeSpan? scaleTimeout = null
) : this(
hub,
endpoint => CreateContainer(serviceConnectionFactory, endpoint, count, maxCount, loggerFactory),
endpointManager,
router,
clientInvocationManager,
loggerFactory,
scaleTimeout)
{
Expand All @@ -73,6 +77,7 @@ internal MultiEndpointServiceConnectionContainer(
Func<HubServiceEndpoint, IServiceConnectionContainer> generator,
IServiceEndpointManager endpointManager,
IMessageRouter router,
IClientInvocationManager clientInvocationManager,
ILoggerFactory loggerFactory,
TimeSpan? scaleTimeout = null)
{
Expand All @@ -90,6 +95,7 @@ internal MultiEndpointServiceConnectionContainer(
_loggerFactory = loggerFactory;
_logger = loggerFactory?.CreateLogger<MultiEndpointServiceConnectionContainer>() ?? throw new ArgumentNullException(nameof(loggerFactory));
_serviceEndpointManager = endpointManager;
_clientInvocationManager = clientInvocationManager;
_scaleTimeout = scaleTimeout ?? Constants.Periods.DefaultScaleTimeout;

// Reserve generator for potential scale use.
Expand Down Expand Up @@ -158,7 +164,7 @@ public Task<bool> WriteAckableMessageAsync(ServiceMessage serviceMessage, Cancel
public IAsyncEnumerable<GroupMember> ListConnectionsInGroupAsync(string groupName, int? top = null, ulong? tracingId = null, CancellationToken token = default)
{
var targetEndpoints = _routerEndpoints.needRouter ? _router.GetEndpointsForGroup(groupName, _routerEndpoints.endpoints) : _routerEndpoints.endpoints;
var messageWriter = new MultiEndpointMessageWriter(targetEndpoints?.ToList(), _loggerFactory);
var messageWriter = new MultiEndpointMessageWriter(targetEndpoints?.ToList(), _clientInvocationManager, _loggerFactory);
return messageWriter.ListConnectionsInGroupAsync(groupName, top, tracingId, token);
}

Expand Down Expand Up @@ -271,7 +277,7 @@ private static IServiceConnectionContainer CreateContainer(IServiceConnectionFac
private MultiEndpointMessageWriter CreateMessageWriter(ServiceMessage serviceMessage)
{
var targetEndpoints = GetRoutedEndpoints(serviceMessage)?.ToList();
return new MultiEndpointMessageWriter(targetEndpoints, _loggerFactory);
return new MultiEndpointMessageWriter(targetEndpoints, _clientInvocationManager, _loggerFactory);
}

private void OnAdd(HubServiceEndpoint endpoint)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
Expand All @@ -18,16 +18,20 @@ internal class ServiceConnectionContainerFactory : IServiceConnectionContainerFa

private readonly IServiceConnectionFactory _serviceConnectionFactory;

private readonly IClientInvocationManager _clientInvocationManager;

public ServiceConnectionContainerFactory(IServiceConnectionFactory serviceConnectionFactory,
IServiceEndpointManager serviceEndpointManager,
IMessageRouter router,
IServiceEndpointOptions options,
IClientInvocationManager clientInvocationManager,
ILoggerFactory loggerFactory)
{
_serviceConnectionFactory = serviceConnectionFactory;
_serviceEndpointManager = serviceEndpointManager ?? throw new ArgumentNullException(nameof(serviceEndpointManager));
_router = router ?? throw new ArgumentNullException(nameof(router));
_options = options;
_clientInvocationManager = clientInvocationManager;
_loggerFactory = loggerFactory;
}

Expand All @@ -39,6 +43,7 @@ public IServiceConnectionContainer Create(string hub, TimeSpan? serviceScaleTime
_options.MaxHubServerConnectionCount,
_serviceEndpointManager,
_router,
_clientInvocationManager,
_loggerFactory,
serviceScaleTimeout);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using Microsoft.Azure.SignalR.Common;
Expand All @@ -14,14 +14,16 @@ internal class MultiEndpointConnectionContainerFactory
private readonly IServiceEndpointManager _endpointManager;
private readonly int _connectionCount;
private readonly IEndpointRouter _router;
private readonly IClientInvocationManager _clientInvocationManager;

public MultiEndpointConnectionContainerFactory(IServiceConnectionFactory connectionFactory, ILoggerFactory loggerFactory, IServiceEndpointManager serviceEndpointManager, IOptions<ServiceManagerOptions> options, IEndpointRouter router = null)
public MultiEndpointConnectionContainerFactory(IServiceConnectionFactory connectionFactory, ILoggerFactory loggerFactory, IServiceEndpointManager serviceEndpointManager, IOptions<ServiceManagerOptions> options, IEndpointRouter router = null, IClientInvocationManager clientInvocationManager = null)
{
_connectionFactory = connectionFactory;
_loggerFactory = loggerFactory;
_endpointManager = serviceEndpointManager;
_connectionCount = options.Value.ConnectionCount;
_router = router;
_clientInvocationManager = clientInvocationManager;
}

public MultiEndpointServiceConnectionContainer Create(string hubName)
Expand All @@ -31,8 +33,9 @@ public MultiEndpointServiceConnectionContainer Create(string hubName)
endpoint => new WeakServiceConnectionContainer(_connectionFactory, _connectionCount, endpoint, _loggerFactory.CreateLogger<WeakServiceConnectionContainer>()),
_endpointManager,
_router,
_clientInvocationManager,
_loggerFactory);
return container;
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
Expand Down Expand Up @@ -96,7 +96,7 @@ public override ServiceHubContext WithEndpoints(IEnumerable<ServiceEndpoint> end
private sealed class MessageWriterServiceContainerWrapper : MultiEndpointMessageWriter, IServiceConnectionContainer
{
public MessageWriterServiceContainerWrapper(IReadOnlyCollection<ServiceEndpoint> targetEndpoints, ILoggerFactory loggerFactory)
: base(targetEndpoints, loggerFactory) { }
: base(targetEndpoints, null, loggerFactory) { }

public Task StartAsync() => Task.CompletedTask;

Expand Down Expand Up @@ -125,4 +125,4 @@ public void Dispose()
#endregion Not supported method or properties
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
#if NET7_0_OR_GREATER
using System;
Expand Down Expand Up @@ -206,6 +206,14 @@ public void RemoveInvocation(string invocationId)
_pendingInvocations.TryRemove(invocationId, out _);
}

public void SetAckNumber(string invocationId, int ackNumber)
{
if (_pendingInvocations.TryGetValue(invocationId, out var item))
{
_ackHandler.SetExpectedCount(item.AckId, ackNumber);
}
}

// Unused, here to honor the IInvocationBinder interface but should never be called
public IReadOnlyList<Type> GetParameterTypes(string methodName) => throw new NotImplementedException();

Expand All @@ -218,4 +226,4 @@ private record PendingInvocation(Type Type, string ConnectionId, object Tcs, int
}
}
}
#endif
#endif
3 changes: 2 additions & 1 deletion src/Microsoft.Azure.SignalR/HubHost/ServiceHubDispatcher.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
Expand Down Expand Up @@ -195,6 +195,7 @@ private IServiceConnectionContainer GetServiceConnectionContainer(ConnectionDele
_serviceEndpointManager,
_router,
_options,
_clientInvocationManager,
_loggerFactory
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ public TestMultiEndpointServiceConnectionContainer(string hub,
Func<HubServiceEndpoint, IServiceConnectionContainer> generator,
IServiceEndpointManager endpoint,
IEndpointRouter router,
ILoggerFactory loggerFactory) : base(hub, generator, endpoint, router, loggerFactory)
ILoggerFactory loggerFactory) : base(hub, generator, endpoint, router, null, loggerFactory)
{
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public async Task ListConnectionsInGroup(int? top, int resultCount, params int?[
endpoint.ConnectionContainer = containerMock.Object;
targetEndpoints.Add(endpoint);
}
var multiEndpointWriter = new MultiEndpointMessageWriter(targetEndpoints, Mock.Of<ILoggerFactory>());
var multiEndpointWriter = new MultiEndpointMessageWriter(targetEndpoints, null, Mock.Of<ILoggerFactory>());
var resultMembers = new List<GroupMember>();
await foreach (var member in multiEndpointWriter.ListConnectionsInGroupAsync("group", top))
{
Expand Down
22 changes: 16 additions & 6 deletions test/Microsoft.Azure.SignalR.Tests/ClientInvocationManagerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,18 @@ public class ClientInvocationManagerTests
private const string SuccessCompleteResult = "success-result";
private const string ErrorCompleteResult = "error-result";

private static ClientInvocationManager GetTestClientInvocationManager(int endpointCount = 1)
private static ClientInvocationManager GetTestClientInvocationManager(int endpointCount = 1, int badEndpointsCount = 0)
{
var services = new ServiceCollection();
var endpoints = Enumerable.Range(0, endpointCount)
.Select(i => new ServiceEndpoint($"Endpoint=https://test{i}connectionstring;AccessKey=1"))
.ToArray();

for (var i = 0; i < badEndpointsCount && i < endpointCount; i++)
{
endpoints[i].Online = false;
}

var config = new ConfigurationBuilder().Build();

var serviceProvider = services.AddLogging()
Expand Down Expand Up @@ -175,14 +180,19 @@ public void TestCallerManagerCancellation()
}

[Theory]
[InlineData(true, 2)]
[InlineData(false, 2)]
[InlineData(true, 3)]
[InlineData(false, 3)]
[InlineData(true, 2, 0)]
[InlineData(true, 2, 1)]
[InlineData(true, 2, 2)]
[InlineData(false, 2, 0)]
[InlineData(false, 2, 1)]
[InlineData(false, 2, 2)]
[InlineData(true, 3, 0)]
[InlineData(false, 3, 0)]
// isCompletionWithResult: the invocation is completed with result or error
public async Task TestCompleteWithMultiEndpointAtLast(bool isCompletionWithResult, int endpointsCount)
public async Task TestCompleteWithMultiEndpointAtLast(bool isCompletionWithResult, int endpointsCount, int badEndpointsCount)
{
Assert.True(endpointsCount > 1);
Assert.True(endpointsCount >= badEndpointsCount);
var clientInvocationManager = GetTestClientInvocationManager(endpointsCount);
var connectionId = TestConnectionIds[0];
var invocationId = clientInvocationManager.Caller.GenerateInvocationId(connectionId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,9 @@ public async Task TestStatusPingChangesEndpointStatus()
var connectionFactory1 = new TestServiceConnectionFactory();
var connectionFactory2 = new TestServiceConnectionFactory();

var hub1 = new MultiEndpointServiceConnectionContainer(connectionFactory1, "hub1", 2, null, sem, router,
var hub1 = new MultiEndpointServiceConnectionContainer(connectionFactory1, "hub1", 2, null, sem, router, null,
loggerFactory);
var hub2 = new MultiEndpointServiceConnectionContainer(connectionFactory2, "hub2", 2, null, sem, router,
var hub2 = new MultiEndpointServiceConnectionContainer(connectionFactory2, "hub2", 2, null, sem, router, null,
loggerFactory);

var connections = connectionFactory1.CreatedConnections.SelectMany(kv => kv.Value).ToArray();
Expand Down Expand Up @@ -1985,7 +1985,7 @@ public TestMultiEndpointServiceConnectionContainer(string hub,
IEndpointRouter router,
ILoggerFactory loggerFactory,
TimeSpan? _ = null
) : base(hub, generator, endpoint, router, loggerFactory)
) : base(hub, generator, endpoint, router, null, loggerFactory)
{
}

Expand Down
Loading