Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -403,34 +403,7 @@ private static void AppendMethodSetup(Class @class, StringBuilder sb, Method met

sb.AppendLine();
sb.AppendLine("\t\t{");

if (@class is { ClassName: "HttpClient", ClassFullName: "System.Net.Http.HttpClient", } &&
method.Name.StartsWith("Send"))
{
sb.Append("\t\t\tif (setup is Mock<System.Net.Http.HttpClient> httpClientMock &&").AppendLine();
sb.Append(
"\t\t\t httpClientMock.ConstructorParameters[0] is IMockSubject<System.Net.Http.HttpMessageHandler> httpMessageHandlerMock &&")
.AppendLine();
sb.Append(
"\t\t\t httpMessageHandlerMock.Mock is IMockMethodSetup<System.Net.Http.HttpMessageHandler> httpMessageHandlerSetup)")
.AppendLine();
sb.Append("\t\t\t{").AppendLine();
AppendMethodSetupBody(sb, method, useParameters,
"\t\t\t\t",
method.GetUniqueNameString().Replace("System.Net.Http.HttpMessageInvoker",
"System.Net.Http.HttpMessageHandler"),
"httpMessageHandlerSetup");
sb.Append("\t\t\t}").AppendLine();
sb.Append("\t\t\telse").AppendLine();
sb.Append("\t\t\t{").AppendLine();
AppendMethodSetupBody(sb, method, useParameters, "\t\t\t\t");
sb.Append("\t\t\t}").AppendLine();
}
else
{
AppendMethodSetupBody(sb, method, useParameters);
}

AppendMethodSetupBody(sb, method, useParameters);
sb.AppendLine("\t\t}");
}

Expand Down
26 changes: 23 additions & 3 deletions Source/Mockolate.SourceGenerators/Sources/Sources.ForMock.cs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ namespace Mockolate.Generated;
sb.Append("\t\tvar ").Append(resultVarName).Append(" = _mock.Registrations.InvokeMethod(")
.Append(mockClass.Delegate.GetUniqueNameString());
}

foreach (MethodParameter p in mockClass.Delegate.Parameters)
{
sb.Append(", new NamedParameterValue(\"").Append(p.Name).Append("\", ").Append(p.RefKind switch
Expand Down Expand Up @@ -291,7 +291,7 @@ private static void AppendMockSubject_ImplementClass(StringBuilder sb, Class @cl
if (mockMethods?.All(m => !Method.EqualityComparer.Equals(method, m)) != false)
{
AppendMockSubject_ImplementClass_AddMethod(sb, method, className, mockClass is not null,
@class.IsInterface);
@class.IsInterface, @class);
}
}

Expand Down Expand Up @@ -607,7 +607,7 @@ property.IndexerParameters is not null
}

