Skip to content

Commit c685ebd

Browse files
Client builder extension customization fix (#54000)
1 parent da29d9e commit c685ebd

File tree

4 files changed

+145
-18
lines changed

4 files changed

+145
-18
lines changed

eng/packages/http-client-csharp/generator/Azure.Generator/src/Providers/ClientBuilderExtensionsDefinition.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,10 @@ protected override MethodProvider[] BuildMethods()
7676
continue;
7777
}
7878

79-
// only add overloads for the full constructors that include the client options parameter
80-
if (constructor.Signature.Parameters.LastOrDefault()?.Type.Equals(client.ClientOptionsParameter.Type) != true)
79+
// Only add overloads for the full constructors that include the client options parameter
80+
// Check that the name of the last parameter matches the client options parameter as the namespace will not be resolved for
81+
// customized constructors. This is safe as we don't allow multiple types types with the same name in an Azure library.
82+
if (constructor.Signature.Parameters.LastOrDefault()?.Type.Name.Equals(client.ClientOptionsParameter.Type.Name) != true)
8183
{
8284
continue;
8385
}

eng/packages/http-client-csharp/generator/Azure.Generator/test/Providers/ClientBuilderExtensionsDefinitions/ClientBuilderExtensionsTests.cs

Lines changed: 84 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
using Microsoft.TypeSpec.Generator.ClientModel.Providers;
1010
using Microsoft.TypeSpec.Generator.Input;
1111
using Microsoft.TypeSpec.Generator.Primitives;
12+
using Microsoft.TypeSpec.Generator.Providers;
1213
using NUnit.Framework;
14+
using static Microsoft.TypeSpec.Generator.Snippets.Snippet;
1315

