Skip to content

Commit 66bb40a

Browse files
committed
Proxy refinements (thanks to @Lukasa).
- Lock-free UDP proxy forwarding by puttng front and back end on the same thread. - Cleans up some Swift language awkwardness in the proxies.
1 parent 827b46c commit 66bb40a

File tree

2 files changed

+62
-68
lines changed

2 files changed

+62
-68
lines changed

Sources/SocketForwarder/TCPForwarder.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,6 @@ public struct TCPForwarder: SocketForwarder {
5757
return
5858
bootstrap
5959
.bind(to: self.proxyAddress)
60-
.flatMap { $0.eventLoop.makeSucceededFuture(SocketForwarderResult(channel: $0)) }
60+
.map { SocketForwarderResult(channel: $0) }
6161
}
6262
}

Sources/SocketForwarder/UDPForwarder.swift

Lines changed: 61 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import NIOFoundationCompat
2222
import Synchronization
2323

2424
// Proxy backend for a single client address (clientIP, clientPort).
25-
private final class UDPProxyBackend: ChannelInboundHandler, Sendable {
25+
private final class UDPProxyBackend: ChannelInboundHandler {
2626
typealias InboundIn = AddressedEnvelope<ByteBuffer>
2727
typealias OutboundOut = AddressedEnvelope<ByteBuffer>
2828

@@ -35,62 +35,56 @@ private final class UDPProxyBackend: ChannelInboundHandler, Sendable {
3535
private let serverAddress: SocketAddress
3636
private let frontendChannel: any Channel
3737
private let log: Logger?
38-
private let state: Mutex<State>
38+
private var state: State
3939

4040
init(clientAddress: SocketAddress, serverAddress: SocketAddress, frontendChannel: any Channel, log: Logger? = nil) {
4141
self.clientAddress = clientAddress
4242
self.serverAddress = serverAddress
4343
self.frontendChannel = frontendChannel
4444
self.log = log
4545
let initialState = State(queuedPayloads: Deque(), channel: nil)
46-
self.state = Mutex(initialState)
46+
self.state = initialState
4747
}
4848

4949
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
5050
// relay data from server to client.
5151
let inbound = self.unwrapInboundIn(data)
5252
let outbound = OutboundOut(remoteAddress: self.clientAddress, data: inbound.data)
5353
self.log?.trace("backend - writing datagram to client")
54-
_ = self.frontendChannel.writeAndFlush(outbound)
54+
self.frontendChannel.writeAndFlush(outbound, promise: nil)
5555
}
5656

5757
func channelActive(context: ChannelHandlerContext) {
58-
state.withLock {
59-
if !$0.queuedPayloads.isEmpty {
60-
self.log?.trace("backend - writing \($0.queuedPayloads.count) queued datagrams to server")
61-
while let queuedData = $0.queuedPayloads.popFirst() {
62-
let outbound: UDPProxyBackend.OutboundOut = OutboundOut(remoteAddress: self.serverAddress, data: queuedData)
63-
_ = context.channel.writeAndFlush(outbound)
64-
}
58+
if !state.queuedPayloads.isEmpty {
59+
self.log?.trace("backend - writing \(state.queuedPayloads.count) queued datagrams to server")
60+
while let queuedData = state.queuedPayloads.popFirst() {
61+
let outbound: UDPProxyBackend.OutboundOut = OutboundOut(remoteAddress: self.serverAddress, data: queuedData)
62+
context.channel.writeAndFlush(outbound, promise: nil)
6563
}
66-
$0.channel = context.channel
6764
}
65+
state.channel = context.channel
6866
}
6967

7068
func write(data: ByteBuffer) {
7169
// change package remote address from proxy server to real server
72-
state.withLock {
73-
if let channel = $0.channel {
74-
// channel has been initialized, so relay any queued packets, along with this one to outbound
75-
self.log?.trace("backend - writing datagram to server")
76-
let outbound: UDPProxyBackend.OutboundOut = OutboundOut(remoteAddress: self.serverAddress, data: data)
77-
_ = channel.writeAndFlush(outbound)
78-
} else {
79-
// channel is initializing, queue
80-
self.log?.trace("backend - queuing datagram")
81-
$0.queuedPayloads.append(data)
82-
}
70+
if let channel = state.channel {
71+
// channel has been initialized, so relay any queued packets, along with this one to outbound
72+
self.log?.trace("backend - writing datagram to server")
73+
let outbound: UDPProxyBackend.OutboundOut = OutboundOut(remoteAddress: self.serverAddress, data: data)
74+
channel.writeAndFlush(outbound, promise: nil)
75+
} else {
76+
// channel is initializing, queue
77+
self.log?.trace("backend - queuing datagram")
78+
state.queuedPayloads.append(data)
8379
}
8480
}
8581

8682
func close() {
87-
state.withLock {
88-
guard let channel = $0.channel else {
89-
self.log?.warning("backend - close on inactive channel")
90-
return
91-
}
92-
_ = channel.close()
83+
guard let channel = state.channel else {
84+
self.log?.warning("backend - close on inactive channel")
85+
return
9386
}
87+
_ = channel.close()
9488
}
9589
}
9690

@@ -99,23 +93,21 @@ private struct ProxyContext {
9993
public let closeFuture: EventLoopFuture<Void>
10094
}
10195

102-
private final class UDPProxyFrontend: ChannelInboundHandler, Sendable {
96+
private final class UDPProxyFrontend: ChannelInboundHandler {
10397
typealias InboundIn = AddressedEnvelope<ByteBuffer>
10498
typealias OutboundOut = AddressedEnvelope<ByteBuffer>
10599
private let maxProxies = UInt(256)
106100

107101
private let proxyAddress: SocketAddress
108102
private let serverAddress: SocketAddress
109-
private let eventLoopGroup: any EventLoopGroup
110103
private let log: Logger?
111104

112-
private let proxies: Mutex<LRUCache<String, ProxyContext>>
105+
private var proxies: LRUCache<String, ProxyContext>
113106

114-
init(proxyAddress: SocketAddress, serverAddress: SocketAddress, eventLoopGroup: any EventLoopGroup, log: Logger? = nil) {
107+
init(proxyAddress: SocketAddress, serverAddress: SocketAddress, log: Logger? = nil) {
115108
self.proxyAddress = proxyAddress
116109
self.serverAddress = serverAddress
117-
self.eventLoopGroup = eventLoopGroup
118-
self.proxies = Mutex(LRUCache(size: maxProxies))
110+
self.proxies = LRUCache(size: maxProxies)
119111
self.log = log
120112
}
121113

@@ -134,33 +126,34 @@ private final class UDPProxyFrontend: ChannelInboundHandler, Sendable {
134126

135127
let key = "\(clientIP):\(clientPort)"
136128
do {
137-
try proxies.withLock {
138-
if let context = $0.get(key) {
139-
context.proxy.write(data: inbound.data)
140-
} else {
141-
self.log?.trace("frontend - creating backend")
142-
let proxy = UDPProxyBackend(
143-
clientAddress: inbound.remoteAddress,
144-
serverAddress: self.serverAddress,
145-
frontendChannel: context.channel,
146-
log: log
147-
)
148-
let proxyAddress = try SocketAddress(ipAddress: "0.0.0.0", port: 0)
149-
let proxyToServerFuture = DatagramBootstrap(group: self.eventLoopGroup)
150-
.channelInitializer {
151-
self.log?.trace("frontend - initializing backend")
152-
return $0.pipeline.addHandler(proxy)
129+
if let context = proxies.get(key) {
130+
context.proxy.write(data: inbound.data)
131+
} else {
132+
self.log?.trace("frontend - creating backend")
133+
let proxy = UDPProxyBackend(
134+
clientAddress: inbound.remoteAddress,
135+
serverAddress: self.serverAddress,
136+
frontendChannel: context.channel,
137+
log: log
138+
)
139+
let proxyAddress = try SocketAddress(ipAddress: "0.0.0.0", port: 0)
140+
let loopBoundProxy = NIOLoopBound(proxy, eventLoop: context.eventLoop)
141+
let proxyToServerFuture = DatagramBootstrap(group: context.eventLoop)
142+
.channelInitializer { [log] channel in
143+
log?.trace("frontend - initializing backend")
144+
return channel.eventLoop.makeCompletedFuture {
145+
try channel.pipeline.syncOperations.addHandler(loopBoundProxy.value)
153146
}
154-
.bind(to: proxyAddress)
155-
.flatMap { $0.closeFuture }
156-
let context = ProxyContext(proxy: proxy, closeFuture: proxyToServerFuture)
157-
if let (_, evictedContext) = $0.put(key: key, value: context) {
158-
self.log?.trace("frontend - closing evicted backend")
159-
evictedContext.proxy.close()
160147
}
161-
162-
proxy.write(data: inbound.data)
148+
.bind(to: proxyAddress)
149+
.flatMap { $0.closeFuture }
150+
let context = ProxyContext(proxy: proxy, closeFuture: proxyToServerFuture)
151+
if let (_, evictedContext) = proxies.put(key: key, value: context) {
152+
self.log?.trace("frontend - closing evicted backend")
153+
evictedContext.proxy.close()
163154
}
155+
156+
proxy.write(data: inbound.data)
164157
}
165158
} catch {
166159
log?.error("server handler - backend channel creation failed with error: \(error)")
@@ -192,20 +185,21 @@ public struct UDPForwarder: SocketForwarder {
192185

193186
public func run() throws -> EventLoopFuture<SocketForwarderResult> {
194187
self.log?.trace("frontend - creating channel")
195-
let proxyToServerHandler = UDPProxyFrontend(
196-
proxyAddress: proxyAddress,
197-
serverAddress: serverAddress,
198-
eventLoopGroup: self.eventLoopGroup,
199-
log: log
200-
)
201188
let bootstrap = DatagramBootstrap(group: self.eventLoopGroup)
202189
.channelInitializer { serverChannel in
203190
self.log?.trace("frontend - initializing channel")
204-
return serverChannel.pipeline.addHandler(proxyToServerHandler)
191+
let proxyToServerHandler = UDPProxyFrontend(
192+
proxyAddress: proxyAddress,
193+
serverAddress: serverAddress,
194+
log: log
195+
)
196+
return serverChannel.eventLoop.makeCompletedFuture {
197+
try serverChannel.pipeline.syncOperations.addHandler(proxyToServerHandler)
198+
}
205199
}
206200
return
207201
bootstrap
208202
.bind(to: proxyAddress)
209-
.flatMap { $0.eventLoop.makeSucceededFuture(SocketForwarderResult(channel: $0)) }
203+
.map { SocketForwarderResult(channel: $0) }
210204
}
211205
}

0 commit comments

Comments
 (0)