private static void AppendMockSubject_ImplementClass_AddMethod(StringBuilder sb, Method method, string className,
bool explicitInterfaceImplementation, bool isClassInterface)
bool explicitInterfaceImplementation, bool isClassInterface, Class @class)
{
sb.Append("\t/// <inheritdoc cref=\"").Append(method.ContainingType.EscapeForXmlDoc()).Append('.')
.Append(method.Name.EscapeForXmlDoc())
Expand Down Expand Up @@ -757,6 +757,26 @@ private static void AppendMockSubject_ImplementClass_AddMethod(StringBuilder sb,
if (method.ReturnType != Type.Void)
{
string baseResultVarName = Helpers.GetUniqueLocalVariableName("baseResult", method.Parameters);

if (method.Name.StartsWith("Send", StringComparison.Ordinal) &&
@class is { ClassName: "HttpClient", ClassFullName: "System.Net.Http.HttpClient", })
{
sb.Append("\t\t\t#if NETFRAMEWORK").AppendLine();
sb.Append("\t\t\t// Persist the HttpContent, because it gets automatically disposed on .NET Framework").AppendLine();
sb.Append("\t\t\tif (request.Content != null)").AppendLine();
sb.Append("\t\t\t{").AppendLine();
sb.Append(
"\t\t\t\tvar stream = request.Content.ReadAsStreamAsync().ConfigureAwait(false).GetAwaiter().GetResult();")
.AppendLine();
sb.Append("\t\t\t\tusing System.IO.MemoryStream ms = new();").AppendLine();
sb.Append("\t\t\t\tstream.CopyTo(ms);").AppendLine();
sb.Append("\t\t\t\tbyte[] bytes = ms.ToArray();").AppendLine();
sb.Append("\t\t\t\tstream.Position = 0L;").AppendLine();
sb.Append("\t\t\t\trequest.Properties.Add(\"Mockolate:HttpContent\", bytes);").AppendLine();
Comment thread
vbreuss marked this conversation as resolved.
sb.Append("\t\t\t}").AppendLine();
sb.Append("\t\t\t#endif").AppendLine();
}

sb.Append("\t\t\tvar ").Append(baseResultVarName).Append(" = base.").Append(method.Name).Append('(')
.Append(FormatMethodParametersWithRefKind(method.Parameters))
.Append(");").AppendLine();
Expand Down
10 changes: 8 additions & 2 deletions Source/Mockolate/Web/HttpClientExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ public void InvokeCallbacks(object? value)
}
}


private interface IHttpRequestMessageParameter
{
bool Matches(HttpRequestMessage value);
Expand All @@ -87,7 +86,14 @@ private sealed class HttpRequestMessageParameter<T>(
: IHttpRequestMessageParameter
{
public bool Matches(HttpRequestMessage value)
=> ((IParameter)parameter).Matches(valueSelector(value));
{
if (parameter is IHttpRequestMessagePropertyParameter<T> httpRequestMessageParameter)
{
return httpRequestMessageParameter.Matches(valueSelector(value), value);
}

return ((IParameter)parameter).Matches(valueSelector(value));
}

public void InvokeCallbacks(HttpRequestMessage value)
{
Expand Down
16 changes: 16 additions & 0 deletions Source/Mockolate/Web/IHttpRequestMessagePropertyParameter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
using System.Net.Http;
using Mockolate.Parameters;

namespace Mockolate.Web;

/// <summary>
/// A parameter of type <typeparamref name="T" /> that also gets a <see cref="HttpRequestMessage" />.
/// </summary>
internal interface IHttpRequestMessagePropertyParameter<T> : IParameter<T>
{
/// <summary>
/// Matches the property of type <typeparamref name="T" /> while also considering the
/// <paramref name="requestMessage" />.
/// </summary>
bool Matches(T value, HttpRequestMessage? requestMessage);
}
118 changes: 76 additions & 42 deletions Source/Mockolate/Web/ItExtensions.HttpContent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ public interface IHttpContentParameter : IParameter<HttpContent?>, IHttpHeaderPa
}

private sealed class HttpContentParameter
: IParameter, IStringContentBodyMatchingParameter, IFormDataContentParameter
: IParameter, IStringContentBodyMatchingParameter, IFormDataContentParameter,
IHttpRequestMessagePropertyParameter<HttpContent?>
{
private List<Action<HttpContent?>>? _callbacks;
private IContentMatcher? _contentMatcher;
Expand All @@ -137,9 +138,38 @@ public IFormDataContentParameter Exactly()
return this;
}

/// <inheritdoc cref="IHttpRequestMessagePropertyParameter{HttpContent}.Matches(HttpContent, HttpRequestMessage)" />
public bool Matches(HttpContent? value, HttpRequestMessage? requestMessage)
{
if (value is null)
{
return false;
}

if (_mediaType is not null &&
value.Headers.ContentType?.MediaType?.Equals(_mediaType, StringComparison.OrdinalIgnoreCase) != true)
{
return false;
}

if (_headers is not null &&
!_headers.Matches(value.Headers))
{
return false;
}

if (_contentMatcher is not null &&
!_contentMatcher.Matches(value, requestMessage))
{
return false;
}

return true;
}

/// <inheritdoc cref="IParameter.Matches(object?)" />
public bool Matches(object? value)
=> value is HttpContent typedValue && Matches(typedValue);
=> value is HttpContent typedValue && Matches(typedValue, null);

/// <inheritdoc cref="IParameter.InvokeCallbacks(object?)" />
public void InvokeCallbacks(object? value)
Expand Down Expand Up @@ -251,33 +281,7 @@ public IStringContentBodyParameter IgnoringCase()
return this;
}

/// <summary>
/// Checks whether the given <see cref="HttpContent" /> <paramref name="value" /> matches the expectations.
/// </summary>
private bool Matches(HttpContent value)
{
if (_mediaType is not null &&
value.Headers.ContentType?.MediaType?.Equals(_mediaType, StringComparison.OrdinalIgnoreCase) != true)
{
return false;
}

if (_headers is not null &&
!_headers.Matches(value.Headers))
{
return false;
}

if (_contentMatcher is not null &&
!_contentMatcher.Matches(value))
{
return false;
}

return true;
}

private static string GetStringFromHttpContent(HttpContent content)
private static string GetStringFromHttpContent(HttpContent content, HttpRequestMessage? message)
{
static Encoding GetEncodingFromCharset(string? charset)
{
Expand All @@ -300,18 +304,32 @@ static Encoding GetEncodingFromCharset(string? charset)
Encoding encoding = GetEncodingFromCharset(charset);
#if NET8_0_OR_GREATER
Stream stream = content.ReadAsStream();
long position = stream.Position;
using StreamReader reader = new(stream, encoding, leaveOpen: true);
string stringContent = reader.ReadToEnd();
stream.Position = position;
Comment thread
vbreuss marked this conversation as resolved.
#else
Stream stream = content.ReadAsStreamAsync().ConfigureAwait(false).GetAwaiter().GetResult();
using StreamReader reader = new(stream, encoding);
string stringContent;
if (message?.Properties.TryGetValue("Mockolate:HttpContent", out object value) == true
&& value is byte[] bytes)
Comment thread
vbreuss marked this conversation as resolved.
{
stringContent = encoding.GetString(bytes);
}
else
{
Stream stream = content.ReadAsStreamAsync().ConfigureAwait(false).GetAwaiter().GetResult();
long position = stream.Position;
using StreamReader reader = new(stream, encoding, true, 1024, true);
stringContent = reader.ReadToEnd();
stream.Position = position;
}
#endif
string stringContent = reader.ReadToEnd();
return stringContent;
}

private interface IContentMatcher
{
bool Matches(HttpContent content);
bool Matches(HttpContent content, HttpRequestMessage? message);
}

private sealed class StringMatcher : IContentMatcher
Expand All @@ -328,9 +346,9 @@ public StringMatcher(string value, bool isExact)
_bodyMatchType = isExact ? BodyMatchType.Exact : BodyMatchType.Wildcard;
}

public bool Matches(HttpContent content)
public bool Matches(HttpContent content, HttpRequestMessage? message)
{
string stringContent = GetStringFromHttpContent(content);
string stringContent = GetStringFromHttpContent(content, message);
switch (_bodyMatchType)
{
case BodyMatchType.Exact when
Expand Down Expand Up @@ -387,9 +405,9 @@ public PredicateStringMatcher(Func<string, bool> predicate)
_predicate = predicate;
}

public bool Matches(HttpContent content)
public bool Matches(HttpContent content, HttpRequestMessage? message)
{
string stringContent = GetStringFromHttpContent(content);
string stringContent = GetStringFromHttpContent(content, message);
return _predicate.Invoke(stringContent);
}
}
Expand All @@ -403,16 +421,32 @@ public BinaryMatcher(Func<byte[], bool> predicate)
_predicate = predicate;
}

public bool Matches(HttpContent content)
public bool Matches(HttpContent content, HttpRequestMessage? message)
{
#if NET8_0_OR_GREATER
Stream stream = content.ReadAsStream();
#else
Stream stream = content.ReadAsStreamAsync().ConfigureAwait(false).GetAwaiter().GetResult();
#endif
long position = stream.Position;
using MemoryStream ms = new();
stream.CopyTo(ms);
byte[] bytes = ms.ToArray();
stream.Position = position;
Comment thread
vbreuss marked this conversation as resolved.
#else
byte[] bytes;
if (message?.Properties.TryGetValue("Mockolate:HttpContent", out object value) == true
&& value is byte[] b)
{
bytes = b;
}
else
{
Stream stream = content.ReadAsStreamAsync().ConfigureAwait(false).GetAwaiter().GetResult();
long position = stream.Position;
using MemoryStream ms = new();
stream.CopyTo(ms);
bytes = ms.ToArray();
stream.Position = position;
}
#endif
return _predicate.Invoke(bytes);
}
}
Expand All @@ -439,7 +473,7 @@ public FormDataMatcher(string formDataParameters)
.Select(pair => (pair.Key, new HttpFormDataValue(pair.Value))));
}

