Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions src/Microsoft.Azure.SignalR.Common/Utilities/RestClient.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
Expand Down Expand Up @@ -79,6 +79,17 @@ public Task SendMessageWithRetryAsync(
return SendAsyncCore(Constants.HttpClientNames.MessageResilient, api, httpMethod, new InvocationMessage(methodName, args), null, AsAsync(handleExpectedResponse), cancellationToken);
}

public Task SendMessageWithRetryAsync(
RestApiEndpoint api,
HttpMethod httpMethod,
string methodName,
object?[] args,
Func<HttpResponseMessage, Task<bool>>? handleExpectedResponseAsync = null,
CancellationToken cancellationToken = default)
{
return SendAsyncCore(Constants.HttpClientNames.MessageResilient, api, httpMethod, new InvocationMessage(methodName, args), null, handleExpectedResponseAsync, cancellationToken);
}

public Task SendStreamMessageWithRetryAsync(
RestApiEndpoint api,
HttpMethod httpMethod,
Expand Down Expand Up @@ -184,7 +195,6 @@ private async Task SendAsyncCore(

private HttpRequestMessage BuildRequest(RestApiEndpoint api, HttpMethod httpMethod, HubMessage? body, Type? typeHint)
{
var payload = httpMethod == HttpMethod.Post ? body : null;
return GenerateHttpRequest(api.Audience, api.Query, httpMethod, body, typeHint, api.Token);
}

Expand All @@ -198,4 +208,4 @@ private HttpRequestMessage GenerateHttpRequest(string url, IDictionary<string, S

private static Func<HttpResponseMessage, Task<bool>>? AsAsync(Func<HttpResponseMessage, bool>? syncFunc) =>
syncFunc == null ? null : (response => Task.FromResult(syncFunc(response)));
}
}
5 changes: 5 additions & 0 deletions src/Microsoft.Azure.SignalR.Management/RestApiProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ public Task<RestApiEndpoint> SendStreamCompletionAsync(string appName, string hu
return GenerateRestApiEndpointAsync(appName, hubName, $"/connections/{Uri.EscapeDataString(connectionId)}/streams/{Uri.EscapeDataString(streamId)}/:complete");
}

public Task<RestApiEndpoint> SendClientInvocationAsync(string appName, string hubName, string connectionId, TimeSpan? lifetime = null)
{
return GenerateRestApiEndpointAsync(appName, hubName, $"/connections/{Uri.EscapeDataString(connectionId)}/:invoke", lifetime);
}

private async Task<RestApiEndpoint> GenerateRestApiEndpointAsync(string appName, string hubName, string pathAfterHub, TimeSpan? lifetime = null, IDictionary<string, StringValues> queries = null)
{
var requestPrefixWithHub = $"{_serverEndpoint}api/hubs/{Uri.EscapeDataString(hubName.ToLowerInvariant())}";
Expand Down
69 changes: 69 additions & 0 deletions src/Microsoft.Azure.SignalR.Management/RestHubLifetimeManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,19 @@
using System.Linq;
using System.Net;
using System.Net.Http;
#if NET7_0_OR_GREATER
using System.Text.Json;
using System.Text.Json.Nodes;
#endif
using System.Threading;
using System.Threading.Tasks;

using Azure;

using Microsoft.AspNetCore.SignalR;
#if NET7_0_OR_GREATER
using Microsoft.AspNetCore.SignalR.Protocol;
#endif
using Microsoft.Azure.SignalR.Protocol;
using Microsoft.Extensions.Primitives;

Expand Down Expand Up @@ -351,6 +358,68 @@ public async Task SendStreamCompletionAsync(string connectionId, string streamId
await _restClient.SendWithRetryAsync(api, HttpMethod.Post, cancellationToken: cancellationToken);
}

#if NET7_0_OR_GREATER
#nullable enable
public override async Task<T> InvokeConnectionAsync<T>(string connectionId, string methodName, object?[] args, CancellationToken cancellationToken = default)
{
// Validate input parameters
if (string.IsNullOrEmpty(methodName))
{
throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(methodName));
}
if (string.IsNullOrEmpty(connectionId))
{
throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(connectionId));
}

// Get API endpoint and prepare for the request
var api = await _restApiProvider.SendClientInvocationAsync(_appName, _hubName, connectionId);
string? responseContent = null;
var isSuccessStatusCode = false;
// Send request and capture the response
await _restClient.SendMessageWithRetryAsync(
api,
HttpMethod.Post,
methodName,
args,
async response =>
{
responseContent = await response.Content.ReadAsStringAsync();
isSuccessStatusCode = response.IsSuccessStatusCode;
return true;
},
cancellationToken: cancellationToken);

// Ensure we have a response
if (string.IsNullOrWhiteSpace(responseContent))
{
throw new HubException("Response content is null or empty");
}

var root = JsonNode.Parse(responseContent)
?? throw new HubException("Failed to parse response as JSON");

if (!isSuccessStatusCode)
{
var message = root["message"]?.GetValue<string>() ?? "Unknown error";
throw new HubException(message);
}

var resultNode = root["jsonObject"]?["result"]
?? throw new HubException("Result not found in JSON response");

return resultNode.Deserialize<T>()
?? throw new HubException("Failed to deserialize result");
}

public override Task SetConnectionResultAsync(string connectionId, CompletionMessage result)
{
// This method won't get trigger because in transient we will wait for the returned completion message.
// this is to honor the interface
throw new NotImplementedException();
}
#endif

private static bool FilterExpectedResponse(HttpResponseMessage response, string expectedErrorCode) =>
response.IsSuccessStatusCode
|| (response.StatusCode == HttpStatusCode.NotFound && response.Headers.TryGetValues(Headers.MicrosoftErrorCode, out var errorCodes) && errorCodes.First().Equals(expectedErrorCode, StringComparison.OrdinalIgnoreCase));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,6 @@ private static IEnumerable<object[]> GetTestDataByContext((string appName, strin
yield return new object[] { RestApiProvider.GetUserGroupManagementEndpointAsync(context.appName, context.hubName, context.userId, context.groupName), $"{commonEndpoint}/users/{Uri.EscapeDataString(context.userId)}/groups/{Uri.EscapeDataString(context.groupName)}?{commonQueryString}" };
yield return new object[] { RestApiProvider.GetSendToConnectionEndpointAsync(context.appName, context.hubName, context.connectionId), $"{commonEndpoint}/connections/{Uri.EscapeDataString(context.connectionId)}/:send?{commonQueryString}" };
yield return new object[] { RestApiProvider.GetConnectionGroupManagementEndpointAsync(context.appName, context.hubName, context.connectionId, context.groupName), $"{commonEndpoint}/groups/{Uri.EscapeDataString(context.groupName)}/connections/{Uri.EscapeDataString(context.connectionId)}?{commonQueryString}" };
yield return new object[] { RestApiProvider.SendClientInvocationAsync(context.appName, context.hubName, context.connectionId), $"{commonEndpoint}/connections/{Uri.EscapeDataString(context.connectionId)}/:invoke?{commonQueryString}" };
}
}
Loading
Loading