1416
namespace Azure.Generator.Tests.Providers.ClientBuilderExtensionsDefinitions
1517
{
@@ -18,7 +20,7 @@ public class ClientBuilderExtensionsTests
1820
[Test]
1921
public void AddsClientExtensionForApiKeyAuth()
2022
{
21-
var client = InputFactory.Client("TestClient", "Samples", "");
23+
var client = InputFactory.Client("TestClient", "Samples", "");
2224
var plugin = MockHelpers.LoadMockGenerator(
2325
apiKeyAuth: () => new InputApiKeyAuth("mock", null),
2426
clients: () => [client]);
@@ -35,9 +37,9 @@ public void AddsClientExtensionForApiKeyAuth()
3537
[Test]
3638
public void AddsClientExtensionForOAuth()
3739
{
38-
var client = InputFactory.Client("TestClient", "Samples", "");
40+
var client = InputFactory.Client("TestClient", "Samples", "");
3941
var plugin = MockHelpers.LoadMockGenerator(
40-
oauth2Auth: ()=> new InputOAuth2Auth([new InputOAuth2Flow(["mock"], null, null, null)]),
42+
oauth2Auth: () => new InputOAuth2Auth([new InputOAuth2Flow(["mock"], null, null, null)]),
4143
clients: () => [client]);
4244

4345
var builderExtensions = plugin.Object.OutputLibrary.TypeProviders
@@ -52,10 +54,10 @@ public void AddsClientExtensionForOAuth()
5254
[Test]
5355
public void AddsClientExtensionForEachAuthMethod()
5456
{
55-
var client = InputFactory.Client("TestClient", "Samples", "");
57+
var client = InputFactory.Client("TestClient", "Samples", "");
5658
var plugin = MockHelpers.LoadMockGenerator(
5759
apiKeyAuth: () => new InputApiKeyAuth("mock", null),
58-
oauth2Auth: ()=> new InputOAuth2Auth([new InputOAuth2Flow(["mock"], null, null, null)]),
60+
oauth2Auth: () => new InputOAuth2Auth([new InputOAuth2Flow(["mock"], null, null, null)]),
5961
clients: () => [client]);
6062

6163
var builderExtensions = plugin.Object.OutputLibrary.TypeProviders
@@ -70,11 +72,11 @@ public void AddsClientExtensionForEachAuthMethod()
7072
[Test]
7173
public void AddsClientExtensionForEachAuthMethodMultipleClients()
7274
{
73-
var client1 = InputFactory.Client("TestClient", "Samples", "");
74-
var client2 = InputFactory.Client("TestClient2", "Samples", "");
75+
var client1 = InputFactory.Client("TestClient", "Samples", "");
76+
var client2 = InputFactory.Client("TestClient2", "Samples", "");
7577
var plugin = MockHelpers.LoadMockGenerator(
7678
apiKeyAuth: () => new InputApiKeyAuth("mock", null),
77-
oauth2Auth: ()=> new InputOAuth2Auth([new InputOAuth2Flow(["mock"], null, null, null)]),
79+
oauth2Auth: () => new InputOAuth2Auth([new InputOAuth2Flow(["mock"], null, null, null)]),
7880
clients: () => [client1, client2]);
7981

8082
var builderExtensions = plugin.Object.OutputLibrary.TypeProviders
@@ -86,17 +88,39 @@ public void AddsClientExtensionForEachAuthMethodMultipleClients()
8688
Assert.AreEqual(Helpers.GetExpectedFromFile(), file.Content);
8789
}
8890

91+
[Test]
92+
public void AddsClientExtensionForCustomConstructor()
93+
{
94+
var inputClient = InputFactory.Client("TestClient", "Samples", "");
95+
var plugin = MockHelpers.LoadMockGenerator(
96+
apiKeyAuth: () => new InputApiKeyAuth("mock", null),
97+
clients: () => [inputClient]);
98+
99+
var client = plugin.Object.OutputLibrary.TypeProviders
100+
.OfType<ClientProvider>().Single();
101+
Assert.IsNotNull(client);
102+
MockHelpers.SetCustomCodeView(client, new TestCustomCodeView(client));
103+
104+
var builderExtensions = plugin.Object.OutputLibrary.TypeProviders
105+
.OfType<ClientBuilderExtensionsDefinition>().SingleOrDefault();
106+
107+
Assert.IsNotNull(builderExtensions);
108+
var writer = new TypeProviderWriter(builderExtensions!);
109+
var file = writer.Write();
110+
Assert.AreEqual(Helpers.GetExpectedFromFile(), file.Content);
111+
}
112+
89113
[Test]
90114
public void DoesNotAddExtensionMethodsClassIfOnlyInternalClients()
91115
{
92116
var client = InputFactory.Client("TestClient", "Samples", "");
93117
var plugin = MockHelpers.LoadMockGenerator(
94118
apiKeyAuth: () => new InputApiKeyAuth("mock", null),
95-
oauth2Auth: ()=> new InputOAuth2Auth([new InputOAuth2Flow(["mock"], null, null, null)]),
119+
oauth2Auth: () => new InputOAuth2Auth([new InputOAuth2Flow(["mock"], null, null, null)]),
96120
clients: () => [client],
97121
createClientCore: inputClient =>
98122
{
99-
var provider = new ClientProvider(inputClient);
123+
var provider = new ClientProvider(inputClient);
100124
provider.Update(modifiers: TypeSignatureModifiers.Internal | TypeSignatureModifiers.Class);
101125
return provider;
102126
});
@@ -114,15 +138,16 @@ public void DoesNotAddExtensionMethodsForInternalClients()
114138
var client2 = InputFactory.Client("TestClient2", "Samples", "");
115139
var plugin = MockHelpers.LoadMockGenerator(
116140
apiKeyAuth: () => new InputApiKeyAuth("mock", null),
117-
oauth2Auth: ()=> new InputOAuth2Auth([new InputOAuth2Flow(["mock"], null, null, null)]),
141+
oauth2Auth: () => new InputOAuth2Auth([new InputOAuth2Flow(["mock"], null, null, null)]),
118142
clients: () => [client1, client2],
119143
createClientCore: inputClient =>
120144
{
121-
var provider = new ClientProvider(inputClient);
145+
var provider = new ClientProvider(inputClient);
122146
if (inputClient.Name == "TestClient1")
123147
{
124148
provider.Update(modifiers: TypeSignatureModifiers.Internal | TypeSignatureModifiers.Class);
125149
}
150+
126151
return provider;
127152
});
128153

@@ -136,5 +161,52 @@ public void DoesNotAddExtensionMethodsForInternalClients()
136161
Assert.IsTrue(method.Signature.Name.EndsWith("TestClient2", StringComparison.Ordinal));
137162
}
138163
}
164+
165+
private class TestCustomCodeView : TypeProvider
166+
{
167+
private readonly ClientProvider _clientProvider;
168+
169+
public TestCustomCodeView(ClientProvider clientProvider)
170+
{
171+
_clientProvider = clientProvider;
172+
}
173+
174+
protected override string BuildRelativeFilePath() => _clientProvider.RelativeFilePath;
175+
176+
protected override string BuildName() => _clientProvider.Name;
177+
178+
protected override ConstructorProvider[] BuildConstructors()
179+
=>
180+
[
181+
new ConstructorProvider(
182+
new ConstructorSignature(
183+
Type,
184+
$"",
185+
MethodSignatureModifiers.Public,
186+
[
187+
new ParameterProvider("endpoint", $"", typeof(string)),
188+
new ParameterProvider("options", $"", new TestClientOptionsProvider(_clientProvider.ClientOptions!).Type),
189+
]),
190+
ThrowExpression(Null),
191+
this)
192+
];
193+
}
194+
195+
private class TestClientOptionsProvider : TypeProvider
196+
{
197+
private readonly ClientOptionsProvider _options;
198+
199+
public TestClientOptionsProvider(ClientOptionsProvider clientOptions)
200+
{
201+
_options = clientOptions;
202+
}
203+
204+
protected override string BuildRelativeFilePath() => _options.RelativeFilePath;
205+
206+
protected override string BuildName() => _options.Name;
207+
208+
// simulate empty namespace
209+
protected override string BuildNamespace() => "";
210+
}
139211
}
140212
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
// <auto-generated/>
5+
6+
#nullable disable
7+
8+
using System;
9+
using System.Diagnostics.CodeAnalysis;
10+
using Azure;
11+
using Azure.Core.Extensions;
12+
using Samples;
13+
14+
namespace Microsoft.Extensions.Azure
15+
{
16+
/// <summary> Extension methods to add clients to <see cref="global::Azure.Core.Extensions.IAzureClientBuilder{TClient,TOptions}"/>. </summary>
17+
public static partial class SamplesClientBuilderExtensions
18+
{
19+
/// <summary> Registers a <see cref="TestClient"/> client with the specified <see cref="global::Azure.Core.Extensions.IAzureClientBuilder{TClient,TOptions}"/>. </summary>
20+
/// <param name="builder"> The builder to register with. </param>
21+
/// <param name="endpoint"> Service endpoint. </param>
22+
/// <param name="credential"> A credential used to authenticate to the service. </param>
23+
/// <exception cref="global::System.ArgumentNullException"> <paramref name="endpoint"/> or <paramref name="credential"/> is null. </exception>
24+
public static global::Azure.Core.Extensions.IAzureClientBuilder<global::Samples.TestClient, global::Samples.TestClientOptions> AddTestClient<TBuilder>(this TBuilder builder, global::System.Uri endpoint, global::Azure.AzureKeyCredential credential)
25+
where TBuilder : global::Azure.Core.Extensions.IAzureClientFactoryBuilder
26+
{
27+
global::Samples.Argument.AssertNotNull(endpoint, nameof(endpoint));
28+
global::Samples.Argument.AssertNotNull(credential, nameof(credential));
29+
30+
return builder.RegisterClientFactory<global::Samples.TestClient, global::Samples.TestClientOptions>(options => new global::Samples.TestClient(endpoint, credential, options));
31+
}
32+
33+
/// <summary> Registers a <see cref="TestClient"/> client with the specified <see cref="global::Azure.Core.Extensions.IAzureClientBuilder{TClient,TOptions}"/>. </summary>
34+
/// <param name="builder"> The builder to register with. </param>
35+
/// <param name="endpoint"></param>
36+
public static global::Azure.Core.Extensions.IAzureClientBuilder<global::Samples.TestClient, global::Samples.TestClientOptions> AddTestClient<TBuilder>(this TBuilder builder, string endpoint)
37+
where TBuilder : global::Azure.Core.Extensions.IAzureClientFactoryBuilder
38+
{
39+
return builder.RegisterClientFactory<global::Samples.TestClient, global::Samples.TestClientOptions>(options => new global::Samples.TestClient(endpoint, options));
40+
}
41+
42+
/// <summary> Registers a <see cref="TestClient"/> client with the specified <see cref="global::Azure.Core.Extensions.IAzureClientBuilder{TClient,TOptions}"/>. </summary>
43+
/// <param name="builder"> The builder to register with. </param>
44+
/// <param name="configuration"> The configuration to use for the client. </param>
45+
[global::System.Diagnostics.CodeAnalysis.RequiresUnreferencedCodeAttribute("Requires unreferenced code until we opt into EnableConfigurationBindingGenerator.")]
46+
[global::System.Diagnostics.CodeAnalysis.RequiresDynamicCodeAttribute("Requires unreferenced code until we opt into EnableConfigurationBindingGenerator.")]
47+
public static global::Azure.Core.Extensions.IAzureClientBuilder<global::Samples.TestClient, global::Samples.TestClientOptions> AddTestClient<TBuilder, TConfiguration>(this TBuilder builder, TConfiguration configuration)
48+
where TBuilder : global::Azure.Core.Extensions.IAzureClientFactoryBuilderWithConfiguration<TConfiguration>
49+
{
50+
return builder.RegisterClientFactory<global::Samples.TestClient, global::Samples.TestClientOptions>(configuration);
51+
}
52+
}
53+
}

eng/packages/http-client-csharp/generator/Azure.Generator/test/TestHelpers/MockHelpers.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,12 @@ public static Mock<AzureClientGenerator> LoadMockGenerator(
113113
return mockPluginInstance;
114114
}
115115

116-
public static void SetCustomCodeView(ModelProvider modelProvider, TypeProvider customCodeTypeProvider)
116+
public static void SetCustomCodeView(TypeProvider typeProvider, TypeProvider customCodeTypeProvider)
117117
{
118-
modelProvider.GetType().BaseType!.GetField(
118+
typeProvider.GetType().BaseType!.GetField(
119119
"_customCodeView",
120-
System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance)?
121-
.SetValue(modelProvider, new Lazy<TypeProvider>(() => customCodeTypeProvider));
120+
BindingFlags.NonPublic | BindingFlags.Instance)?
121+
.SetValue(typeProvider, new Lazy<TypeProvider>(() => customCodeTypeProvider));
122122
}
123123
}
124124
}

0 commit comments

Comments
 (0)