Skip to content

Commit e98b3bb

Browse files
Fix websocket broken when connecting via SSL/TLS (#611)
* Rework websocket connection to streams instead of sockets
1 parent 7c2f475 commit e98b3bb

File tree

8 files changed

+140
-16
lines changed

8 files changed

+140
-16
lines changed

API/Protocol/IRequest.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,12 @@ public interface IRequest : IDisposable
3131
IResponseBuilder Respond();
3232

3333
/// <summary>
34-
/// Upgrades the connection of the client, causing the underlying socket
34+
/// Upgrades the connection of the client, causing the underlying socket and streams
3535
/// to be exposed and to be used by another protocol, such as a websocket handler.
3636
/// </summary>
3737
/// <returns>The upgrade information to use</returns>
3838
/// <remarks>
39-
/// After upgrading a connection, the server will surrender this socket to your logic
39+
/// After upgrading a connection, the server will surrender this socket and stream to your logic
4040
/// to no further interaction is done by the framework with the client and is up to
4141
/// your logic.
4242
/// </remarks>

API/Protocol/UpgradeInfo.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@ namespace GenHTTP.Api.Protocol;
66
/// Returned when upgrading a connection.
77
/// </summary>
88
/// <param name="Socket">The raw socket the current client is connected to</param>
9+
/// <param name="Stream">The underlying network stream used for the connection (already authenticated in case of TLS)</param>
910
/// <param name="Response">The response to return so that the server will ignore the connection further on</param>
10-
public record UpgradeInfo(Socket Socket, IResponse Response);
11+
public record UpgradeInfo(Socket Socket, Stream Stream, IResponse Response);

Engine/Internal/Protocol/ClientHandler.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ private async PooledValueTask<ConnectionStatus> HandlePipe(PipeReader reader)
153153

154154
private async PooledValueTask<ConnectionStatus> HandleRequest(RequestBuilder builder, bool dataRemaining)
155155
{
156-
using var request = builder.Connection(Server, Connection, EndPoint, Connection.GetAddress(), ClientCertificate).Build();
156+
using var request = builder.Connection(Server, Connection, Stream, EndPoint, Connection.GetAddress(), ClientCertificate).Build();
157157

158158
KeepAlive ??= request["Connection"]?.Equals("Keep-Alive", StringComparison.InvariantCultureIgnoreCase) ?? request.ProtocolType == HttpProtocol.Http11;
159159

Engine/Internal/Protocol/Request.cs

+4-2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ namespace GenHTTP.Engine.Internal.Protocol;
1212
internal sealed class Request : IRequest
1313
{
1414
private readonly Socket _Socket;
15+
private readonly Stream _Stream;
1516

1617
private FlexibleContentType? _ContentType;
1718
private ICookieCollection? _Cookies;
@@ -24,11 +25,12 @@ internal sealed class Request : IRequest
2425

2526
#region Initialization
2627

27-
internal Request(IServer server, Socket socket, IEndPoint endPoint, IClientConnection client, IClientConnection localClient, HttpProtocol protocol, FlexibleRequestMethod method,
28+
internal Request(IServer server, Socket socket, Stream stream, IEndPoint endPoint, IClientConnection client, IClientConnection localClient, HttpProtocol protocol, FlexibleRequestMethod method,
2829
RoutingTarget target, IHeaderCollection headers, ICookieCollection? cookies, IForwardingCollection? forwardings,
2930
IRequestQuery? query, Stream? content)
3031
{
3132
_Socket = socket;
33+
_Stream = stream;
3234

3335
Client = client;
3436
LocalClient = localClient;
@@ -55,7 +57,7 @@ internal Request(IServer server, Socket socket, IEndPoint endPoint, IClientConne
5557

5658
public IResponseBuilder Respond() => new ResponseBuilder().Status(ResponseStatus.Ok);
5759

58-
public UpgradeInfo Upgrade() => new(_Socket, new Response() { Upgraded = true });
60+
public UpgradeInfo Upgrade() => new(_Socket, _Stream, new Response { Upgraded = true });
5961

6062
#endregion
6163

Engine/Internal/Protocol/RequestBuilder.cs

+11-2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ namespace GenHTTP.Engine.Internal.Protocol;
1212
internal sealed class RequestBuilder : IBuilder<IRequest>
1313
{
1414
private Socket? _Socket;
15+
private Stream? _Stream;
16+
1517
private IPAddress? _Address;
1618
private X509Certificate? _ClientCertificate;
1719

@@ -57,9 +59,11 @@ private ForwardingCollection Forwardings
5759

5860
#region Functionality
5961

60-
public RequestBuilder Connection(IServer server, Socket socket, IEndPoint endPoint, IPAddress? address, X509Certificate? clientCertificate)
62+
public RequestBuilder Connection(IServer server, Socket socket, Stream stream, IEndPoint endPoint, IPAddress? address, X509Certificate? clientCertificate)
6163
{
6264
_Socket = socket;
65+
_Stream = stream;
66+
6367
_Server = server;
6468
_Address = address;
6569
_EndPoint = endPoint;
@@ -130,6 +134,11 @@ public IRequest Build()
130134
throw new BuilderMissingPropertyException("Socket");
131135
}
132136

137+
if (_Stream == null)
138+
{
139+
throw new BuilderMissingPropertyException("Stream");
140+
}
141+
133142
if (_EndPoint is null)
134143
{
135144
throw new BuilderMissingPropertyException("EndPoint");
@@ -171,7 +180,7 @@ public IRequest Build()
171180

172181
var client = Forwardings.DetermineClient(_ClientCertificate) ?? localClient;
173182

174-
return new Request(_Server, _Socket, _EndPoint, client, localClient, (HttpProtocol)_Protocol, _RequestMethod,
183+
return new Request(_Server, _Socket, _Stream, _EndPoint, client, localClient, (HttpProtocol)_Protocol, _RequestMethod,
175184
_Target, Headers, _Cookies, _Forwardings, _Query, _Content);
176185
}
177186
catch (Exception)

Modules/Websockets/Handler/WebsocketConnection.cs

+5-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
using System.Net.Sockets;
2-
3-
using Fleck;
1+
using Fleck;
42

53
using GenHTTP.Api.Protocol;
64

@@ -49,7 +47,7 @@ public bool IsAvailable
4947

5048
#region Initialization
5149

52-
public WebsocketConnection(Socket socket, IRequest request, List<string> supportedProtocols,
50+
public WebsocketConnection(ISocket socket, IRequest request, List<string> supportedProtocols,
5351
Action<IWebsocketConnection>? onOpen,
5452
Action<IWebsocketConnection>? onClose,
5553
Action<IWebsocketConnection, string>? onMessage,
@@ -58,11 +56,11 @@ public WebsocketConnection(Socket socket, IRequest request, List<string> support
5856
Action<IWebsocketConnection, byte[]>? onPong,
5957
Action<IWebsocketConnection, Exception>? onError)
6058
{
61-
Socket = new SocketWrapper(socket);
59+
Socket = socket;
6260
Request = request;
6361

6462
SupportedProtocols = supportedProtocols;
65-
63+
6664
OnOpen = (onOpen != null) ? () => onOpen(this) : () => { };
6765
OnClose = (onClose != null) ? () => onClose(this) : () => { };
6866
OnMessage = (onMessage != null) ? (x) => onMessage(this, x) : x => { };
@@ -152,7 +150,7 @@ public void Close(int code)
152150
}
153151

154152
var bytes = Handler.FrameClose(code);
155-
153+
156154
if (bytes.Length == 0)
157155
CloseSocket();
158156
else

Modules/Websockets/Handler/WebsocketHandler.cs

+3-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ public WebsocketHandler(List<string> supportedProtocols,
6565

6666
var upgrade = request.Upgrade();
6767

68-
var connection = new WebsocketConnection(upgrade.Socket, request, SupportedProtocols, OnOpen, OnClose, OnMessage, OnBinary, OnPing, OnPong, OnError);
68+
var socket = new WrappedSocket(upgrade);
69+
70+
var connection = new WebsocketConnection(socket, request, SupportedProtocols, OnOpen, OnClose, OnMessage, OnBinary, OnPing, OnPong, OnError);
6971

7072
connection.Start();
7173

+112
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
using System.Net;
2+
using System.Net.Sockets;
3+
using System.Security.Authentication;
4+
using System.Security.Cryptography.X509Certificates;
5+
using Fleck;
6+
using GenHTTP.Api.Protocol;
7+
8+
namespace GenHTTP.Modules.Websockets.Handler;
9+
10+
public sealed class WrappedSocket : ISocket
11+
{
12+
13+
#region Get-/Setters
14+
15+
public Socket Socket { get; }
16+
17+
public Stream Stream { get; }
18+
19+
public CancellationTokenSource TokenSource { get; }
20+
21+
public bool Connected => Socket.Connected;
22+
23+
public bool NoDelay
24+
{
25+
get => Socket.NoDelay;
26+
set => Socket.NoDelay = value;
27+
}
28+
29+
public string RemoteIpAddress => (Socket.RemoteEndPoint as IPEndPoint)?.Address.ToString() ?? string.Empty;
30+
31+
public int RemotePort => (Socket.RemoteEndPoint as IPEndPoint)?.Port ?? -1;
32+
33+
#endregion
34+
35+
#region Initialization
36+
37+
public WrappedSocket(UpgradeInfo upgradeInfo)
38+
{
39+
Socket = upgradeInfo.Socket;
40+
Stream = upgradeInfo.Stream;
41+
42+
TokenSource = new();
43+
}
44+
45+
#endregion
46+
47+
#region Functionality
48+
49+
public async Task Send(byte[] buffer, Action callback, Action<Exception> error)
50+
{
51+
try
52+
{
53+
await Stream.WriteAsync(buffer.AsMemory(), TokenSource.Token);
54+
await Stream.FlushAsync();
55+
56+
callback();
57+
}
58+
catch (Exception ex)
59+
{
60+
error(ex);
61+
}
62+
}
63+
64+
public async Task<int> Receive(byte[] buffer, Action<int> callback, Action<Exception> error, int offset = 0)
65+
{
66+
try
67+
{
68+
var result = await Stream.ReadAsync(buffer.AsMemory(offset), TokenSource.Token);
69+
70+
callback(result);
71+
72+
return result;
73+
}
74+
catch (Exception ex)
75+
{
76+
error(ex);
77+
return -1;
78+
}
79+
}
80+
81+
public void Close() => Dispose();
82+
83+
public void Dispose()
84+
{
85+
TokenSource.Cancel();
86+
Stream.Dispose();
87+
Socket.Dispose();
88+
}
89+
90+
#endregion
91+
92+
#region Unnecessary stuff
93+
94+
private const string NotRequiredByIntegration = "Not required by integration";
95+
96+
public EndPoint LocalEndPoint => throw new NotImplementedException(NotRequiredByIntegration);
97+
98+
public Task<ISocket> Accept(Action<ISocket> callback, Action<Exception> error)
99+
=> throw new NotImplementedException(NotRequiredByIntegration);
100+
101+
public void Bind(EndPoint ipLocal)
102+
=> throw new NotImplementedException(NotRequiredByIntegration);
103+
104+
public void Listen(int backlog)
105+
=> throw new NotImplementedException(NotRequiredByIntegration);
106+
107+
public Task Authenticate(X509Certificate2 certificate, SslProtocols enabledSslProtocols, Action callback, Action<Exception> error)
108+
=> throw new NotImplementedException(NotRequiredByIntegration);
109+
110+
#endregion
111+
112+
}

0 commit comments

Comments
 (0)