Skip to content

Commit

Permalink
Generate Allow header if a framework is called with the wrong HTTP …
Browse files Browse the repository at this point in the history
…method (#590)
  • Loading branch information
Kaliumhexacyanoferrat authored Dec 12, 2024
1 parent 8dbdc25 commit ea98e9b
Show file tree
Hide file tree
Showing 12 changed files with 137 additions and 19 deletions.
24 changes: 22 additions & 2 deletions API/Content/ProviderException.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,38 @@ public class ProviderException : Exception
/// </summary>
public ResponseStatus Status { get; }

/// <summary>
/// Modifications to be applied to the generated HTTP response.
/// </summary>
public Action<IResponseBuilder>? Modifications { get; }

#endregion

#region Initialization

public ProviderException(ResponseStatus status, string message) : base(message)
/// <summary>
/// Raises an exception that allows the server to derive a HTTP response status from.
/// </summary>
/// <param name="status">The status of the HTTP response to be set</param>
/// <param name="message">The error message to return to the client</param>
/// <param name="modifications">The modifications to be applied to the generated response</param>
public ProviderException(ResponseStatus status, string message, Action<IResponseBuilder>? modifications = null) : base(message)
{
Status = status;
Modifications = modifications;
}

public ProviderException(ResponseStatus status, string message, Exception inner) : base(message, inner)
/// <summary>
/// Raises an exception that allows the server to derive a HTTP response status from.
/// </summary>
/// <param name="status">The status of the HTTP response to be set</param>
/// <param name="message">The error message to return to the client</param>
/// <param name="inner">The original exception that caused this error</param>
/// <param name="modifications">The modifications to be applied to the generated response</param>
public ProviderException(ResponseStatus status, string message, Exception inner, Action<IResponseBuilder>? modifications = null) : base(message, inner)
{
Status = status;
Modifications = modifications;
}

#endregion
Expand Down
3 changes: 1 addition & 2 deletions API/Protocol/IResponseModification.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
/// This can be useful if you would like to add behavior that the
/// original handler (such as a page renderer) does not provide.
/// For example, as the page handlers implement this interface,
/// you can add an additional header to the response being generated
/// for a page.
/// you can add a header to the response being generated for a page.
/// </remarks>
/// <typeparam name="TBuilder">The type of builder used as a return value</typeparam>
public interface IResponseModification<out TBuilder>
Expand Down
11 changes: 11 additions & 0 deletions Modules/Basics/CoreExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,17 @@ public static bool HasType(this IRequest request, params RequestMethod[] methods
/// <param name="contentType">The content type of this response</param>
public static IResponseBuilder Type(this IResponseBuilder builder, string contentType) => builder.Type(FlexibleContentType.Parse(contentType));

/// <summary>
/// Applies the given modifications to the response.
/// </summary>
/// <param name="builder">The response to be modified</param>
/// <param name="modifications">The modifications to be applied</param>
public static IResponseBuilder Apply(this IResponseBuilder builder, Action<IResponseBuilder>? modifications)
{
modifications?.Invoke(builder);
return builder;
}

#endregion

#region Content types
Expand Down
2 changes: 2 additions & 0 deletions Modules/ErrorHandling/HtmlErrorMapper.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using GenHTTP.Api.Content;
using GenHTTP.Api.Protocol;
using GenHTTP.Modules.Basics;
using GenHTTP.Modules.Pages;

namespace GenHTTP.Modules.ErrorHandling;
Expand All @@ -19,6 +20,7 @@ public class HtmlErrorMapper : IErrorMapper<Exception>

return request.GetPage(page)
.Status(e.Status)
.Apply(e.Modifications)
.Build();
}
else
Expand Down
11 changes: 6 additions & 5 deletions Modules/ErrorHandling/StructuredErrorMapper.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using GenHTTP.Api.Content;
using GenHTTP.Api.Protocol;
using GenHTTP.Modules.Basics;
using GenHTTP.Modules.Conversion;
using GenHTTP.Modules.Conversion.Serializers;
using GenHTTP.Modules.Conversion.Serializers.Json;
Expand Down Expand Up @@ -45,32 +46,32 @@ public record ErrorModel(ResponseStatus Status, string Message, string? StackTra
{
var model = new ErrorModel(e.Status, error.Message, stackTrace);

return await RenderAsync(request, model);
return (await RenderAsync(request, model)).Apply(e.Modifications).Build();
}
else
{
var model = new ErrorModel(ResponseStatus.InternalServerError, error.Message, stackTrace);

return await RenderAsync(request, model);
return (await RenderAsync(request, model)).Build();
}
}

public async ValueTask<IResponse?> GetNotFound(IRequest request, IHandler handler)
{
var model = new ErrorModel(ResponseStatus.NotFound, "The requested resource does not exist on this server");

return await RenderAsync(request, model);
return (await RenderAsync(request, model)).Build();
}

private async ValueTask<IResponse> RenderAsync(IRequest request, ErrorModel model)
private async ValueTask<IResponseBuilder> RenderAsync(IRequest request, ErrorModel model)
{
var format = Registry.GetSerialization(request) ?? new JsonFormat();

var response = await format.SerializeAsync(request, model);

response.Status(model.Status);

return response.Build();
return response;
}

#endregion
Expand Down
2 changes: 1 addition & 1 deletion Modules/IO/Providers/DownloadProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public DownloadProvider(IResource resourceProvider, string? fileName, FlexibleCo

if (!request.HasType(RequestMethod.Get, RequestMethod.Head))
{
throw new ProviderException(ResponseStatus.MethodNotAllowed, "Only GET requests are allowed by this handler");
throw new ProviderException(ResponseStatus.MethodNotAllowed, "Only GET requests are allowed by this handler", (b) => b.Header("Allow", "GET"));
}

var response = request.Respond()
Expand Down
20 changes: 13 additions & 7 deletions Modules/Reflection/MethodCollection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public MethodCollection(IEnumerable<MethodHandler> methods)

public ValueTask<IResponse?> HandleAsync(IRequest request)
{
var methods = FindProviders(request.Target.GetRemaining().ToString(), request.Method, out var foundOthers);
var methods = FindProviders(request.Target.GetRemaining().ToString(), request.Method, out var others);

if (methods.Count == 1)
{
Expand All @@ -43,9 +43,15 @@ public MethodCollection(IEnumerable<MethodHandler> methods)

throw new ProviderException(ResponseStatus.BadRequest, $"There are multiple methods matching '{request.Target.Path}'");
}
if (foundOthers)

if (others.Count > 0)
{
throw new ProviderException(ResponseStatus.MethodNotAllowed, "There is no method of a matching request type");
throw new ProviderException(ResponseStatus.MethodNotAllowed, "There is no method of a matching request type", AddAllowHeader);

void AddAllowHeader(IResponseBuilder b)
{
b.Header("Allow", string.Join(", ", others.Select(o => o.RawMethod.ToUpper())));
}
}

return new ValueTask<IResponse?>();
Expand All @@ -59,9 +65,9 @@ public async ValueTask PrepareAsync()
}
}

private List<MethodHandler> FindProviders(string path, FlexibleRequestMethod requestedMethod, out bool foundOthers)
private List<MethodHandler> FindProviders(string path, FlexibleRequestMethod requestedMethod, out HashSet<FlexibleRequestMethod> otherMethods)
{
foundOthers = false;
otherMethods = new HashSet<FlexibleRequestMethod>();

var result = new List<MethodHandler>(2);

Expand All @@ -75,7 +81,7 @@ private List<MethodHandler> FindProviders(string path, FlexibleRequestMethod req
}
else
{
foundOthers = true;
otherMethods.UnionWith(method.Configuration.SupportedMethods);
}
}
else
Expand All @@ -88,7 +94,7 @@ private List<MethodHandler> FindProviders(string path, FlexibleRequestMethod req
}
else
{
foundOthers = true;
otherMethods.UnionWith(method.Configuration.SupportedMethods);
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion Modules/ServerSentEvents/Handler/EventSourceHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public EventSourceHandler(Func<IRequest, string?, ValueTask<bool>>? inspector, F
{
if (request.Method.KnownMethod != RequestMethod.Get)
{
throw new ProviderException(ResponseStatus.MethodNotAllowed, "Server Sent Events require a GET request to establish a connection");
throw new ProviderException(ResponseStatus.MethodNotAllowed, "Server Sent Events require a GET request to establish a connection", (b) => b.Header("Allow", "GET"));
}

request.Headers.TryGetValue("Last-Event-ID", out var lastId);
Expand Down
75 changes: 75 additions & 0 deletions Testing/Acceptance/Modules/ErrorHandling/HtmlErrorMapperTest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
using System.Net;
using GenHTTP.Api.Content;
using GenHTTP.Api.Protocol;
using GenHTTP.Modules.ErrorHandling;
using GenHTTP.Modules.Functional;
using Microsoft.VisualStudio.TestTools.UnitTesting;

namespace GenHTTP.Testing.Acceptance.Modules.ErrorHandling;

[TestClass]
public class HtmlErrorMapperTest
{

[TestMethod]
[MultiEngineTest]
public async Task TestNotFound(TestEngine engine)
{
await using var host = await TestHost.RunAsync(Inline.Create().Add(ErrorHandler.Html()), engine: engine);

using var response = await host.GetResponseAsync();

await response.AssertStatusAsync(HttpStatusCode.NotFound);
}

[TestMethod]
[MultiEngineTest]
public async Task TestGeneralError(TestEngine engine)
{
var handler = Inline.Create()
.Get(() => DoThrow(new Exception("Oops")))
.Add(ErrorHandler.Html());

await using var host = await TestHost.RunAsync(handler, engine: engine);

using var response = await host.GetResponseAsync();

await response.AssertStatusAsync(HttpStatusCode.InternalServerError);
}

[TestMethod]
[MultiEngineTest]
public async Task TestProviderError(TestEngine engine)
{
var handler = Inline.Create()
.Get(() => DoThrow(new ProviderException(ResponseStatus.Locked, "Locked up!")))
.Add(ErrorHandler.Html());

await using var host = await TestHost.RunAsync(handler, engine: engine);

using var response = await host.GetResponseAsync();

await response.AssertStatusAsync(HttpStatusCode.Locked);
}

[TestMethod]
[MultiEngineTest]
public async Task TestNoTraceInProduction(TestEngine engine)
{
var handler = Inline.Create()
.Get(() => DoThrow(new Exception("Oops")))
.Add(ErrorHandler.Html());

await using var host = await TestHost.RunAsync(handler, development: false, engine: engine);

using var response = await host.GetResponseAsync();

await response.AssertStatusAsync(HttpStatusCode.InternalServerError);
}

private static void DoThrow(Exception e)
{
throw e;
}

}
2 changes: 2 additions & 0 deletions Testing/Acceptance/Modules/IO/DownloadTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ public async Task DownloadsCannotBeModified(TestEngine engine)
using var response = await runner.GetResponseAsync(request);

await response.AssertStatusAsync(HttpStatusCode.MethodNotAllowed);

Assert.AreEqual("GET", response.GetContentHeader("Allow"));
}

[TestMethod]
Expand Down
2 changes: 2 additions & 0 deletions Testing/Acceptance/Modules/ServerSentEvents/ProtocolTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ public async Task TestGetOnly(TestEngine engine)
using var response = await host.GetResponseAsync(request);

await response.AssertStatusAsync(HttpStatusCode.MethodNotAllowed);

Assert.AreEqual("GET", response.GetContentHeader("Allow"));
}

[TestMethod]
Expand Down
2 changes: 1 addition & 1 deletion Testing/Acceptance/Modules/Webservices/WebserviceTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ public async Task TestUnsupportedDownloadEnforcesDefault(TestEngine engine)
[MultiEngineTest]
public async Task TestWrongMethod(TestEngine engine)
{
await WithResponse(engine, "entity", HttpMethod.Put, "123", null, null, async r => { await r.AssertStatusAsync(HttpStatusCode.MethodNotAllowed); });
await WithResponse(engine, "entity", HttpMethod.Put, "123", null, null, async r => { await r.AssertStatusAsync(HttpStatusCode.MethodNotAllowed); Assert.AreEqual("POST", r.GetContentHeader("Allow")); });
}

[TestMethod]
Expand Down

0 comments on commit ea98e9b

Please sign in to comment.