Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Sources/SocketForwarder/TCPForwarder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,6 @@ public struct TCPForwarder: SocketForwarder {
return
bootstrap
.bind(to: self.proxyAddress)
.flatMap { $0.eventLoop.makeSucceededFuture(SocketForwarderResult(channel: $0)) }
.map { SocketForwarderResult(channel: $0) }
}
}
128 changes: 61 additions & 67 deletions Sources/SocketForwarder/UDPForwarder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import NIOFoundationCompat
import Synchronization

// Proxy backend for a single client address (clientIP, clientPort).
private final class UDPProxyBackend: ChannelInboundHandler, Sendable {
private final class UDPProxyBackend: ChannelInboundHandler {
typealias InboundIn = AddressedEnvelope<ByteBuffer>
typealias OutboundOut = AddressedEnvelope<ByteBuffer>

Expand All @@ -35,62 +35,56 @@ private final class UDPProxyBackend: ChannelInboundHandler, Sendable {
private let serverAddress: SocketAddress
private let frontendChannel: any Channel
private let log: Logger?
private let state: Mutex<State>
private var state: State

init(clientAddress: SocketAddress, serverAddress: SocketAddress, frontendChannel: any Channel, log: Logger? = nil) {
self.clientAddress = clientAddress
self.serverAddress = serverAddress
self.frontendChannel = frontendChannel
self.log = log
let initialState = State(queuedPayloads: Deque(), channel: nil)
self.state = Mutex(initialState)
self.state = initialState
}

func channelRead(context: ChannelHandlerContext, data: NIOAny) {
// relay data from server to client.
let inbound = self.unwrapInboundIn(data)
let outbound = OutboundOut(remoteAddress: self.clientAddress, data: inbound.data)
self.log?.trace("backend - writing datagram to client")
_ = self.frontendChannel.writeAndFlush(outbound)
self.frontendChannel.writeAndFlush(outbound, promise: nil)
}

func channelActive(context: ChannelHandlerContext) {
state.withLock {
if !$0.queuedPayloads.isEmpty {
self.log?.trace("backend - writing \($0.queuedPayloads.count) queued datagrams to server")
while let queuedData = $0.queuedPayloads.popFirst() {
let outbound: UDPProxyBackend.OutboundOut = OutboundOut(remoteAddress: self.serverAddress, data: queuedData)
_ = context.channel.writeAndFlush(outbound)
}
if !state.queuedPayloads.isEmpty {
self.log?.trace("backend - writing \(state.queuedPayloads.count) queued datagrams to server")
while let queuedData = state.queuedPayloads.popFirst() {
let outbound: UDPProxyBackend.OutboundOut = OutboundOut(remoteAddress: self.serverAddress, data: queuedData)
context.channel.writeAndFlush(outbound, promise: nil)
}
$0.channel = context.channel
}
state.channel = context.channel
}

func write(data: ByteBuffer) {
// change package remote address from proxy server to real server
state.withLock {
if let channel = $0.channel {
// channel has been initialized, so relay any queued packets, along with this one to outbound
self.log?.trace("backend - writing datagram to server")
let outbound: UDPProxyBackend.OutboundOut = OutboundOut(remoteAddress: self.serverAddress, data: data)
_ = channel.writeAndFlush(outbound)
} else {
// channel is initializing, queue
self.log?.trace("backend - queuing datagram")
$0.queuedPayloads.append(data)
}
if let channel = state.channel {
// channel has been initialized, so relay any queued packets, along with this one to outbound
self.log?.trace("backend - writing datagram to server")
let outbound: UDPProxyBackend.OutboundOut = OutboundOut(remoteAddress: self.serverAddress, data: data)
channel.writeAndFlush(outbound, promise: nil)
} else {
// channel is initializing, queue
self.log?.trace("backend - queuing datagram")
state.queuedPayloads.append(data)
}
}

func close() {
state.withLock {
guard let channel = $0.channel else {
self.log?.warning("backend - close on inactive channel")
return
}
_ = channel.close()
guard let channel = state.channel else {
self.log?.warning("backend - close on inactive channel")
return
}
_ = channel.close()
}
}

Expand All @@ -99,23 +93,21 @@ private struct ProxyContext {
public let closeFuture: EventLoopFuture<Void>
}

private final class UDPProxyFrontend: ChannelInboundHandler, Sendable {
private final class UDPProxyFrontend: ChannelInboundHandler {
typealias InboundIn = AddressedEnvelope<ByteBuffer>
typealias OutboundOut = AddressedEnvelope<ByteBuffer>
private let maxProxies = UInt(256)

private let proxyAddress: SocketAddress
private let serverAddress: SocketAddress
private let eventLoopGroup: any EventLoopGroup
private let log: Logger?

private let proxies: Mutex<LRUCache<String, ProxyContext>>
private var proxies: LRUCache<String, ProxyContext>

init(proxyAddress: SocketAddress, serverAddress: SocketAddress, eventLoopGroup: any EventLoopGroup, log: Logger? = nil) {
init(proxyAddress: SocketAddress, serverAddress: SocketAddress, log: Logger? = nil) {
self.proxyAddress = proxyAddress
self.serverAddress = serverAddress
self.eventLoopGroup = eventLoopGroup
self.proxies = Mutex(LRUCache(size: maxProxies))
self.proxies = LRUCache(size: maxProxies)
self.log = log
}

Expand All @@ -134,33 +126,34 @@ private final class UDPProxyFrontend: ChannelInboundHandler, Sendable {

let key = "\(clientIP):\(clientPort)"
do {
try proxies.withLock {
if let context = $0.get(key) {
context.proxy.write(data: inbound.data)
} else {
self.log?.trace("frontend - creating backend")
let proxy = UDPProxyBackend(
clientAddress: inbound.remoteAddress,
serverAddress: self.serverAddress,
frontendChannel: context.channel,
log: log
)
let proxyAddress = try SocketAddress(ipAddress: "0.0.0.0", port: 0)
let proxyToServerFuture = DatagramBootstrap(group: self.eventLoopGroup)
.channelInitializer {
self.log?.trace("frontend - initializing backend")
return $0.pipeline.addHandler(proxy)
if let context = proxies.get(key) {
context.proxy.write(data: inbound.data)
} else {
self.log?.trace("frontend - creating backend")
let proxy = UDPProxyBackend(
clientAddress: inbound.remoteAddress,
serverAddress: self.serverAddress,
frontendChannel: context.channel,
log: log
)
let proxyAddress = try SocketAddress(ipAddress: "0.0.0.0", port: 0)
let loopBoundProxy = NIOLoopBound(proxy, eventLoop: context.eventLoop)
let proxyToServerFuture = DatagramBootstrap(group: context.eventLoop)
.channelInitializer { [log] channel in
log?.trace("frontend - initializing backend")
return channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandler(loopBoundProxy.value)
}
.bind(to: proxyAddress)
.flatMap { $0.closeFuture }
let context = ProxyContext(proxy: proxy, closeFuture: proxyToServerFuture)
if let (_, evictedContext) = $0.put(key: key, value: context) {
self.log?.trace("frontend - closing evicted backend")
evictedContext.proxy.close()
}

proxy.write(data: inbound.data)
.bind(to: proxyAddress)
.flatMap { $0.closeFuture }
let context = ProxyContext(proxy: proxy, closeFuture: proxyToServerFuture)
if let (_, evictedContext) = proxies.put(key: key, value: context) {
self.log?.trace("frontend - closing evicted backend")
evictedContext.proxy.close()
}

proxy.write(data: inbound.data)
}
} catch {
log?.error("server handler - backend channel creation failed with error: \(error)")
Expand Down Expand Up @@ -192,20 +185,21 @@ public struct UDPForwarder: SocketForwarder {

public func run() throws -> EventLoopFuture<SocketForwarderResult> {
self.log?.trace("frontend - creating channel")
let proxyToServerHandler = UDPProxyFrontend(
proxyAddress: proxyAddress,
serverAddress: serverAddress,
eventLoopGroup: self.eventLoopGroup,
log: log
)
let bootstrap = DatagramBootstrap(group: self.eventLoopGroup)
.channelInitializer { serverChannel in
self.log?.trace("frontend - initializing channel")
return serverChannel.pipeline.addHandler(proxyToServerHandler)
let proxyToServerHandler = UDPProxyFrontend(
proxyAddress: proxyAddress,
serverAddress: serverAddress,
log: log
)
return serverChannel.eventLoop.makeCompletedFuture {
try serverChannel.pipeline.syncOperations.addHandler(proxyToServerHandler)
}
}
return
bootstrap
.bind(to: proxyAddress)
.flatMap { $0.eventLoop.makeSucceededFuture(SocketForwarderResult(channel: $0)) }
.map { SocketForwarderResult(channel: $0) }
}
}