@@ -8,7 +8,6 @@ open Microsoft.Extensions.DependencyInjection
8
8
open Microsoft.Extensions .Logging
9
9
open System
10
10
open System.Collections .Concurrent
11
- open System.Collections .Generic
12
11
open System.Net .WebSockets
13
12
open System.Threading
14
13
open System.Threading .Tasks
@@ -32,10 +31,10 @@ module Channels =
32
31
33
32
///Type representing information about client that has executed some channel action
34
33
///It's passed as an argument in channel actions (`join`, `handle`, `terminate`)
35
- type ClientInfo = { SocketId: SocketId }
34
+ type ClientInfo = { SocketId: SocketId ; ChannelPath : ChannelPath }
36
35
with
37
- static member New socketId =
38
- { SocketId = socketId }
36
+ static member New channelPath socketId =
37
+ { SocketId = socketId; ChannelPath = channelPath }
39
38
40
39
///Type representing result of `join` action. It can be either succesful (`Ok`) or you can reject client connection (`Rejected`)
41
40
type JoinResult =
@@ -53,11 +52,12 @@ module Channels =
53
52
/// You can get instance of it with `ctx.GetService<Saturn.Channels.ISocketHub>()` from any place that has access to HttpContext instance (`controller` actions, `channel` actions, normal `HttpHandler`)
54
53
type ISocketHub =
55
54
abstract member SendMessageToClients: ChannelPath -> Topic -> 'a -> Task < unit >
56
- abstract member SendMessageToClient: ChannelPath -> SocketId -> Topic -> 'a -> Task < unit >
55
+ abstract member SendMessageToClient: SocketId -> Topic -> 'a -> Task < unit >
56
+ abstract member SendMessageToClientsFilter: ( ClientInfo -> bool ) -> Topic -> 'a -> Task < unit >
57
57
58
58
/// A type that wraps access to connected websockets by endpoint
59
59
type SocketHub ( serializer : IJsonSerializer ) =
60
- let sockets = Dictionary < ChannelPath , ConcurrentDictionary< SocketId , Socket.ThreadSafeWebSocket> >()
60
+ let sockets = ConcurrentDictionary< ClientInfo , Socket.ThreadSafeWebSocket>()
61
61
62
62
let sendMessage ( msg : 'a Message ) ( socket : Socket.ThreadSafeWebSocket ) = task {
63
63
let text = serializer.SerializeToString msg
@@ -67,37 +67,48 @@ module Channels =
67
67
| Error exn -> return exn.Throw()
68
68
}
69
69
70
- member __.NewPath path =
71
- match sockets.TryGetValue path with
72
- | true , _ path -> ()
73
- | false , _ -> sockets .[ path ] <- ConcurrentDictionary ()
70
+ member __.ConnectSocketToPath path clientId socket =
71
+ let ci = { SocketId = clientId ; ChannelPath = path }
72
+ sockets.AddOrUpdate ( ci , socket , fun _ _ -> socket ) |> ignore
73
+ ci
74
74
75
- member __.ConnectSocketToPath path id socket =
76
- sockets.[ path]. AddOrUpdate( id, socket, fun _ _ -> socket) |> ignore
77
- id
78
-
79
- member __.DisconnectSocketForPath path socketId =
80
- sockets.[ path]. TryRemove socketId |> ignore
75
+ member __.DisconnectSocketForPath path clientId =
76
+ let ci = { SocketId = clientId; ChannelPath = path}
77
+ sockets.TryRemove ci |> ignore
81
78
82
79
interface ISocketHub with
80
+ member __.SendMessageToClientsFilter ( predicate : ClientInfo -> bool ) ( topic : Topic ) ( payload : 'a ): Task < unit > = task {
81
+ let msg = { Topic = topic; Ref = " " ; Payload = payload }
82
+ let tasks =
83
+ sockets
84
+ |> Seq.filter ( fun n -> predicate n.Key)
85
+ |> Seq.map ( fun n -> sendMessage msg n.Value)
86
+
87
+ let! _results = Task.WhenAll tasks
88
+ return ()
89
+ }
90
+
83
91
member __.SendMessageToClients path topic payload = task {
84
92
let msg = { Topic = topic; Ref = " " ; Payload = payload }
85
- let tasks = [ for kvp in sockets.[ path] -> sendMessage msg kvp.Value ]
93
+ let tasks =
94
+ sockets
95
+ |> Seq.filter ( fun n -> n.Key.ChannelPath = path)
96
+ |> Seq.map ( fun n -> sendMessage msg n.Value)
97
+
86
98
let! _results = Task.WhenAll tasks
87
99
return ()
88
100
}
89
101
90
102
member __.SendMessageToClient path clientId topic payload = task {
91
- match sockets.[ path]. TryGetValue clientId with
103
+ let ci = { SocketId = clientId; ChannelPath = path}
104
+ match sockets.TryGetValue ci with
92
105
| true , socket ->
93
106
let msg = { Topic = topic; Ref = " " ; Payload = payload }
94
107
do ! sendMessage msg socket
95
108
| _ -> ()
96
109
}
97
110
98
111
type SocketMiddleware ( next : RequestDelegate , serializer : IJsonSerializer , path : string , channel : IChannel , sockets : SocketHub , logger : ILogger < SocketMiddleware >) =
99
- do sockets.NewPath path
100
-
101
112
member __.Invoke ( ctx : HttpContext ) =
102
113
task {
103
114
if ctx.Request.Path = PathString( path) then
@@ -106,14 +117,14 @@ module Channels =
106
117
let logger = ctx.RequestServices.GetRequiredService< ILogger< SocketMiddleware>>()
107
118
logger.LogTrace( " Promoted websocket request" )
108
119
let socketId = Guid.NewGuid()
109
- let socketInfo = ClientInfo.New socketId
110
- let! joinResult = channel.Join( ctx, socketInfo )
120
+ let clientInfo = ClientInfo.New path socketId
121
+ let! joinResult = channel.Join( ctx, clientInfo )
111
122
match joinResult with
112
123
| Ok ->
113
124
logger.LogTrace( " Joined channel {path}" , path)
114
125
let! webSocket = ctx.WebSockets.AcceptWebSocketAsync()
115
126
let wrappedSocket = Socket.createFromWebSocket webSocket
116
- let socketId = sockets.ConnectSocketToPath path socketId wrappedSocket
127
+ let clientInfo = sockets.ConnectSocketToPath path socketId wrappedSocket
117
128
118
129
while wrappedSocket.State = WebSocketState.Open do
119
130
match ! Socket.receiveMessageAsUTF8 wrappedSocket with
@@ -122,7 +133,7 @@ module Channels =
122
133
| Result.Ok ( WebSocket.ReceiveUTF8Result.String msg) ->
123
134
logger.LogTrace( " received message {0}" , msg)
124
135
try
125
- do ! channel.HandleMessage( ctx, socketInfo , serializer, msg)
136
+ do ! channel.HandleMessage( ctx, clientInfo , serializer, msg)
126
137
with
127
138
| ex ->
128
139
// typically a deserialization error, swallow
@@ -132,8 +143,8 @@ module Channels =
132
143
logger.LogError( exn.SourceException, " Error while receiving message" )
133
144
() // TODO: ?
134
145
135
- do ! channel.Terminate ( ctx, socketInfo )
136
- sockets.DisconnectSocketForPath path socketId
146
+ do ! channel.Terminate ( ctx, clientInfo )
147
+ sockets.DisconnectSocketForPath path clientInfo.SocketId
137
148
let! result = Socket.close wrappedSocket WebSocketCloseStatus.NormalClosure " Closing channel"
138
149
match result with
139
150
| Result.Ok () ->
0 commit comments