Skip to content

Commit 9cf88a0

Browse files
Changes to channels implementation and IChannelHub
1 parent e77c872 commit 9cf88a0

File tree

1 file changed

+37
-26
lines changed

1 file changed

+37
-26
lines changed

src/Saturn/Channels.fs

+37-26
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ open Microsoft.Extensions.DependencyInjection
88
open Microsoft.Extensions.Logging
99
open System
1010
open System.Collections.Concurrent
11-
open System.Collections.Generic
1211
open System.Net.WebSockets
1312
open System.Threading
1413
open System.Threading.Tasks
@@ -32,10 +31,10 @@ module Channels =
3231

3332
///Type representing information about client that has executed some channel action
3433
///It's passed as an argument in channel actions (`join`, `handle`, `terminate`)
35-
type ClientInfo = { SocketId: SocketId }
34+
type ClientInfo = { SocketId: SocketId; ChannelPath: ChannelPath }
3635
with
37-
static member New socketId =
38-
{ SocketId = socketId }
36+
static member New channelPath socketId =
37+
{ SocketId = socketId; ChannelPath = channelPath }
3938

4039
///Type representing result of `join` action. It can be either succesful (`Ok`) or you can reject client connection (`Rejected`)
4140
type JoinResult =
@@ -53,11 +52,12 @@ module Channels =
5352
/// You can get instance of it with `ctx.GetService<Saturn.Channels.ISocketHub>()` from any place that has access to HttpContext instance (`controller` actions, `channel` actions, normal `HttpHandler`)
5453
type ISocketHub =
5554
abstract member SendMessageToClients: ChannelPath -> Topic -> 'a -> Task<unit>
56-
abstract member SendMessageToClient: ChannelPath -> SocketId -> Topic -> 'a -> Task<unit>
55+
abstract member SendMessageToClient: SocketId -> Topic -> 'a -> Task<unit>
56+
abstract member SendMessageToClientsFilter: (ClientInfo -> bool) -> Topic -> 'a -> Task<unit>
5757

5858
/// A type that wraps access to connected websockets by endpoint
5959
type SocketHub(serializer: IJsonSerializer) =
60-
let sockets = Dictionary<ChannelPath, ConcurrentDictionary<SocketId, Socket.ThreadSafeWebSocket>>()
60+
let sockets = ConcurrentDictionary<ClientInfo, Socket.ThreadSafeWebSocket>()
6161

6262
let sendMessage (msg: 'a Message) (socket: Socket.ThreadSafeWebSocket) = task {
6363
let text = serializer.SerializeToString msg
@@ -67,37 +67,48 @@ module Channels =
6767
| Error exn -> return exn.Throw()
6868
}
6969

70-
member __.NewPath path =
71-
match sockets.TryGetValue path with
72-
| true, _path -> ()
73-
| false, _ -> sockets.[path] <- ConcurrentDictionary()
70+
member __.ConnectSocketToPath path clientId socket =
71+
let ci = {SocketId = clientId; ChannelPath = path}
72+
sockets.AddOrUpdate(ci, socket, fun _ _ -> socket) |> ignore
73+
ci
7474

75-
member __.ConnectSocketToPath path id socket =
76-
sockets.[path].AddOrUpdate(id, socket, fun _ _ -> socket) |> ignore
77-
id
78-
79-
member __.DisconnectSocketForPath path socketId =
80-
sockets.[path].TryRemove socketId |> ignore
75+
member __.DisconnectSocketForPath path clientId =
76+
let ci = {SocketId = clientId; ChannelPath = path}
77+
sockets.TryRemove ci |> ignore
8178

8279
interface ISocketHub with
80+
member __.SendMessageToClientsFilter(predicate: ClientInfo -> bool) (topic: Topic) (payload: 'a): Task<unit> = task {
81+
let msg = { Topic = topic; Ref = ""; Payload = payload }
82+
let tasks =
83+
sockets
84+
|> Seq.filter (fun n -> predicate n.Key)
85+
|> Seq.map (fun n -> sendMessage msg n.Value)
86+
87+
let! _results = Task.WhenAll tasks
88+
return ()
89+
}
90+
8391
member __.SendMessageToClients path topic payload = task {
8492
let msg = { Topic = topic; Ref = ""; Payload = payload }
85-
let tasks = [for kvp in sockets.[path] -> sendMessage msg kvp.Value ]
93+
let tasks =
94+
sockets
95+
|> Seq.filter (fun n -> n.Key.ChannelPath = path)
96+
|> Seq.map (fun n -> sendMessage msg n.Value)
97+
8698
let! _results = Task.WhenAll tasks
8799
return ()
88100
}
89101

90102
member __.SendMessageToClient path clientId topic payload = task {
91-
match sockets.[path].TryGetValue clientId with
103+
let ci = {SocketId = clientId; ChannelPath = path}
104+
match sockets.TryGetValue ci with
92105
| true, socket ->
93106
let msg = { Topic = topic; Ref = ""; Payload = payload }
94107
do! sendMessage msg socket
95108
| _ -> ()
96109
}
97110

98111
type SocketMiddleware(next : RequestDelegate, serializer: IJsonSerializer, path: string, channel: IChannel, sockets: SocketHub, logger: ILogger<SocketMiddleware>) =
99-
do sockets.NewPath path
100-
101112
member __.Invoke(ctx : HttpContext) =
102113
task {
103114
if ctx.Request.Path = PathString(path) then
@@ -106,14 +117,14 @@ module Channels =
106117
let logger = ctx.RequestServices.GetRequiredService<ILogger<SocketMiddleware>>()
107118
logger.LogTrace("Promoted websocket request")
108119
let socketId = Guid.NewGuid()
109-
let socketInfo = ClientInfo.New socketId
110-
let! joinResult = channel.Join(ctx, socketInfo)
120+
let clientInfo = ClientInfo.New path socketId
121+
let! joinResult = channel.Join(ctx, clientInfo)
111122
match joinResult with
112123
| Ok ->
113124
logger.LogTrace("Joined channel {path}", path)
114125
let! webSocket = ctx.WebSockets.AcceptWebSocketAsync()
115126
let wrappedSocket = Socket.createFromWebSocket webSocket
116-
let socketId = sockets.ConnectSocketToPath path socketId wrappedSocket
127+
let clientInfo = sockets.ConnectSocketToPath path socketId wrappedSocket
117128

118129
while wrappedSocket.State = WebSocketState.Open do
119130
match! Socket.receiveMessageAsUTF8 wrappedSocket with
@@ -122,7 +133,7 @@ module Channels =
122133
| Result.Ok (WebSocket.ReceiveUTF8Result.String msg) ->
123134
logger.LogTrace("received message {0}", msg)
124135
try
125-
do! channel.HandleMessage(ctx, socketInfo, serializer, msg)
136+
do! channel.HandleMessage(ctx, clientInfo, serializer, msg)
126137
with
127138
| ex ->
128139
// typically a deserialization error, swallow
@@ -132,8 +143,8 @@ module Channels =
132143
logger.LogError(exn.SourceException, "Error while receiving message")
133144
() // TODO: ?
134145

135-
do! channel.Terminate (ctx, socketInfo)
136-
sockets.DisconnectSocketForPath path socketId
146+
do! channel.Terminate (ctx, clientInfo)
147+
sockets.DisconnectSocketForPath path clientInfo.SocketId
137148
let! result = Socket.close wrappedSocket WebSocketCloseStatus.NormalClosure "Closing channel"
138149
match result with
139150
| Result.Ok () ->

0 commit comments

Comments
 (0)