public bool Matches(HttpContent content)
public bool Matches(HttpContent content, HttpRequestMessage? message)
{
List<(string Key, string Value)> formDataParameters = GetFormData(content).ToList();
return _isExactly
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ public async Task StringUri_ShouldVerifyHttpContent(string mediaType, int expect
HttpClient httpClient = Mock.Create<HttpClient>();

await httpClient.PatchAsync("https://www.aweXpect.com",
new StringContent("", Encoding.UTF8, mediaType),
new StringContent("{}", Encoding.UTF8, mediaType),
CancellationToken.None);

await That(httpClient.VerifyMock.Invoked.PatchAsync(
It.IsAny<string?>(),
It.IsHttpContent("application/json")))
It.IsHttpContent("application/json").WithString("{}")))
.Exactly(expected);
}

Expand Down Expand Up @@ -111,12 +111,12 @@ public async Task Uri_ShouldVerifyHttpContent(string mediaType, int expected)
HttpClient httpClient = Mock.Create<HttpClient>();

await httpClient.PatchAsync("https://www.aweXpect.com",
new StringContent("", Encoding.UTF8, mediaType),
new StringContent("{}", Encoding.UTF8, mediaType),
CancellationToken.None);

await That(httpClient.VerifyMock.Invoked.PatchAsync(
It.IsUri("*aweXpect.com*"),
It.IsHttpContent("application/json")))
It.IsHttpContent("application/json").WithString("{}")))
.Exactly(expected);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ public async Task StringUri_ShouldVerifyHttpContent(string mediaType, int expect
HttpClient httpClient = Mock.Create<HttpClient>();

await httpClient.PostAsync("https://www.aweXpect.com",
new StringContent("", Encoding.UTF8, mediaType),
new StringContent("{}", Encoding.UTF8, mediaType),
CancellationToken.None);

await That(httpClient.VerifyMock.Invoked.PostAsync(
It.IsAny<string?>(),
It.IsHttpContent("application/json")))
It.IsHttpContent("application/json").WithString("{}")))
.Exactly(expected);
}

Expand Down Expand Up @@ -109,12 +109,12 @@ public async Task Uri_ShouldVerifyHttpContent(string mediaType, int expected)
HttpClient httpClient = Mock.Create<HttpClient>();

await httpClient.PostAsync("https://www.aweXpect.com",
new StringContent("", Encoding.UTF8, mediaType),
new StringContent("{}", Encoding.UTF8, mediaType),
CancellationToken.None);

await That(httpClient.VerifyMock.Invoked.PostAsync(
It.IsUri("*aweXpect.com*"),
It.IsHttpContent("application/json")))
It.IsHttpContent("application/json").WithString("{}")))
.Exactly(expected);
}

Expand Down
Loading
Loading