diff --git a/access_server/src/lib.rs b/access_server/src/lib.rs index 92d89a3..aa27705 100644 --- a/access_server/src/lib.rs +++ b/access_server/src/lib.rs @@ -3,31 +3,85 @@ use std::{collections::HashMap, sync::Arc}; use common::{ config::{merge_map, Merge}, error::{AnyError, AnyResult}, - filter::{self, FilterBuilder, MatcherBuilder}, + filter::{self, MatcherBuilder}, loading, - stream::proxy_table::StreamProxyConfig, - udp::proxy_table::UdpProxyConfig, + udp::proxy_table::{UdpProxyConfig, UdpProxyGroupBuilder, UdpProxyTableBuilder}, }; -use protocol::{context::ConcreteContext, stream::addr::ConcreteStreamType}; +use protocol::{ + context::ConcreteContext, + stream::proxy_table::{StreamProxyConfig, StreamProxyGroupBuilder, StreamProxyTableBuilder}, +}; +use proxy_client::{stream::StreamTracerBuilder, udp::UdpTracerBuilder}; use serde::{Deserialize, Serialize}; use socks5::server::{ tcp::{Socks5ServerTcpAccess, Socks5ServerTcpAccessServerConfig}, udp::{Socks5ServerUdpAccess, Socks5ServerUdpAccessServerConfig}, }; use stream::{ - proxy_table::StreamProxyTableBuilder, + proxy_table::{StreamProxyGroupBuildContext, StreamProxyTableBuildContext}, streams::{ http_tunnel::{HttpAccess, HttpAccessServerConfig}, tcp::{TcpAccess, TcpAccessServerConfig}, }, }; use tokio_util::sync::CancellationToken; -use udp::{proxy_table::UdpProxyTableBuilder, UdpAccess, UdpAccessServerConfig}; +use udp::{ + proxy_table::{UdpProxyGroupBuildContext, UdpProxyTableBuildContext}, + UdpAccess, UdpAccessServerConfig, +}; pub mod socks5; pub mod stream; pub mod udp; +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[serde(deny_unknown_fields)] +pub struct AccessServerStream { + #[serde(default)] + pub proxy_table: HashMap, StreamProxyTableBuilder>, + #[serde(default)] + pub proxy_group: HashMap, StreamProxyGroupBuilder>, +} +impl Merge for AccessServerStream { + type Error = AnyError; + + fn merge(self, other: Self) -> Result + where + Self: Sized, + { + let proxy_table = merge_map(self.proxy_table, other.proxy_table)?; + let proxy_group = merge_map(self.proxy_group, other.proxy_group)?; + Ok(Self { + proxy_table, + proxy_group, + }) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[serde(deny_unknown_fields)] +pub struct AccessServerUdp { + #[serde(default)] + pub proxy_table: HashMap, UdpProxyTableBuilder>, + #[serde(default)] + pub proxy_group: HashMap, UdpProxyGroupBuilder>, +} +impl Merge for AccessServerUdp { + type Error = AnyError; + + fn merge(self, other: Self) -> Result + where + Self: Sized, + { + let proxy_table = merge_map(self.proxy_table, other.proxy_table)?; + let proxy_group = merge_map(self.proxy_group, other.proxy_group)?; + Ok(Self { + proxy_table, + proxy_group, + }) + } +} + #[derive(Debug, Clone, Serialize, Deserialize, Default)] #[serde(deny_unknown_fields)] pub struct AccessServerConfig { @@ -42,13 +96,11 @@ pub struct AccessServerConfig { #[serde(default)] pub socks5_udp_server: Vec, #[serde(default)] - pub stream_proxy_tables: HashMap, StreamProxyTableBuilder>, - #[serde(default)] - pub udp_proxy_tables: HashMap, UdpProxyTableBuilder>, + stream: AccessServerStream, #[serde(default)] - pub matchers: HashMap, MatcherBuilder>, + udp: AccessServerUdp, #[serde(default)] - pub filters: HashMap, FilterBuilder>, + pub matcher: HashMap, MatcherBuilder>, } impl AccessServerConfig { pub fn new() -> AccessServerConfig { @@ -58,10 +110,9 @@ impl AccessServerConfig { http_server: Default::default(), socks5_tcp_server: Default::default(), socks5_udp_server: Default::default(), - stream_proxy_tables: Default::default(), - udp_proxy_tables: Default::default(), - matchers: Default::default(), - filters: Default::default(), + stream: Default::default(), + udp: Default::default(), + matcher: Default::default(), } } @@ -71,29 +122,80 @@ impl AccessServerConfig { loader: &mut AccessServerLoader, cancellation: CancellationToken, context: ConcreteContext, - stream_proxy: &HashMap, StreamProxyConfig>, - udp_proxy: &HashMap, UdpProxyConfig>, + stream_proxy_server: &HashMap, StreamProxyConfig>, + udp_proxy_server: &HashMap, UdpProxyConfig>, ) -> AnyResult { // Shared + let matcher: HashMap, filter::Matcher> = self + .matcher + .into_iter() + .map(|(k, v)| match v.build() { + Ok(v) => Ok((k, v)), + Err(e) => Err(e), + }) + .collect::, _>>()?; + + // Stream + let stream_trace_builder = StreamTracerBuilder::new(context.stream.clone()); + let stream_proxy_group_cx = StreamProxyGroupBuildContext { + proxy_server: stream_proxy_server, + tracer_builder: &stream_trace_builder, + cancellation: cancellation.clone(), + }; + let stream_proxy_group = self + .stream + .proxy_group + .into_iter() + .map(|(k, v)| match v.build(stream_proxy_group_cx.clone()) { + Ok(v) => Ok((k, v)), + Err(e) => Err(e), + }) + .collect::, _>>()?; + let stream_proxy_table_cx = StreamProxyTableBuildContext { + matcher: &matcher, + proxy_group: &stream_proxy_group, + proxy_group_cx: stream_proxy_group_cx.clone(), + }; let stream_proxy_tables = self - .stream_proxy_tables + .stream + .proxy_table .into_iter() - .map( - |(k, v)| match v.build(stream_proxy, &context.stream, cancellation.clone()) { - Ok(v) => Ok((k, v)), - Err(e) => Err(e), - }, - ) + .map(|(k, v)| match v.build(stream_proxy_table_cx.clone()) { + Ok(v) => Ok((k, v)), + Err(e) => Err(e), + }) .collect::, _>>()?; - let udp_proxy_tables = self - .udp_proxy_tables + + // UDP + let udp_trace_builder = UdpTracerBuilder::new(); + let udp_proxy_group_cx = UdpProxyGroupBuildContext { + proxy_server: udp_proxy_server, + tracer_builder: &udp_trace_builder, + cancellation: cancellation.clone(), + }; + let udp_proxy_group = self + .udp + .proxy_group .into_iter() - .map(|(k, v)| match v.build(udp_proxy, cancellation.clone()) { + .map(|(k, v)| match v.build(udp_proxy_group_cx.clone()) { + Ok(v) => Ok((k, v)), + Err(e) => Err(e), + }) + .collect::, _>>()?; + let udp_proxy_table_cx = UdpProxyTableBuildContext { + matcher: &matcher, + proxy_group: &udp_proxy_group, + proxy_group_cx: udp_proxy_group_cx.clone(), + }; + let _udp_proxy_tables = self + .udp + .proxy_table + .into_iter() + .map(|(k, v)| match v.build(udp_proxy_table_cx.clone()) { Ok(v) => Ok((k, v)), Err(e) => Err(e), }) .collect::, _>>()?; - let filters = filter::build_from_map(self.matchers, self.filters)?; // TCP servers let tcp_server = self @@ -101,9 +203,8 @@ impl AccessServerConfig { .into_iter() .map(|c| { c.into_builder( - stream_proxy, - &stream_proxy_tables, - cancellation.clone(), + &stream_proxy_group, + stream_proxy_group_cx.clone(), context.stream.clone(), ) }) @@ -119,9 +220,8 @@ impl AccessServerConfig { .into_iter() .map(|c| { c.into_builder( - udp_proxy, - &udp_proxy_tables, - cancellation.clone(), + &udp_proxy_group, + udp_proxy_group_cx.clone(), context.udp.clone(), ) }) @@ -137,10 +237,8 @@ impl AccessServerConfig { .into_iter() .map(|c| { c.into_builder( - stream_proxy, &stream_proxy_tables, - &filters, - cancellation.clone(), + stream_proxy_table_cx.clone(), context.stream.clone(), ) }) @@ -156,10 +254,8 @@ impl AccessServerConfig { .into_iter() .map(|c| { c.into_builder( - stream_proxy, &stream_proxy_tables, - &filters, - cancellation.clone(), + stream_proxy_table_cx.clone(), context.stream.clone(), ) }) @@ -175,9 +271,8 @@ impl AccessServerConfig { .into_iter() .map(|c| { c.into_builder( - udp_proxy, - &udp_proxy_tables, - cancellation.clone(), + &udp_proxy_group, + udp_proxy_group_cx.clone(), context.udp.clone(), ) }) @@ -202,20 +297,18 @@ impl Merge for AccessServerConfig { self.http_server.extend(other.http_server); self.socks5_tcp_server.extend(other.socks5_tcp_server); self.socks5_udp_server.extend(other.socks5_udp_server); - let stream_proxy_tables = merge_map(self.stream_proxy_tables, other.stream_proxy_tables)?; - let udp_proxy_tables = merge_map(self.udp_proxy_tables, other.udp_proxy_tables)?; - let matchers = merge_map(self.matchers, other.matchers)?; - let filters = merge_map(self.filters, other.filters)?; + let stream = self.stream.merge(other.stream)?; + let udp = self.udp.merge(other.udp)?; + let matcher = merge_map(self.matcher, other.matcher)?; Ok(Self { tcp_server: self.tcp_server, udp_server: self.udp_server, http_server: self.http_server, socks5_tcp_server: self.socks5_tcp_server, socks5_udp_server: self.socks5_udp_server, - stream_proxy_tables, - udp_proxy_tables, - matchers, - filters, + stream, + udp, + matcher, }) } } diff --git a/access_server/src/socks5/server/tcp.rs b/access_server/src/socks5/server/tcp.rs index e849c76..adcb163 100644 --- a/access_server/src/socks5/server/tcp.rs +++ b/access_server/src/socks5/server/tcp.rs @@ -4,24 +4,25 @@ use async_speed_limit::Limiter; use common::{ addr::{InternetAddr, InternetAddrStr}, config::SharableConfig, - filter::{self, Filter, FilterBuilder}, loading::{self, Hook}, + proxy_table::{ProxyAction, ProxyTableBuildError}, stream::{ addr::StreamAddr, io_copy::{CopyBidirectional, MetricContext}, - proxy_table::{StreamProxyConfig, StreamProxyTable}, IoAddr, IoStream, StreamServerHook, }, }; use protocol::stream::{ - addr::ConcreteStreamType, connection::ConnAndAddr, context::ConcreteStreamContext, + addr::ConcreteStreamType, + connection::ConnAndAddr, + context::ConcreteStreamContext, + proxy_table::{StreamProxyGroup, StreamProxyTable, StreamProxyTableBuilder}, streams::tcp::TcpServer, }; use proxy_client::stream::StreamEstablishError; use serde::{Deserialize, Serialize}; use thiserror::Error; use tokio::io::AsyncReadExt; -use tokio_util::sync::CancellationToken; use tracing::{error, trace, warn}; use crate::{ @@ -32,7 +33,7 @@ use crate::{ Command, MethodIdentifier, NegotiationRequest, NegotiationResponse, RelayRequest, RelayResponse, Reply, }, - stream::proxy_table::{StreamProxyTableBuildError, StreamProxyTableBuilder}, + stream::proxy_table::StreamProxyTableBuildContext, }; #[derive(Debug, Clone, Serialize, Deserialize)] @@ -40,7 +41,6 @@ use crate::{ pub struct Socks5ServerTcpAccessServerConfig { pub listen_addr: Arc, pub proxy_table: SharableConfig, - pub filter: SharableConfig, pub speed_limit: Option, pub udp_server_addr: Option, #[serde(default)] @@ -50,25 +50,16 @@ pub struct Socks5ServerTcpAccessServerConfig { impl Socks5ServerTcpAccessServerConfig { pub fn into_builder( self, - stream_proxy: &HashMap, StreamProxyConfig>, - proxy_tables: &HashMap, StreamProxyTable>, - filters: &HashMap, Filter>, - cancellation: CancellationToken, + proxy_table: &HashMap, StreamProxyTable>, + proxy_table_cx: StreamProxyTableBuildContext<'_>, stream_context: ConcreteStreamContext, ) -> Result { let proxy_table = match self.proxy_table { - SharableConfig::SharingKey(key) => proxy_tables + SharableConfig::SharingKey(key) => proxy_table .get(&key) .ok_or_else(|| BuildError::ProxyTableKeyNotFound(key.clone()))? .clone(), - SharableConfig::Private(x) => x.build(stream_proxy, &stream_context, cancellation)?, - }; - let filter = match self.filter { - SharableConfig::SharingKey(key) => filters - .get(&key) - .ok_or_else(|| BuildError::FilterKeyNotFound(key.clone()))? - .clone(), - SharableConfig::Private(x) => x.build(filters, &Default::default())?, + SharableConfig::Private(x) => x.build(proxy_table_cx)?, }; let users = self .users @@ -79,7 +70,6 @@ impl Socks5ServerTcpAccessServerConfig { Ok(Socks5ServerTcpAccessServerBuilder { listen_addr: self.listen_addr, proxy_table, - filter, speed_limit: self.speed_limit.unwrap_or(f64::INFINITY), udp_server_addr: self.udp_server_addr.map(|a| a.0), users, @@ -101,17 +91,14 @@ pub enum BuildError { ProxyTableKeyNotFound(Arc), #[error("Filter key not found: {0}")] FilterKeyNotFound(Arc), - #[error("Filter error: {0}")] - Filter(#[from] filter::FilterBuildError), #[error("{0}")] - ProxyTable(#[from] StreamProxyTableBuildError), + ProxyTable(#[from] ProxyTableBuildError), } #[derive(Debug, Clone)] pub struct Socks5ServerTcpAccessServerBuilder { listen_addr: Arc, - proxy_table: StreamProxyTable, - filter: Filter, + proxy_table: StreamProxyTable, speed_limit: f64, udp_server_addr: Option, users: HashMap, Arc<[u8]>>, @@ -137,7 +124,6 @@ impl loading::Builder for Socks5ServerTcpAccessServerBuilder { fn build_hook(self) -> Result { let access = Socks5ServerTcpAccess::new( self.proxy_table, - self.filter, self.speed_limit, self.udp_server_addr, self.users, @@ -149,16 +135,13 @@ impl loading::Builder for Socks5ServerTcpAccessServerBuilder { #[derive(Debug)] pub struct Socks5ServerTcpAccess { - proxy_table: StreamProxyTable, - filter: Filter, + proxy_table: StreamProxyTable, speed_limiter: Limiter, udp_listen_addr: Option, users: HashMap, Arc<[u8]>>, stream_context: ConcreteStreamContext, } - impl Hook for Socks5ServerTcpAccess {} - impl StreamServerHook for Socks5ServerTcpAccess { async fn handle_stream(&self, stream: S) where @@ -173,11 +156,9 @@ impl StreamServerHook for Socks5ServerTcpAccess { } } } - impl Socks5ServerTcpAccess { pub fn new( - proxy_table: StreamProxyTable, - filter: Filter, + proxy_table: StreamProxyTable, speed_limit: f64, udp_listen_addr: Option, users: HashMap, Arc<[u8]>>, @@ -185,7 +166,6 @@ impl Socks5ServerTcpAccess { ) -> Self { Self { proxy_table, - filter, speed_limiter: Limiter::new(speed_limit), udp_listen_addr, users, @@ -322,21 +302,6 @@ impl Socks5ServerTcpAccess { relay_request: RelayRequest, local_addr: SocketAddr, ) -> (RelayResponse, Result) { - // Filter - let action = self.filter.filter(&relay_request.destination); - if matches!(action, filter::Action::Block) { - let relay_response = RelayResponse { - reply: Reply::ConnectionNotAllowedByRuleset, - bind: InternetAddr::zero_ipv4_addr(), - }; - return ( - relay_response, - Ok(RequestResult::Blocked { - destination: relay_request.destination, - }), - ); - } - match relay_request.command { Command::Connect => (), Command::Bind => { @@ -364,54 +329,64 @@ impl Socks5ServerTcpAccess { }, } - fn general_socks_server_failure() -> RelayResponse { - RelayResponse { - reply: Reply::GeneralSocksServerFailure, - bind: InternetAddr::zero_ipv4_addr(), + // Filter + let action = self.proxy_table.action(&relay_request.destination); + let proxy_group = match action { + ProxyAction::Block => { + let relay_response = RelayResponse { + reply: Reply::ConnectionNotAllowedByRuleset, + bind: InternetAddr::zero_ipv4_addr(), + }; + return ( + relay_response, + Ok(RequestResult::Blocked { + destination: relay_request.destination, + }), + ); } - } - - if matches!(action, filter::Action::Direct) { - let sock_addr = match relay_request.destination.to_socket_addr().await { - Ok(sock_addr) => sock_addr, - Err(e) => { - return ( - general_socks_server_failure(), - Err(EstablishError::DirectConnect { - source: e, - destination: relay_request.destination.clone(), - }), - ); - } - }; - let upstream = match tokio::net::TcpStream::connect(sock_addr).await { - Ok(upstream) => upstream, - Err(e) => { - return ( - general_socks_server_failure(), - Err(EstablishError::DirectConnect { - source: e, - destination: relay_request.destination.clone(), - }), - ); - } - }; - let relay_response = RelayResponse { - reply: Reply::Succeeded, - bind: local_addr.into(), - }; - return ( - relay_response, - Ok(RequestResult::Direct { - upstream, - upstream_addr: relay_request.destination, - upstream_sock_addr: sock_addr, - }), - ); - } + ProxyAction::Direct => { + let sock_addr = match relay_request.destination.to_socket_addr().await { + Ok(sock_addr) => sock_addr, + Err(e) => { + return ( + general_socks_server_failure(), + Err(EstablishError::DirectConnect { + source: e, + destination: relay_request.destination.clone(), + }), + ); + } + }; + let upstream = match tokio::net::TcpStream::connect(sock_addr).await { + Ok(upstream) => upstream, + Err(e) => { + return ( + general_socks_server_failure(), + Err(EstablishError::DirectConnect { + source: e, + destination: relay_request.destination.clone(), + }), + ); + } + }; + let relay_response = RelayResponse { + reply: Reply::Succeeded, + bind: local_addr.into(), + }; + return ( + relay_response, + Ok(RequestResult::Direct { + upstream, + upstream_addr: relay_request.destination, + upstream_sock_addr: sock_addr, + }), + ); + } + ProxyAction::ProxyGroup(proxy_group) => proxy_group, + }; let (upstream, payload_crypto) = match self - .establish_proxy_chain(relay_request.destination.clone()) + .establish_proxy_chain(proxy_group, relay_request.destination.clone()) .await { Ok(res) => res, @@ -423,14 +398,21 @@ impl Socks5ServerTcpAccess { reply: Reply::Succeeded, bind: local_addr.into(), }; - ( + return ( relay_response, Ok(RequestResult::Proxy { destination: relay_request.destination, upstream, payload_crypto, }), - ) + ); + + fn general_socks_server_failure() -> RelayResponse { + RelayResponse { + reply: Reply::GeneralSocksServerFailure, + bind: InternetAddr::zero_ipv4_addr(), + } + } } async fn steer(&self, stream: S) -> io::Result<(S, RelayRequest)> @@ -531,13 +513,11 @@ impl Socks5ServerTcpAccess { async fn establish_proxy_chain( &self, + proxy_group: &StreamProxyGroup, destination: InternetAddr, ) -> Result<(ConnAndAddr, Option), EstablishProxyChainError> { - let Some(proxy_table_group) = self.proxy_table.group(&destination) else { - return Err(EstablishProxyChainError::NoProxy); - }; - let proxy_chain = proxy_table_group.choose_chain(); + let proxy_chain = proxy_group.choose_chain(); let res = proxy_client::stream::establish( &proxy_chain.chain, StreamAddr { @@ -553,8 +533,6 @@ impl Socks5ServerTcpAccess { #[derive(Debug, Error)] pub enum EstablishProxyChainError { - #[error("No proxy")] - NoProxy, #[error("{0}")] StreamEstablish(#[from] StreamEstablishError), } diff --git a/access_server/src/socks5/server/udp.rs b/access_server/src/socks5/server/udp.rs index dd8ac1c..12b2ccf 100644 --- a/access_server/src/socks5/server/udp.rs +++ b/access_server/src/socks5/server/udp.rs @@ -4,10 +4,11 @@ use async_speed_limit::Limiter; use common::{ config::SharableConfig, loading, + proxy_table::ProxyGroupBuildError, udp::{ context::UdpContext, io_copy::{CopyBidirectional, DownstreamParts, UpstreamParts}, - proxy_table::{UdpProxyConfig, UdpProxyTable}, + proxy_table::{UdpProxyGroup, UdpProxyGroupBuilder}, FlowOwnedGuard, Packet, UdpDownstreamWriter, UdpServer, UdpServerHook, UpstreamAddr, }, }; @@ -15,41 +16,36 @@ use proxy_client::udp::{EstablishError, UdpProxyClient}; use serde::{Deserialize, Serialize}; use thiserror::Error; use tokio::{net::ToSocketAddrs, sync::mpsc}; -use tokio_util::sync::CancellationToken; use tracing::{error, warn}; -use crate::{ - socks5::messages::UdpRequestHeader, - udp::proxy_table::{UdpProxyTableBuildError, UdpProxyTableBuilder}, -}; +use crate::{socks5::messages::UdpRequestHeader, udp::proxy_table::UdpProxyGroupBuildContext}; #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(deny_unknown_fields)] pub struct Socks5ServerUdpAccessServerConfig { pub listen_addr: Arc, - pub proxy_table: SharableConfig, + pub proxy_group: SharableConfig, pub speed_limit: Option, } impl Socks5ServerUdpAccessServerConfig { pub fn into_builder( self, - udp_proxy: &HashMap, UdpProxyConfig>, - proxy_tables: &HashMap, UdpProxyTable>, - cancellation: CancellationToken, + proxy_group: &HashMap, UdpProxyGroup>, + cx: UdpProxyGroupBuildContext<'_>, udp_context: UdpContext, ) -> Result { - let proxy_table = match self.proxy_table { - SharableConfig::SharingKey(key) => proxy_tables + let proxy_group = match self.proxy_group { + SharableConfig::SharingKey(key) => proxy_group .get(&key) - .ok_or_else(|| BuildError::ProxyTableKeyNotFound(key.clone()))? + .ok_or_else(|| BuildError::ProxyGroupKeyNotFound(key.clone()))? .clone(), - SharableConfig::Private(x) => x.build(udp_proxy, cancellation)?, + SharableConfig::Private(x) => x.build(cx)?, }; Ok(Socks5ServerUdpAccessServerBuilder { listen_addr: self.listen_addr, - proxy_table, + proxy_group, speed_limit: self.speed_limit.unwrap_or(f64::INFINITY), udp_context, }) @@ -58,16 +54,16 @@ impl Socks5ServerUdpAccessServerConfig { #[derive(Debug, Error)] pub enum BuildError { - #[error("Proxy table key not found: {0}")] - ProxyTableKeyNotFound(Arc), + #[error("Proxy group key not found: {0}")] + ProxyGroupKeyNotFound(Arc), #[error("{0}")] - ProxyTable(#[from] UdpProxyTableBuildError), + ProxyGroup(#[from] ProxyGroupBuildError), } #[derive(Debug, Clone)] pub struct Socks5ServerUdpAccessServerBuilder { listen_addr: Arc, - proxy_table: UdpProxyTable, + proxy_group: UdpProxyGroup, speed_limit: f64, udp_context: UdpContext, } @@ -90,7 +86,7 @@ impl loading::Builder for Socks5ServerUdpAccessServerBuilder { fn build_hook(self) -> Result { Ok(Socks5ServerUdpAccess::new( - self.proxy_table, + self.proxy_group, self.speed_limit, self.udp_context, )) @@ -99,7 +95,7 @@ impl loading::Builder for Socks5ServerUdpAccessServerBuilder { #[derive(Debug)] pub struct Socks5ServerUdpAccess { - proxy_table: UdpProxyTable, + proxy_group: UdpProxyGroup, speed_limiter: Limiter, udp_context: UdpContext, } @@ -107,9 +103,9 @@ pub struct Socks5ServerUdpAccess { impl loading::Hook for Socks5ServerUdpAccess {} impl Socks5ServerUdpAccess { - pub fn new(proxy_table: UdpProxyTable, speed_limit: f64, udp_context: UdpContext) -> Self { + pub fn new(proxy_group: UdpProxyGroup, speed_limit: f64, udp_context: UdpContext) -> Self { Self { - proxy_table, + proxy_group, speed_limiter: Limiter::new(speed_limit), udp_context, } @@ -127,10 +123,7 @@ impl Socks5ServerUdpAccess { downstream_writer: UdpDownstreamWriter, ) -> Result<(), AccessProxyError> { // Connect to upstream - let Some(proxy_table_group) = self.proxy_table.group(&flow.flow().upstream.0) else { - return Err(AccessProxyError::NoProxy); - }; - let proxy_chain = proxy_table_group.choose_chain(); + let proxy_chain = self.proxy_group.choose_chain(); let upstream = UdpProxyClient::establish(proxy_chain.chain.clone(), flow.flow().upstream.0.clone()) .await?; @@ -177,8 +170,6 @@ impl Socks5ServerUdpAccess { #[derive(Debug, Error)] pub enum AccessProxyError { - #[error("No proxy")] - NoProxy, #[error("Failed to establish proxy chain: {0}")] Establish(#[from] EstablishError), } diff --git a/access_server/src/stream/proxy_table.rs b/access_server/src/stream/proxy_table.rs index 457326e..f55138d 100644 --- a/access_server/src/stream/proxy_table.rs +++ b/access_server/src/stream/proxy_table.rs @@ -1,88 +1,157 @@ -use std::{collections::HashMap, num::NonZeroUsize, sync::Arc}; +use common::proxy_table::{ProxyGroupBuildContext, ProxyTableBuildContext}; +use protocol::stream::addr::ConcreteStreamAddr; +use proxy_client::stream::StreamTracerBuilder; -use common::{ - filter::MatcherBuilder, - proxy_table::{ProxyTable, ProxyTableError, ProxyTableGroup}, - stream::proxy_table::{ - StreamProxyConfig, StreamProxyConfigBuildError, StreamProxyTable, StreamProxyTableGroup, - StreamWeightedProxyChainBuilder, - }, -}; -use protocol::stream::{ - addr::{ConcreteStreamAddrStr, ConcreteStreamType}, - context::ConcreteStreamContext, -}; -use proxy_client::stream::StreamTracer; -use serde::{Deserialize, Serialize}; -use thiserror::Error; -use tokio_util::sync::CancellationToken; +pub type StreamProxyTableBuildContext<'caller> = + ProxyTableBuildContext<'caller, ConcreteStreamAddr, StreamTracerBuilder>; +pub type StreamProxyGroupBuildContext<'caller> = + ProxyGroupBuildContext<'caller, ConcreteStreamAddr, StreamTracerBuilder>; -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(deny_unknown_fields)] -pub struct StreamProxyTableBuilder { - pub groups: Vec, -} -impl StreamProxyTableBuilder { - pub fn build( - self, - stream_proxy: &HashMap, StreamProxyConfig>, - stream_context: &ConcreteStreamContext, - cancellation: CancellationToken, - ) -> Result, StreamProxyTableBuildError> { - let mut built = vec![]; - for group in self.groups { - let g = group.build(stream_proxy, stream_context, cancellation.clone())?; - built.push(g); - } - Ok(ProxyTable::new(built)) - } -} +// use std::{collections::HashMap, num::NonZeroUsize, sync::Arc}; -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(deny_unknown_fields)] -pub struct StreamProxyTableGroupBuilder { - pub matcher: MatcherBuilder, - pub chains: Vec>, - pub trace_rtt: bool, - pub active_chains: Option, -} -impl StreamProxyTableGroupBuilder { - pub fn build( - self, - stream_proxy: &HashMap, StreamProxyConfig>, - stream_context: &ConcreteStreamContext, - cancellation: CancellationToken, - ) -> Result, StreamProxyTableBuildError> { - let matcher = self - .matcher - .build() - .map_err(StreamProxyTableBuildError::Matcher)?; - let chains = self - .chains - .into_iter() - .map(|c| c.build(stream_proxy)) - .collect::>() - .map_err(StreamProxyTableBuildError::ChainConfig)?; - let tracer = match self.trace_rtt { - true => Some(StreamTracer::new(stream_context.clone())), - false => None, - }; - Ok(ProxyTableGroup::new( - matcher, - chains, - tracer, - self.active_chains, - cancellation, - )?) - } -} +// use common::{ +// config::SharableConfig, +// filter::{Matcher, MatcherBuilder}, +// proxy_table::{ProxyAction, ProxyGroup, ProxyTable, ProxyTableEntry, ProxyTableError}, +// stream::proxy_table::{ +// StreamProxyConfig, StreamProxyConfigBuildError, StreamProxyGroup, StreamProxyTable, +// StreamProxyTableEntry, StreamProxyTableEntryAction, StreamWeightedProxyChainBuilder, +// }, +// }; +// use protocol::stream::{ +// addr::{ConcreteStreamAddrStr, ConcreteStreamType}, +// context::ConcreteStreamContext, +// }; +// use proxy_client::stream::StreamTracer; +// use serde::{Deserialize, Serialize}; +// use thiserror::Error; +// use tokio_util::sync::CancellationToken; -#[derive(Debug, Error)] -pub enum StreamProxyTableBuildError { - #[error("Matcher: {0}")] - Matcher(#[source] regex::Error), - #[error("Chain config is invalid: {0}")] - ChainConfig(#[source] StreamProxyConfigBuildError), - #[error("{0}")] - ProxyTable(#[from] ProxyTableError), -} +// #[derive(Debug, Clone, Serialize, Deserialize)] +// #[serde(deny_unknown_fields)] +// pub struct StreamProxyTableBuilder { +// pub entries: Vec, +// } +// impl StreamProxyTableBuilder { +// pub fn build( +// self, +// cx: StreamProxyTableBuildContext<'_>, +// ) -> Result, StreamProxyTableBuildError> { +// let mut built = vec![]; +// for entry in self.entries { +// let e = entry.build(cx.clone())?; +// built.push(e); +// } +// Ok(ProxyTable::new(built)) +// } +// } +// #[derive(Debug, Clone)] +// pub struct StreamProxyTableBuildContext<'caller> { +// pub matcher: &'caller HashMap, Matcher>, +// pub stream_proxy_group: &'caller HashMap, StreamProxyGroup>, +// pub proxy_group_cx: StreamProxyGroupBuildContext<'caller>, +// } + +// #[derive(Debug, Clone, Serialize, Deserialize)] +// #[serde(deny_unknown_fields)] +// pub struct StreamProxyTableEntryBuilder { +// matcher: SharableConfig, +// #[serde(flatten)] +// action: StreamProxyActionBuilder, +// } +// impl StreamProxyTableEntryBuilder { +// pub fn build( +// self, +// cx: StreamProxyTableBuildContext<'_>, +// ) -> Result, StreamProxyTableBuildError> { +// let matcher = match self.matcher { +// SharableConfig::SharingKey(k) => cx +// .matcher +// .get(&k) +// .cloned() +// .ok_or_else(|| StreamProxyTableBuildError::KeyNotFound(k))?, +// SharableConfig::Private(v) => v.build().map_err(StreamProxyTableBuildError::Matcher)?, +// }; +// let action = self +// .action +// .build(cx.stream_proxy_group, cx.proxy_group_cx)?; +// Ok(ProxyTableEntry::new(matcher, action)) +// } +// } + +// #[derive(Debug, Clone, Serialize, Deserialize)] +// #[serde(deny_unknown_fields)] +// #[serde(rename = "snake_case")] +// pub enum StreamProxyActionBuilder { +// Direct, +// Block, +// ProxyGroup(SharableConfig), +// } +// impl StreamProxyActionBuilder { +// pub fn build( +// self, +// stream_proxy_group: &HashMap, StreamProxyGroup>, +// proxy_group_cx: StreamProxyGroupBuildContext<'_>, +// ) -> Result, StreamProxyTableBuildError> { +// Ok(match self { +// StreamProxyActionBuilder::Direct => ProxyAction::Direct, +// StreamProxyActionBuilder::Block => ProxyAction::Block, +// StreamProxyActionBuilder::ProxyGroup(p) => ProxyAction::ProxyGroup(Arc::new(match p { +// SharableConfig::SharingKey(k) => stream_proxy_group +// .get(&k) +// .cloned() +// .ok_or_else(|| StreamProxyTableBuildError::KeyNotFound(k))?, +// SharableConfig::Private(p) => p.build(proxy_group_cx)?, +// })), +// }) +// } +// } + +// #[derive(Debug, Clone, Serialize, Deserialize)] +// #[serde(deny_unknown_fields)] +// pub struct StreamProxyGroupBuilder { +// pub chains: Vec>, +// pub trace_rtt: bool, +// pub active_chains: Option, +// } +// impl StreamProxyGroupBuilder { +// pub fn build( +// self, +// cx: StreamProxyGroupBuildContext<'_>, +// ) -> Result, StreamProxyTableBuildError> { +// let chains = self +// .chains +// .into_iter() +// .map(|c| c.build(cx.stream_proxy_server)) +// .collect::>() +// .map_err(StreamProxyTableBuildError::ChainConfig)?; +// let tracer = match self.trace_rtt { +// true => Some(StreamTracer::new(cx.stream_context.clone())), +// false => None, +// }; +// Ok(ProxyGroup::new( +// chains, +// tracer, +// self.active_chains, +// cx.cancellation, +// )?) +// } +// } +// #[derive(Debug, Clone)] +// pub struct StreamProxyGroupBuildContext<'caller> { +// pub stream_proxy_server: &'caller HashMap, StreamProxyConfig>, +// pub stream_context: &'caller ConcreteStreamContext, +// pub cancellation: CancellationToken, +// } + +// #[derive(Debug, Error)] +// pub enum StreamProxyTableBuildError { +// #[error("Key not found: `{0}`")] +// KeyNotFound(Arc), +// #[error("Matcher: {0}")] +// Matcher(#[source] regex::Error), +// #[error("Chain config is invalid: {0}")] +// ChainConfig(#[source] StreamProxyConfigBuildError), +// #[error("{0}")] +// ProxyTable(#[from] ProxyTableError), +// } diff --git a/access_server/src/stream/streams/http_tunnel/mod.rs b/access_server/src/stream/streams/http_tunnel/mod.rs index 67b8943..b6cc46b 100644 --- a/access_server/src/stream/streams/http_tunnel/mod.rs +++ b/access_server/src/stream/streams/http_tunnel/mod.rs @@ -5,13 +5,12 @@ use bytes::Bytes; use common::{ addr::{InternetAddr, ParseInternetAddrError}, config::SharableConfig, - filter::{self, Filter, FilterBuilder}, loading, + proxy_table::{ProxyAction, ProxyTableBuildError}, stream::{ addr::StreamAddr, io_copy::{CopyBidirectional, MetricContext, DEAD_SESSION_RETENTION_DURATION}, metrics::{SimplifiedStreamMetrics, SimplifiedStreamProxyMetrics, StreamRecord}, - proxy_table::{StreamProxyConfig, StreamProxyTable}, session_table::{Session, StreamSessionTable}, IoAddr, IoStream, StreamServerHook, }, @@ -23,7 +22,10 @@ use hyper::{ use hyper_util::rt::TokioIo; use monitor_table::table::RowOwnedGuard; use protocol::stream::{ - addr::ConcreteStreamType, context::ConcreteStreamContext, streams::tcp::TcpServer, + addr::ConcreteStreamType, + context::ConcreteStreamContext, + proxy_table::{StreamProxyGroup, StreamProxyTable, StreamProxyTableBuilder}, + streams::tcp::TcpServer, }; use proxy_client::stream::{establish, StreamEstablishError}; use serde::{Deserialize, Serialize}; @@ -32,27 +34,23 @@ use tokio::{ io::{AsyncRead, AsyncWrite}, net::ToSocketAddrs, }; -use tokio_util::sync::CancellationToken; use tracing::{error, info, instrument, trace, warn}; -use crate::stream::proxy_table::{StreamProxyTableBuildError, StreamProxyTableBuilder}; +use crate::stream::proxy_table::StreamProxyTableBuildContext; #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(deny_unknown_fields)] pub struct HttpAccessServerConfig { pub listen_addr: Arc, pub proxy_table: SharableConfig, - pub filter: SharableConfig, pub speed_limit: Option, } impl HttpAccessServerConfig { pub fn into_builder( self, - stream_proxy: &HashMap, StreamProxyConfig>, - proxy_tables: &HashMap, StreamProxyTable>, - filters: &HashMap, Filter>, - cancellation: CancellationToken, + proxy_tables: &HashMap, StreamProxyTable>, + proxy_tables_cx: StreamProxyTableBuildContext<'_>, stream_context: ConcreteStreamContext, ) -> Result { let proxy_table = match self.proxy_table { @@ -60,20 +58,12 @@ impl HttpAccessServerConfig { .get(&key) .ok_or_else(|| BuildError::ProxyTableKeyNotFound(key.clone()))? .clone(), - SharableConfig::Private(x) => x.build(stream_proxy, &stream_context, cancellation)?, - }; - let filter = match self.filter { - SharableConfig::SharingKey(key) => filters - .get(&key) - .ok_or_else(|| BuildError::FilterKeyNotFound(key.clone()))? - .clone(), - SharableConfig::Private(x) => x.build(filters, &Default::default())?, + SharableConfig::Private(x) => x.build(proxy_tables_cx.clone())?, }; Ok(HttpAccessServerBuilder { listen_addr: self.listen_addr, proxy_table, - filter, speed_limit: self.speed_limit.unwrap_or(f64::INFINITY), stream_context, }) @@ -84,19 +74,14 @@ impl HttpAccessServerConfig { pub enum BuildError { #[error("Proxy table key not found: {0}")] ProxyTableKeyNotFound(Arc), - #[error("Filter key not found: {0}")] - FilterKeyNotFound(Arc), - #[error("Filter error: {0}")] - Filter(#[from] filter::FilterBuildError), #[error("{0}")] - ProxyTable(#[from] StreamProxyTableBuildError), + ProxyTable(#[from] ProxyTableBuildError), } #[derive(Debug, Clone)] pub struct HttpAccessServerBuilder { listen_addr: Arc, - proxy_table: StreamProxyTable, - filter: Filter, + proxy_table: StreamProxyTable, speed_limit: f64, stream_context: ConcreteStreamContext, } @@ -118,34 +103,26 @@ impl loading::Builder for HttpAccessServerBuilder { } fn build_hook(self) -> Result { - let access = HttpAccess::new( - self.proxy_table, - self.filter, - self.speed_limit, - self.stream_context, - ); + let access = HttpAccess::new(self.proxy_table, self.speed_limit, self.stream_context); Ok(access) } } #[derive(Debug)] pub struct HttpAccess { - proxy_table: Arc>, - filter: Filter, + proxy_table: Arc, speed_limiter: Limiter, stream_context: ConcreteStreamContext, } impl HttpAccess { pub fn new( - proxy_table: StreamProxyTable, - filter: Filter, + proxy_table: StreamProxyTable, speed_limit: f64, stream_context: ConcreteStreamContext, ) -> Self { Self { proxy_table: Arc::new(proxy_table), - filter, speed_limiter: Limiter::new(speed_limit), stream_context, } @@ -195,14 +172,14 @@ impl HttpAccess { stream_type: ConcreteStreamType::Tcp, }; - let action = self.filter.filter(&addr.address); - match action { - filter::Action::Proxy => (), - filter::Action::Block => { + let action = self.proxy_table.action(&addr.address); + let proxy_group = match action { + ProxyAction::ProxyGroup(proxy_group) => proxy_group, + ProxyAction::Block => { trace!(?addr, "Blocked {}", method); return Ok(respond_with_rejection()); } - filter::Action::Direct => { + ProxyAction::Direct => { let sock_addr = addr .address .to_socket_addr() @@ -228,14 +205,10 @@ impl HttpAccess { info!(%addr, "Direct {} finished", method); return res; } - } + }; // Establish proxy chain - let Some(proxy_table_group) = self.proxy_table.group(&addr.address) else { - trace!(?addr, "No proxy {}", method); - return Ok(respond_with_rejection()); - }; - let proxy_chain = proxy_table_group.choose_chain(); + let proxy_chain = proxy_group.choose_chain(); let upstream = establish(&proxy_chain.chain, addr.clone(), &self.stream_context).await?; let session_guard = self.stream_context.session_table.as_ref().map(|s| { @@ -309,18 +282,18 @@ impl HttpAccess { } }; let addr = addr.parse()?; - let action = self.filter.filter(&addr); + let action = self.proxy_table.action(&addr); let http_connect = match action { - filter::Action::Proxy => Some(HttpConnect::new( - Arc::clone(&self.proxy_table), + ProxyAction::ProxyGroup(proxy_group) => Some(HttpConnect::new( + Arc::clone(proxy_group), self.speed_limiter.clone(), self.stream_context.clone(), )), - filter::Action::Block => { + ProxyAction::Block => { trace!(?addr, "Blocked CONNECT"); return Ok(respond_with_rejection()); } - filter::Action::Direct => None, + ProxyAction::Direct => None, }; let speed_limiter = self.speed_limiter.clone(); @@ -473,19 +446,19 @@ impl StreamServerHook for HttpAccess { } struct HttpConnect { - proxy_table: Arc>, + proxy_group: Arc, speed_limiter: Limiter, stream_context: ConcreteStreamContext, } impl HttpConnect { pub fn new( - proxy_table: Arc>, + proxy_group: Arc, speed_limiter: Limiter, stream_context: ConcreteStreamContext, ) -> Self { Self { - proxy_table, + proxy_group, speed_limiter, stream_context, } @@ -504,10 +477,7 @@ impl HttpConnect { address: address.clone(), stream_type: ConcreteStreamType::Tcp, }; - let Some(proxy_table_group) = self.proxy_table.group(&address) else { - return Err(HttpConnectError::NoProxy); - }; - let proxy_chain = proxy_table_group.choose_chain(); + let proxy_chain = self.proxy_group.choose_chain(); let upstream = establish( &proxy_chain.chain, destination.clone(), @@ -541,8 +511,6 @@ impl HttpConnect { #[derive(Debug, Error)] pub enum HttpConnectError { - #[error("No proxy")] - NoProxy, #[error("Failed to establish proxy chain")] EstablishProxyChain(#[from] StreamEstablishError), } diff --git a/access_server/src/stream/streams/tcp.rs b/access_server/src/stream/streams/tcp.rs index 90ec0fa..49f8104 100644 --- a/access_server/src/stream/streams/tcp.rs +++ b/access_server/src/stream/streams/tcp.rs @@ -4,55 +4,54 @@ use async_speed_limit::Limiter; use common::{ config::SharableConfig, loading, + proxy_table::ProxyGroupBuildError, stream::{ io_copy::{CopyBidirectional, MetricContext}, - proxy_table::{StreamProxyConfig, StreamProxyTable}, IoAddr, IoStream, StreamServerHook, }, }; use protocol::stream::{ - addr::{ConcreteStreamAddr, ConcreteStreamAddrStr, ConcreteStreamType}, + addr::{ConcreteStreamAddr, ConcreteStreamAddrStr}, context::ConcreteStreamContext, + proxy_table::{StreamProxyGroup, StreamProxyGroupBuilder}, streams::tcp::TcpServer, }; use proxy_client::stream::{establish, StreamEstablishError}; use serde::{Deserialize, Serialize}; use thiserror::Error; use tokio::net::ToSocketAddrs; -use tokio_util::sync::CancellationToken; use tracing::{error, instrument, warn}; -use crate::stream::proxy_table::{StreamProxyTableBuildError, StreamProxyTableBuilder}; +use crate::stream::proxy_table::StreamProxyGroupBuildContext; #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(deny_unknown_fields)] pub struct TcpAccessServerConfig { pub listen_addr: Arc, pub destination: ConcreteStreamAddrStr, - pub proxy_table: SharableConfig, + pub proxy_group: SharableConfig, pub speed_limit: Option, } impl TcpAccessServerConfig { pub fn into_builder( self, - stream_proxy: &HashMap, StreamProxyConfig>, - proxy_tables: &HashMap, StreamProxyTable>, - cancellation: CancellationToken, + proxy_group: &HashMap, StreamProxyGroup>, + proxy_group_cx: StreamProxyGroupBuildContext<'_>, stream_context: ConcreteStreamContext, ) -> Result { - let proxy_table = match self.proxy_table { - SharableConfig::SharingKey(key) => proxy_tables + let proxy_group = match self.proxy_group { + SharableConfig::SharingKey(key) => proxy_group .get(&key) - .ok_or_else(|| BuildError::ProxyTableKeyNotFound(key.clone()))? + .ok_or_else(|| BuildError::ProxyGroupKeyNotFound(key.clone()))? .clone(), - SharableConfig::Private(x) => x.build(stream_proxy, &stream_context, cancellation)?, + SharableConfig::Private(x) => x.build(proxy_group_cx.clone())?, }; Ok(TcpAccessServerBuilder { listen_addr: self.listen_addr, destination: self.destination, - proxy_table, + proxy_group, speed_limit: self.speed_limit.unwrap_or(f64::INFINITY), stream_context, }) @@ -61,17 +60,17 @@ impl TcpAccessServerConfig { #[derive(Debug, Error)] pub enum BuildError { - #[error("Proxy table key not found: {0}")] - ProxyTableKeyNotFound(Arc), + #[error("Proxy group key not found: {0}")] + ProxyGroupKeyNotFound(Arc), #[error("{0}")] - ProxyTable(#[from] StreamProxyTableBuildError), + ProxyGroup(#[from] ProxyGroupBuildError), } #[derive(Debug, Clone)] pub struct TcpAccessServerBuilder { listen_addr: Arc, destination: ConcreteStreamAddrStr, - proxy_table: StreamProxyTable, + proxy_group: StreamProxyGroup, speed_limit: f64, stream_context: ConcreteStreamContext, } @@ -94,7 +93,7 @@ impl loading::Builder for TcpAccessServerBuilder { fn build_hook(self) -> Result { Ok(TcpAccess::new( - self.proxy_table, + self.proxy_group, self.destination.0, self.speed_limit, self.stream_context, @@ -104,7 +103,7 @@ impl loading::Builder for TcpAccessServerBuilder { #[derive(Debug)] pub struct TcpAccess { - proxy_table: StreamProxyTable, + proxy_group: StreamProxyGroup, destination: ConcreteStreamAddr, speed_limiter: Limiter, stream_context: ConcreteStreamContext, @@ -112,13 +111,13 @@ pub struct TcpAccess { impl TcpAccess { pub fn new( - proxy_table: StreamProxyTable, + proxy_group: StreamProxyGroup, destination: ConcreteStreamAddr, speed_limit: f64, stream_context: ConcreteStreamContext, ) -> Self { Self { - proxy_table, + proxy_group, destination, speed_limiter: Limiter::new(speed_limit), stream_context, @@ -134,10 +133,7 @@ impl TcpAccess { where S: IoStream + IoAddr, { - let Some(proxy_table_group) = self.proxy_table.group(&self.destination.address) else { - return Err(ProxyError::NoProxy); - }; - let proxy_chain = proxy_table_group.choose_chain(); + let proxy_chain = self.proxy_group.choose_chain(); let upstream = establish( &proxy_chain.chain, self.destination.clone(), @@ -171,8 +167,6 @@ impl TcpAccess { #[derive(Debug, Error)] pub enum ProxyError { - #[error("No proxy")] - NoProxy, #[error("Failed to get downstream address: {0}")] DownstreamAddr(#[source] io::Error), #[error("Failed to establish proxy chain: {0}")] diff --git a/access_server/src/udp/mod.rs b/access_server/src/udp/mod.rs index e9e3980..fbb0540 100644 --- a/access_server/src/udp/mod.rs +++ b/access_server/src/udp/mod.rs @@ -5,10 +5,11 @@ use common::{ addr::{InternetAddr, InternetAddrStr}, config::SharableConfig, loading, + proxy_table::ProxyGroupBuildError, udp::{ context::UdpContext, io_copy::{CopyBiError, CopyBidirectional, DownstreamParts, UpstreamParts}, - proxy_table::{UdpProxyConfig, UdpProxyTable}, + proxy_table::{UdpProxyGroup, UdpProxyGroupBuilder}, FlowOwnedGuard, Packet, UdpDownstreamWriter, UdpServer, UdpServerHook, UpstreamAddr, }, }; @@ -16,10 +17,9 @@ use proxy_client::udp::{EstablishError, UdpProxyClient}; use serde::{Deserialize, Serialize}; use thiserror::Error; use tokio::{net::ToSocketAddrs, sync::mpsc}; -use tokio_util::sync::CancellationToken; use tracing::{error, warn}; -use self::proxy_table::{UdpProxyTableBuildError, UdpProxyTableBuilder}; +use self::proxy_table::UdpProxyGroupBuildContext; pub mod proxy_table; @@ -28,30 +28,29 @@ pub mod proxy_table; pub struct UdpAccessServerConfig { pub listen_addr: Arc, pub destination: InternetAddrStr, - pub proxy_table: SharableConfig, + pub proxy_group: SharableConfig, pub speed_limit: Option, } impl UdpAccessServerConfig { pub fn into_builder( self, - udp_proxy: &HashMap, UdpProxyConfig>, - proxy_tables: &HashMap, UdpProxyTable>, - cancellation: CancellationToken, + proxy_group: &HashMap, UdpProxyGroup>, + cx: UdpProxyGroupBuildContext<'_>, udp_context: UdpContext, ) -> Result { - let proxy_table = match self.proxy_table { - SharableConfig::SharingKey(key) => proxy_tables + let proxy_group = match self.proxy_group { + SharableConfig::SharingKey(key) => proxy_group .get(&key) - .ok_or_else(|| BuildError::ProxyTableKeyNotFound(key.clone()))? + .ok_or_else(|| BuildError::ProxyGroupKeyNotFound(key.clone()))? .clone(), - SharableConfig::Private(x) => x.build(udp_proxy, cancellation.clone())?, + SharableConfig::Private(x) => x.build(cx)?, }; Ok(UdpAccessServerBuilder { listen_addr: self.listen_addr, destination: self.destination, - proxy_table, + proxy_group, speed_limit: self.speed_limit.unwrap_or(f64::INFINITY), udp_context, }) @@ -60,17 +59,17 @@ impl UdpAccessServerConfig { #[derive(Debug, Error)] pub enum BuildError { - #[error("Proxy table key not found: {0}")] - ProxyTableKeyNotFound(Arc), + #[error("Proxy group key not found: {0}")] + ProxyGroupKeyNotFound(Arc), #[error("{0}")] - ProxyTable(#[from] UdpProxyTableBuildError), + ProxyGroup(#[from] ProxyGroupBuildError), } #[derive(Debug, Clone)] pub struct UdpAccessServerBuilder { listen_addr: Arc, destination: InternetAddrStr, - proxy_table: UdpProxyTable, + proxy_group: UdpProxyGroup, speed_limit: f64, udp_context: UdpContext, } @@ -93,7 +92,7 @@ impl loading::Builder for UdpAccessServerBuilder { fn build_hook(self) -> Result { Ok(UdpAccess::new( - self.proxy_table, + self.proxy_group, self.destination.0, self.speed_limit, self.udp_context, @@ -103,7 +102,7 @@ impl loading::Builder for UdpAccessServerBuilder { #[derive(Debug)] pub struct UdpAccess { - proxy_table: UdpProxyTable, + proxy_group: UdpProxyGroup, destination: InternetAddr, speed_limiter: Limiter, udp_context: UdpContext, @@ -113,13 +112,13 @@ impl loading::Hook for UdpAccess {} impl UdpAccess { pub fn new( - proxy_table: UdpProxyTable, + proxy_table: UdpProxyGroup, destination: InternetAddr, speed_limit: f64, udp_context: UdpContext, ) -> Self { Self { - proxy_table, + proxy_group: proxy_table, destination, speed_limiter: Limiter::new(speed_limit), udp_context, @@ -138,10 +137,7 @@ impl UdpAccess { downstream_writer: UdpDownstreamWriter, ) -> Result<(), AccessProxyError> { // Connect to upstream - let Some(proxy_table_group) = self.proxy_table.group(&flow.flow().upstream.0) else { - return Err(AccessProxyError::NoProxy); - }; - let proxy_chain = proxy_table_group.choose_chain(); + let proxy_chain = self.proxy_group.choose_chain(); let upstream = UdpProxyClient::establish(proxy_chain.chain.clone(), self.destination.clone()).await?; let upstream_remote = upstream.remote_addr().clone(); @@ -177,8 +173,6 @@ impl UdpAccess { #[derive(Debug, Error)] pub enum AccessProxyError { - #[error("No proxy")] - NoProxy, #[error("Failed to establish proxy chain: {0}")] Establish(#[from] EstablishError), #[error("Failed to copy: {0}")] diff --git a/access_server/src/udp/proxy_table.rs b/access_server/src/udp/proxy_table.rs index 860c9ce..4820a09 100644 --- a/access_server/src/udp/proxy_table.rs +++ b/access_server/src/udp/proxy_table.rs @@ -1,82 +1,10 @@ -use std::{collections::HashMap, num::NonZeroUsize, sync::Arc}; - use common::{ - filter::MatcherBuilder, - proxy_table::{ProxyTable, ProxyTableError, ProxyTableGroup}, - udp::proxy_table::{ - UdpProxyConfig, UdpProxyConfigBuildError, UdpProxyTable, UdpProxyTableGroup, - UdpWeightedProxyChainBuilder, - }, + addr::InternetAddr, + proxy_table::{ProxyGroupBuildContext, ProxyTableBuildContext}, }; -use proxy_client::udp::UdpTracer; -use serde::{Deserialize, Serialize}; -use thiserror::Error; -use tokio_util::sync::CancellationToken; - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(deny_unknown_fields)] -pub struct UdpProxyTableBuilder { - pub groups: Vec, -} -impl UdpProxyTableBuilder { - pub fn build( - self, - udp_proxy: &HashMap, UdpProxyConfig>, - cancellation: CancellationToken, - ) -> Result { - let mut built = vec![]; - for group in self.groups { - let g = group.build(udp_proxy, cancellation.clone())?; - built.push(g); - } - Ok(ProxyTable::new(built)) - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(deny_unknown_fields)] -pub struct UdpProxyTableGroupBuilder { - pub matcher: MatcherBuilder, - pub chains: Vec, - pub trace_rtt: bool, - pub active_chains: Option, -} -impl UdpProxyTableGroupBuilder { - pub fn build( - self, - udp_proxy: &HashMap, UdpProxyConfig>, - cancellation: CancellationToken, - ) -> Result { - let matcher = self - .matcher - .build() - .map_err(UdpProxyTableBuildError::Matcher)?; - let chains = self - .chains - .into_iter() - .map(|c| c.build(udp_proxy)) - .collect::>() - .map_err(UdpProxyTableBuildError::ChainConfig)?; - let tracer = match self.trace_rtt { - true => Some(UdpTracer::new()), - false => None, - }; - Ok(ProxyTableGroup::new( - matcher, - chains, - tracer, - self.active_chains, - cancellation, - )?) - } -} +use proxy_client::udp::UdpTracerBuilder; -#[derive(Debug, Error)] -pub enum UdpProxyTableBuildError { - #[error("Matcher: {0}")] - Matcher(#[source] regex::Error), - #[error("Chain config is invalid: {0}")] - ChainConfig(#[source] UdpProxyConfigBuildError), - #[error("{0}")] - ProxyTable(#[from] ProxyTableError), -} +pub type UdpProxyTableBuildContext<'caller> = + ProxyTableBuildContext<'caller, InternetAddr, UdpTracerBuilder>; +pub type UdpProxyGroupBuildContext<'caller> = + ProxyGroupBuildContext<'caller, InternetAddr, UdpTracerBuilder>; diff --git a/common/src/addr.rs b/common/src/addr.rs index f4da19e..ce192aa 100644 --- a/common/src/addr.rs +++ b/common/src/addr.rs @@ -14,6 +14,8 @@ use serde::{de::Visitor, Deserialize, Serialize}; use thiserror::Error; use tokio::net::lookup_host; +use crate::proxy_table::AddressString; + static RESOLVED_SOCKET_ADDR: Lazy, IpAddr>>> = Lazy::new(|| Mutex::new(LruCache::new(NonZeroUsize::new(128).unwrap()))); @@ -146,7 +148,13 @@ impl InternetAddr { #[derive(Debug, Clone)] pub struct InternetAddrStr(pub InternetAddr); +impl AddressString for InternetAddrStr { + type Address = InternetAddr; + fn into_address(self) -> Self::Address { + self.0 + } +} impl Serialize for InternetAddrStr { fn serialize(&self, serializer: S) -> Result where @@ -155,7 +163,6 @@ impl Serialize for InternetAddrStr { serializer.serialize_str(&self.0.to_string()) } } - impl<'de> Deserialize<'de> for InternetAddrStr { fn deserialize(deserializer: D) -> Result where @@ -164,9 +171,7 @@ impl<'de> Deserialize<'de> for InternetAddrStr { deserializer.deserialize_str(InternetAddrStrVisitor) } } - struct InternetAddrStrVisitor; - impl<'de> Visitor<'de> for InternetAddrStrVisitor { type Value = InternetAddrStr; diff --git a/common/src/proxy_table.rs b/common/src/proxy_table.rs deleted file mode 100644 index f143842..0000000 --- a/common/src/proxy_table.rs +++ /dev/null @@ -1,402 +0,0 @@ -use std::{ - fmt::Display, - num::NonZeroUsize, - sync::{Arc, RwLock}, - time::Duration, -}; - -use rand::Rng; -use serde::{Deserialize, Serialize}; -use thiserror::Error; -use tokio_util::sync::CancellationToken; -use tracing::{info, trace}; - -use crate::{ - addr::InternetAddr, cache_cell::CacheCell, error::AnyError, filter::Matcher, - header::route::RouteRequest, -}; - -const TRACE_INTERVAL: Duration = Duration::from_secs(30); -const TRACE_DEAD_INTERVAL: Duration = Duration::from_secs(60 * 2); -const TRACES_PER_WAVE: usize = 60; -const TRACE_BURST_GAP: Duration = Duration::from_millis(10); -const RTT_TIMEOUT: Duration = Duration::from_secs(60); - -pub type ProxyChain = [ProxyConfig]; - -#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize, Serialize)] -pub struct ProxyConfig { - pub address: A, - pub crypto: tokio_chacha20::config::Config, -} - -/// # Panic -/// -/// `nodes` must not be empty. -pub fn convert_proxies_to_header_crypto_pairs( - nodes: &ProxyChain, - destination: Option, -) -> Vec<(RouteRequest, &tokio_chacha20::config::Config)> -where - A: Clone + Sync + Send, -{ - assert!(!nodes.is_empty()); - let mut pairs = (0..nodes.len() - 1) - .map(|i| { - let node = &nodes[i]; - let next_node = &nodes[i + 1]; - let route_req = RouteRequest { - upstream: Some(next_node.address.clone()), - }; - (route_req, &node.crypto) - }) - .collect::>(); - let route_req = RouteRequest { - upstream: destination, - }; - pairs.push((route_req, &nodes.last().unwrap().crypto)); - pairs -} - -#[derive(Debug, Clone)] -pub struct ProxyTable { - groups: Vec>, -} -impl ProxyTable -where - A: std::fmt::Debug + Display + Clone + Send + Sync + 'static, -{ - pub fn new(groups: Vec>) -> Self { - Self { groups } - } - - pub fn group(&self, addr: &InternetAddr) -> Option<&ProxyTableGroup> { - self.groups - .iter() - .find(|&group| group.matcher().matches(addr)) - } -} - -#[derive(Debug, Clone)] -pub struct ProxyTableGroup { - matcher: Matcher, - chains: Arc<[GaugedProxyChain]>, - cum_weight: NonZeroUsize, - score_store: Arc>, - active_chains: NonZeroUsize, -} -impl ProxyTableGroup -where - A: std::fmt::Debug + Display + Clone + Send + Sync + 'static, -{ - pub fn new( - matcher: Matcher, - chains: Vec>, - tracer: Option, - active_chains: Option, - cancellation: CancellationToken, - ) -> Result - where - T: Tracer
+ Send + Sync + 'static, - { - let cum_weight = chains.iter().map(|c| c.weight).sum(); - if cum_weight == 0 { - return Err(ProxyTableError::ZeroAccumulatedWeight); - } - let cum_weight = NonZeroUsize::new(cum_weight).unwrap(); - - let active_chains = match active_chains { - Some(active_chains) => { - if active_chains.get() > chains.len() { - return Err(ProxyTableError::TooManyActiveChains); - } - active_chains - } - None => NonZeroUsize::new(chains.len()).unwrap(), - }; - - let tracer = tracer.map(Arc::new); - let chains = chains - .into_iter() - .map(|c| GaugedProxyChain::new(c, tracer.clone(), cancellation.clone())) - .collect::>(); - let score_store = Arc::new(RwLock::new(ScoreStore::new(None, TRACE_INTERVAL))); - Ok(Self { - matcher, - chains, - cum_weight, - score_store, - active_chains, - }) - } - - pub fn matcher(&self) -> &Matcher { - &self.matcher - } - - pub fn choose_chain(&self) -> &WeightedProxyChain { - if self.chains.len() == 1 { - return self.chains[0].weighted(); - } - - let scores = self.score_store.read().unwrap().get().cloned(); - let scores = match scores { - Some(scores) => scores, - None => { - let scores: Arc<[_]> = self.scores().into(); - info!(?scores, "Calculated scores"); - let sum = scores.iter().map(|(_, s)| *s).sum::(); - let scores = Scores { scores, sum }; - self.score_store.write().unwrap().set(scores.clone()); - scores - } - }; - - let mut rng = rand::thread_rng(); - if scores.sum == 0. { - let i = rng.gen_range(0..self.chains.len()); - return self.chains[i].weighted(); - } - let mut rand_score = rng.gen_range(0. ..scores.sum); - for &(i, score) in scores.scores.iter() { - if rand_score < score { - return self.chains[i].weighted(); - } - rand_score -= score; - } - unreachable!(); - } - - fn scores(&self) -> Vec<(usize, f64)> { - let weights_hat = self - .chains - .iter() - .map(|c| c.weighted().weight as f64 / self.cum_weight.get() as f64) - .collect::>(); - - let rtt = self - .chains - .iter() - .map(|c| c.rtt().map(|r| r.as_secs_f64())) - .collect::>(); - let rtt_hat = normalize(&rtt); - - let losses = self.chains.iter().map(|c| c.loss()).collect::>(); - let losses_hat = normalize(&losses); - - let mut scores = (0..self.chains.len()) - .map(|i| (1. - losses_hat[i]).powi(3) * (1. - rtt_hat[i]).powi(2) * weights_hat[i]) - .enumerate() - .collect::>(); - - scores.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap()); - scores[..self.active_chains.get()].to_vec() - } -} -#[derive(Debug, Error)] -pub enum ProxyTableError { - #[error("Zero accumulated weight with chains")] - ZeroAccumulatedWeight, - #[error("The number of active chains is more than the number of chains")] - TooManyActiveChains, -} - -fn normalize(list: &[Option]) -> Vec { - let sum_some: f64 = list.iter().map(|x| x.unwrap_or(0.)).sum(); - let count_some = list.iter().map(|x| x.map(|_| 1).unwrap_or(0)).sum(); - let hat = match count_some { - 0 => { - let hat_mean = 1. / list.len() as f64; - (0..list.len()).map(|_| hat_mean).collect::>() - } - _ => { - let mean = sum_some / count_some as f64; - let sum: f64 = list.iter().map(|x| x.unwrap_or(mean)).sum(); - if sum == 0. { - (0..list.len()).map(|_| 0.).collect::>() - } else { - let hat_mean = mean / sum; - list.iter() - .map(|x| x.map(|x| x / sum).unwrap_or(hat_mean)) - .collect::>() - } - } - }; - hat -} - -type ScoreStore = CacheCell; - -#[derive(Debug, Clone)] -struct Scores { - scores: Arc<[(usize, f64)]>, - sum: f64, -} - -#[derive(Debug)] -pub struct WeightedProxyChain { - pub weight: usize, - pub chain: Arc>, - pub payload_crypto: Option, -} - -#[derive(Debug)] -struct GaugedProxyChain { - weighted: WeightedProxyChain, - rtt: Arc>>, - loss: Arc>>, - task_handle: Option>, -} - -impl GaugedProxyChain -where - A: std::fmt::Debug + Display + Clone + Send + Sync + 'static, -{ - pub fn new( - weighted: WeightedProxyChain, - tracer: Option>, - cancellation: CancellationToken, - ) -> Self - where - T: Tracer
+ Send + Sync + 'static, - { - let rtt = Arc::new(RwLock::new(None)); - let loss = Arc::new(RwLock::new(None)); - let task_handle = tracer.map(|tracer| { - spawn_tracer( - tracer, - weighted.chain.clone(), - rtt.clone(), - loss.clone(), - cancellation, - ) - }); - Self { - weighted, - rtt, - loss, - task_handle, - } - } - - pub fn weighted(&self) -> &WeightedProxyChain { - &self.weighted - } - - pub fn rtt(&self) -> Option { - *self.rtt.read().unwrap() - } - - pub fn loss(&self) -> Option { - *self.loss.read().unwrap() - } -} - -fn spawn_tracer( - tracer: Arc, - chain: Arc>, - rtt_store: Arc>>, - loss_store: Arc>>, - cancellation: CancellationToken, -) -> tokio::task::JoinHandle<()> -where - T: Tracer
+ Send + Sync + 'static, - A: Display + Send + Sync + 'static, -{ - tokio::task::spawn(async move { - let mut wave = tokio::task::JoinSet::new(); - while !cancellation.is_cancelled() { - // Spawn tracing tasks - for _ in 0..TRACES_PER_WAVE { - let chain = chain.clone(); - let tracer = tracer.clone(); - wave.spawn(async move { - tokio::time::timeout(RTT_TIMEOUT, tracer.trace_rtt(&chain)).await - }); - tokio::time::sleep(TRACE_BURST_GAP).await; - } - - // Collect RTT - let mut rtt_sum = Duration::from_secs(0); - let mut rtt_count: usize = 0; - while let Some(res) = wave.join_next().await { - let res = match res.unwrap() { - Ok(res) => res, - Err(_) => { - trace!("Trace timeout"); - continue; - } - }; - match res { - Ok(rtt) => { - rtt_sum += rtt; - rtt_count += 1; - } - Err(e) => { - trace!("{:?}", e); - } - } - } - let rtt = if rtt_count == 0 { - None - } else { - Some(rtt_sum / (rtt_count as u32)) - }; - let loss = (TRACES_PER_WAVE - rtt_count) as f64 / TRACES_PER_WAVE as f64; - - // Store RTT - let addresses = DisplayChain(&chain); - info!(%addresses, ?rtt, ?loss, "Traced RTT"); - { - let mut rtt_store = rtt_store.write().unwrap(); - *rtt_store = rtt; - } - { - let mut loss_store = loss_store.write().unwrap(); - *loss_store = Some(loss); - } - - // Sleep - if rtt_count == 0 { - tokio::time::sleep(TRACE_DEAD_INTERVAL).await; - } else { - tokio::time::sleep(TRACE_INTERVAL).await; - } - } - }) -} - -impl Drop for GaugedProxyChain { - fn drop(&mut self) { - if let Some(h) = self.task_handle.as_ref() { - h.abort() - } - } -} - -pub struct DisplayChain<'chain, A>(&'chain ProxyChain); - -impl<'chain, A> Display for DisplayChain<'chain, A> -where - A: Display, -{ - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "[")?; - for (i, c) in self.0.iter().enumerate() { - write!(f, "{}", c.address)?; - if i + 1 != self.0.len() { - write!(f, ",")?; - } - } - write!(f, "]")?; - Ok(()) - } -} - -pub trait Tracer { - type Address; - fn trace_rtt( - &self, - chain: &ProxyChain, - ) -> impl std::future::Future> + Send; -} diff --git a/common/src/proxy_table/mod.rs b/common/src/proxy_table/mod.rs new file mode 100644 index 0000000..20fe4dc --- /dev/null +++ b/common/src/proxy_table/mod.rs @@ -0,0 +1,8 @@ +mod proxy_chain; +mod proxy_group; +mod proxy_server; +mod table; +pub use proxy_chain::*; +pub use proxy_group::*; +pub use proxy_server::*; +pub use table::*; diff --git a/common/src/proxy_table/proxy_chain.rs b/common/src/proxy_table/proxy_chain.rs new file mode 100644 index 0000000..103f80d --- /dev/null +++ b/common/src/proxy_table/proxy_chain.rs @@ -0,0 +1,270 @@ +use std::{ + collections::HashMap, + fmt, + sync::{Arc, RwLock}, + time::Duration, +}; + +use serde::{Deserialize, Serialize}; +use thiserror::Error; +use tokio_util::sync::CancellationToken; +use tracing::{info, trace}; + +use crate::{config::SharableConfig, error::AnyError, header::route::RouteRequest}; + +use super::{AddressString, ProxyConfig, ProxyConfigBuildError, ProxyConfigBuilder}; + +pub const TRACE_INTERVAL: Duration = Duration::from_secs(30); +const TRACE_DEAD_INTERVAL: Duration = Duration::from_secs(60 * 2); +const TRACES_PER_WAVE: usize = 60; +const TRACE_BURST_GAP: Duration = Duration::from_millis(10); +const RTT_TIMEOUT: Duration = Duration::from_secs(60); + +pub type ProxyChain = [ProxyConfig]; + +/// # Panic +/// +/// `nodes` must not be empty. +pub fn convert_proxies_to_header_crypto_pairs( + nodes: &ProxyChain, + destination: Option, +) -> Vec<(RouteRequest, &tokio_chacha20::config::Config)> +where + A: Clone + Sync + Send, +{ + assert!(!nodes.is_empty()); + let mut pairs = (0..nodes.len() - 1) + .map(|i| { + let node = &nodes[i]; + let next_node = &nodes[i + 1]; + let route_req = RouteRequest { + upstream: Some(next_node.address.clone()), + }; + (route_req, &node.header_crypto) + }) + .collect::>(); + let route_req = RouteRequest { + upstream: destination, + }; + pairs.push((route_req, &nodes.last().unwrap().header_crypto)); + pairs +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct WeightedProxyChainBuilder { + pub weight: usize, + pub chain: Vec>>, +} +impl WeightedProxyChainBuilder { + pub fn build( + self, + proxy_server: &HashMap, ProxyConfig>, + ) -> Result, WeightedProxyChainBuildError> + where + AS: AddressString
, + { + let chain = self + .chain + .into_iter() + .map(|c| match c { + SharableConfig::SharingKey(k) => proxy_server + .get(&k) + .cloned() + .ok_or_else(|| WeightedProxyChainBuildError::ProxyServerKeyNotFound(k)), + SharableConfig::Private(c) => c.build().map_err(Into::into), + }) + .collect::, _>>()?; + let mut payload_crypto = None; + for proxy_config in chain.iter() { + let Some(p) = &proxy_config.payload_crypto else { + continue; + }; + if payload_crypto.is_some() { + return Err(WeightedProxyChainBuildError::MultiplePayloadKeys); + } + payload_crypto = Some(p.clone()); + } + Ok(WeightedProxyChain { + weight: self.weight, + chain, + payload_crypto, + }) + } +} +#[derive(Debug, Error)] +pub enum WeightedProxyChainBuildError { + #[error("{0}")] + ProxyServer(#[from] ProxyConfigBuildError), + #[error("Proxy server key not found: {0}")] + ProxyServerKeyNotFound(Arc), + #[error("Multiple payload keys")] + MultiplePayloadKeys, +} + +#[derive(Debug)] +pub struct WeightedProxyChain { + pub weight: usize, + pub chain: Arc>, + pub payload_crypto: Option, +} + +#[derive(Debug)] +pub struct GaugedProxyChain { + weighted: WeightedProxyChain, + rtt: Arc>>, + loss: Arc>>, + task_handle: Option>, +} +impl GaugedProxyChain +where + A: std::fmt::Debug + fmt::Display + Clone + Send + Sync + 'static, +{ + pub fn new( + weighted: WeightedProxyChain, + tracer: Option>, + cancellation: CancellationToken, + ) -> Self + where + T: Tracer
+ Send + Sync + 'static, + { + let rtt = Arc::new(RwLock::new(None)); + let loss = Arc::new(RwLock::new(None)); + let task_handle = tracer.map(|tracer| { + spawn_tracer( + tracer, + weighted.chain.clone(), + rtt.clone(), + loss.clone(), + cancellation, + ) + }); + Self { + weighted, + rtt, + loss, + task_handle, + } + } + + pub fn weighted(&self) -> &WeightedProxyChain { + &self.weighted + } + + pub fn rtt(&self) -> Option { + *self.rtt.read().unwrap() + } + + pub fn loss(&self) -> Option { + *self.loss.read().unwrap() + } +} + +fn spawn_tracer( + tracer: Arc, + chain: Arc>, + rtt_store: Arc>>, + loss_store: Arc>>, + cancellation: CancellationToken, +) -> tokio::task::JoinHandle<()> +where + T: Tracer
+ Send + Sync + 'static, + A: fmt::Display + Send + Sync + 'static, +{ + tokio::task::spawn(async move { + let mut wave = tokio::task::JoinSet::new(); + while !cancellation.is_cancelled() { + // Spawn tracing tasks + for _ in 0..TRACES_PER_WAVE { + let chain = chain.clone(); + let tracer = tracer.clone(); + wave.spawn(async move { + tokio::time::timeout(RTT_TIMEOUT, tracer.trace_rtt(&chain)).await + }); + tokio::time::sleep(TRACE_BURST_GAP).await; + } + + // Collect RTT + let mut rtt_sum = Duration::from_secs(0); + let mut rtt_count: usize = 0; + while let Some(res) = wave.join_next().await { + let res = match res.unwrap() { + Ok(res) => res, + Err(_) => { + trace!("Trace timeout"); + continue; + } + }; + match res { + Ok(rtt) => { + rtt_sum += rtt; + rtt_count += 1; + } + Err(e) => { + trace!("{:?}", e); + } + } + } + let rtt = if rtt_count == 0 { + None + } else { + Some(rtt_sum / (rtt_count as u32)) + }; + let loss = (TRACES_PER_WAVE - rtt_count) as f64 / TRACES_PER_WAVE as f64; + + // Store RTT + let addresses = DisplayChain(&chain); + info!(%addresses, ?rtt, ?loss, "Traced RTT"); + { + let mut rtt_store = rtt_store.write().unwrap(); + *rtt_store = rtt; + } + { + let mut loss_store = loss_store.write().unwrap(); + *loss_store = Some(loss); + } + + // Sleep + if rtt_count == 0 { + tokio::time::sleep(TRACE_DEAD_INTERVAL).await; + } else { + tokio::time::sleep(TRACE_INTERVAL).await; + } + } + }) +} + +impl Drop for GaugedProxyChain { + fn drop(&mut self) { + if let Some(h) = self.task_handle.as_ref() { + h.abort() + } + } +} + +pub struct DisplayChain<'chain, A>(&'chain ProxyChain); + +impl<'chain, A> fmt::Display for DisplayChain<'chain, A> +where + A: fmt::Display, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "[")?; + for (i, c) in self.0.iter().enumerate() { + write!(f, "{}", c.address)?; + if i + 1 != self.0.len() { + write!(f, ",")?; + } + } + write!(f, "]")?; + Ok(()) + } +} + +pub trait Tracer { + type Address; + fn trace_rtt( + &self, + chain: &ProxyChain, + ) -> impl std::future::Future> + Send; +} diff --git a/common/src/proxy_table/proxy_group.rs b/common/src/proxy_table/proxy_group.rs new file mode 100644 index 0000000..5eef0fd --- /dev/null +++ b/common/src/proxy_table/proxy_group.rs @@ -0,0 +1,232 @@ +use std::{ + collections::HashMap, + fmt, + num::NonZeroUsize, + sync::{Arc, RwLock}, +}; + +use rand::Rng; +use serde::{Deserialize, Serialize}; +use thiserror::Error; +use tokio_util::sync::CancellationToken; +use tracing::info; + +use crate::cache_cell::CacheCell; + +use super::{ + AddressString, GaugedProxyChain, ProxyConfig, Tracer, WeightedProxyChain, + WeightedProxyChainBuildError, WeightedProxyChainBuilder, TRACE_INTERVAL, +}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct ProxyGroupBuilder { + pub chains: Vec>, + pub trace_rtt: bool, + pub active_chains: Option, +} +impl ProxyGroupBuilder { + pub fn build( + self, + cx: ProxyGroupBuildContext<'_, A, TB>, + ) -> Result, ProxyGroupBuildError> + where + A: std::fmt::Debug + fmt::Display + Clone + Send + Sync + 'static, + AS: AddressString
, + TB: TracerBuilder, + T: Tracer
+ Sync + Send + 'static, + { + let chains = self + .chains + .into_iter() + .map(|c| c.build(cx.proxy_server)) + .collect::>() + .map_err(ProxyGroupBuildError::ChainConfig)?; + let tracer = match self.trace_rtt { + true => Some(cx.tracer_builder.build()), + false => None, + }; + Ok(ProxyGroup::new( + chains, + tracer, + self.active_chains, + cx.cancellation, + )?) + } +} +#[derive(Debug, Error)] +pub enum ProxyGroupBuildError { + #[error("Chain config is invalid: {0}")] + ChainConfig(#[source] WeightedProxyChainBuildError), + #[error("{0}")] + ProxyGroup(#[from] ProxyGroupError), +} +#[derive(Debug)] +pub struct ProxyGroupBuildContext<'caller, A, TB> { + pub proxy_server: &'caller HashMap, ProxyConfig>, + pub tracer_builder: &'caller TB, + pub cancellation: CancellationToken, +} +impl<'caller, A, TB> Clone for ProxyGroupBuildContext<'caller, A, TB> { + fn clone(&self) -> Self { + Self { + proxy_server: self.proxy_server, + tracer_builder: self.tracer_builder, + cancellation: self.cancellation.clone(), + } + } +} + +pub trait TracerBuilder { + type Tracer: Tracer + Send + Sync + 'static; + fn build(&self) -> Self::Tracer; +} + +#[derive(Debug, Clone)] +pub struct ProxyGroup { + chains: Arc<[GaugedProxyChain]>, + cum_weight: NonZeroUsize, + score_store: Arc>, + active_chains: NonZeroUsize, +} +impl ProxyGroup +where + A: std::fmt::Debug + fmt::Display + Clone + Send + Sync + 'static, +{ + pub fn new( + chains: Vec>, + tracer: Option, + active_chains: Option, + cancellation: CancellationToken, + ) -> Result + where + T: Tracer
+ Send + Sync + 'static, + { + let cum_weight = chains.iter().map(|c| c.weight).sum(); + if cum_weight == 0 { + return Err(ProxyGroupError::ZeroAccumulatedWeight); + } + let cum_weight = NonZeroUsize::new(cum_weight).unwrap(); + + let active_chains = match active_chains { + Some(active_chains) => { + if active_chains.get() > chains.len() { + return Err(ProxyGroupError::TooManyActiveChains); + } + active_chains + } + None => NonZeroUsize::new(chains.len()).unwrap(), + }; + + let tracer = tracer.map(Arc::new); + let chains = chains + .into_iter() + .map(|c| GaugedProxyChain::new(c, tracer.clone(), cancellation.clone())) + .collect::>(); + let score_store = Arc::new(RwLock::new(ScoreStore::new(None, TRACE_INTERVAL))); + Ok(Self { + chains, + cum_weight, + score_store, + active_chains, + }) + } + + pub fn choose_chain(&self) -> &WeightedProxyChain { + if self.chains.len() == 1 { + return self.chains[0].weighted(); + } + + let scores = self.score_store.read().unwrap().get().cloned(); + let scores = match scores { + Some(scores) => scores, + None => { + let scores: Arc<[_]> = self.scores().into(); + info!(?scores, "Calculated scores"); + let sum = scores.iter().map(|(_, s)| *s).sum::(); + let scores = Scores { scores, sum }; + self.score_store.write().unwrap().set(scores.clone()); + scores + } + }; + + let mut rng = rand::thread_rng(); + if scores.sum == 0. { + let i = rng.gen_range(0..self.chains.len()); + return self.chains[i].weighted(); + } + let mut rand_score = rng.gen_range(0. ..scores.sum); + for &(i, score) in scores.scores.iter() { + if rand_score < score { + return self.chains[i].weighted(); + } + rand_score -= score; + } + unreachable!(); + } + + fn scores(&self) -> Vec<(usize, f64)> { + let weights_hat = self + .chains + .iter() + .map(|c| c.weighted().weight as f64 / self.cum_weight.get() as f64) + .collect::>(); + + let rtt = self + .chains + .iter() + .map(|c| c.rtt().map(|r| r.as_secs_f64())) + .collect::>(); + let rtt_hat = normalize(&rtt); + + let losses = self.chains.iter().map(|c| c.loss()).collect::>(); + let losses_hat = normalize(&losses); + + let mut scores = (0..self.chains.len()) + .map(|i| (1. - losses_hat[i]).powi(3) * (1. - rtt_hat[i]).powi(2) * weights_hat[i]) + .enumerate() + .collect::>(); + + scores.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap()); + scores[..self.active_chains.get()].to_vec() + } +} +#[derive(Debug, Error, Clone)] +pub enum ProxyGroupError { + #[error("Zero accumulated weight with chains")] + ZeroAccumulatedWeight, + #[error("The number of active chains is more than the number of chains")] + TooManyActiveChains, +} + +type ScoreStore = CacheCell; + +#[derive(Debug, Clone)] +struct Scores { + scores: Arc<[(usize, f64)]>, + sum: f64, +} + +fn normalize(list: &[Option]) -> Vec { + let sum_some: f64 = list.iter().map(|x| x.unwrap_or(0.)).sum(); + let count_some = list.iter().map(|x| x.map(|_| 1).unwrap_or(0)).sum(); + let hat = match count_some { + 0 => { + let hat_mean = 1. / list.len() as f64; + (0..list.len()).map(|_| hat_mean).collect::>() + } + _ => { + let mean = sum_some / count_some as f64; + let sum: f64 = list.iter().map(|x| x.unwrap_or(mean)).sum(); + if sum == 0. { + (0..list.len()).map(|_| 0.).collect::>() + } else { + let hat_mean = mean / sum; + list.iter() + .map(|x| x.map(|x| x / sum).unwrap_or(hat_mean)) + .collect::>() + } + } + }; + hat +} diff --git a/common/src/proxy_table/proxy_server.rs b/common/src/proxy_table/proxy_server.rs new file mode 100644 index 0000000..0edad3a --- /dev/null +++ b/common/src/proxy_table/proxy_server.rs @@ -0,0 +1,42 @@ +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use thiserror::Error; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct ProxyConfigBuilder { + pub address: AS, + pub header_key: tokio_chacha20::config::ConfigBuilder, + pub payload_key: Option, +} +impl ProxyConfigBuilder { + pub fn build(self) -> Result, ProxyConfigBuildError> + where + AS: AddressString
, + { + let header_crypto = self.header_key.build()?; + let payload_crypto = self.payload_key.map(|p| p.build()).transpose()?; + let address = self.address.into_address(); + Ok(ProxyConfig { + address, + header_crypto, + payload_crypto, + }) + } +} +#[derive(Debug, Error)] +pub enum ProxyConfigBuildError { + #[error("{0}")] + Crypto(#[from] tokio_chacha20::config::ConfigBuildError), +} + +pub trait AddressString: Serialize + DeserializeOwned { + type Address; + fn into_address(self) -> Self::Address; +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize, Serialize)] +pub struct ProxyConfig { + pub address: A, + pub header_crypto: tokio_chacha20::config::Config, + pub payload_crypto: Option, +} diff --git a/common/src/proxy_table/table.rs b/common/src/proxy_table/table.rs new file mode 100644 index 0000000..63741a1 --- /dev/null +++ b/common/src/proxy_table/table.rs @@ -0,0 +1,193 @@ +use std::{collections::HashMap, fmt, sync::Arc}; + +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +use crate::{ + addr::InternetAddr, + config::SharableConfig, + filter::{Matcher, MatcherBuilder}, +}; + +use super::{ + AddressString, ProxyConfigBuildError, ProxyGroup, ProxyGroupBuildContext, ProxyGroupBuildError, + ProxyGroupBuilder, Tracer, TracerBuilder, +}; + +#[derive(Debug)] +pub struct ProxyTableBuildContext<'caller, A, TB> { + pub matcher: &'caller HashMap, Matcher>, + pub proxy_group: &'caller HashMap, ProxyGroup>, + pub proxy_group_cx: ProxyGroupBuildContext<'caller, A, TB>, +} +impl<'caller, A, TB> Clone for ProxyTableBuildContext<'caller, A, TB> { + fn clone(&self) -> Self { + Self { + matcher: self.matcher, + proxy_group: self.proxy_group, + proxy_group_cx: self.proxy_group_cx.clone(), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +#[serde(transparent)] +pub struct ProxyTableBuilder { + #[serde(flatten)] + pub entries: Vec>, +} +impl ProxyTableBuilder { + pub fn build( + self, + cx: ProxyTableBuildContext<'_, A, TB>, + ) -> Result, ProxyTableBuildError> + where + A: std::fmt::Debug + fmt::Display + Clone + Send + Sync + 'static, + AS: AddressString
, + TB: TracerBuilder, + T: Tracer
+ Sync + Send + 'static, + { + let mut built = vec![]; + for entry in self.entries { + let e = entry.build(cx.clone())?; + built.push(e); + } + Ok(ProxyTable::new(built)) + } +} + +#[derive(Debug, Clone)] +pub struct ProxyTable { + entries: Vec>, +} +impl ProxyTable +where + A: std::fmt::Debug + fmt::Display + Clone + Send + Sync + 'static, +{ + const BLOCK_ACTION: ProxyAction = ProxyAction::Block; + + pub fn new(entries: Vec>) -> Self { + Self { entries } + } + + pub fn action(&self, addr: &InternetAddr) -> &ProxyAction { + self.entries + .iter() + .find(|&entry| entry.matcher().matches(addr)) + .map(|entry| entry.action()) + .unwrap_or(&Self::BLOCK_ACTION) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct ProxyTableEntryBuilder { + matcher: SharableConfig, + action: ProxyActionBuilder, +} +impl ProxyTableEntryBuilder { + pub fn build( + self, + cx: ProxyTableBuildContext<'_, A, TB>, + ) -> Result, ProxyTableBuildError> + where + A: std::fmt::Debug + fmt::Display + Clone + Send + Sync + 'static, + AS: AddressString
, + TB: TracerBuilder, + T: Tracer
+ Sync + Send + 'static, + { + let matcher = match self.matcher { + SharableConfig::SharingKey(k) => cx + .matcher + .get(&k) + .cloned() + .ok_or_else(|| ProxyTableBuildError::ProxyGroupKeyNotFound(k))?, + SharableConfig::Private(v) => v.build().map_err(ProxyTableBuildError::Matcher)?, + }; + let action = self.action.build(cx.proxy_group, cx.proxy_group_cx)?; + Ok(ProxyTableEntry::new(matcher, action)) + } +} + +#[derive(Debug, Clone)] +pub struct ProxyTableEntry { + matcher: Matcher, + action: ProxyAction, +} +impl ProxyTableEntry +where + A: std::fmt::Debug + fmt::Display + Clone + Send + Sync + 'static, +{ + pub fn new(matcher: Matcher, action: ProxyAction) -> Self { + Self { matcher, action } + } + + pub fn matcher(&self) -> &Matcher { + &self.matcher + } + + pub fn action(&self) -> &ProxyAction { + &self.action + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +#[serde(rename_all = "snake_case")] +pub enum ProxyActionTagBuilder { + Direct, + Block, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +#[serde(untagged)] +pub enum ProxyActionBuilder { + Tagged(ProxyActionTagBuilder), + ProxyGroup(SharableConfig>), +} +impl ProxyActionBuilder { + pub fn build( + self, + proxy_group: &HashMap, ProxyGroup>, + proxy_group_cx: ProxyGroupBuildContext<'_, A, TB>, + ) -> Result, ProxyTableBuildError> + where + A: std::fmt::Debug + fmt::Display + Clone + Send + Sync + 'static, + AS: AddressString
, + TB: TracerBuilder, + T: Tracer
+ Sync + Send + 'static, + { + Ok(match self { + ProxyActionBuilder::Tagged(ProxyActionTagBuilder::Direct) => ProxyAction::Direct, + ProxyActionBuilder::Tagged(ProxyActionTagBuilder::Block) => ProxyAction::Block, + ProxyActionBuilder::ProxyGroup(p) => ProxyAction::ProxyGroup(Arc::new(match p { + SharableConfig::SharingKey(k) => proxy_group + .get(&k) + .cloned() + .ok_or_else(|| ProxyTableBuildError::ProxyGroupKeyNotFound(k))?, + SharableConfig::Private(p) => p.build(proxy_group_cx)?, + })), + }) + } +} + +#[derive(Debug, Clone)] +pub enum ProxyAction { + Direct, + Block, + ProxyGroup(Arc>), +} + +#[derive(Debug, Error)] +pub enum ProxyTableBuildError { + #[error("Proxy group key not found: `{0}`")] + ProxyGroupKeyNotFound(Arc), + #[error("Matcher: {0}")] + Matcher(#[source] regex::Error), + #[error("Chain config is invalid: {0}")] + ChainConfig(#[source] ProxyConfigBuildError), + #[error("{0}")] + ProxyGroup(#[from] ProxyGroupBuildError), +} diff --git a/common/src/stream/addr.rs b/common/src/stream/addr.rs index 715c914..1c38317 100644 --- a/common/src/stream/addr.rs +++ b/common/src/stream/addr.rs @@ -24,7 +24,6 @@ pub trait StreamType: pub struct StreamAddrBuilder { pub address: Arc, } - impl StreamAddrBuilder { pub fn build(self) -> Result, ParseInternetAddrError> { self.address.as_ref().parse() @@ -37,13 +36,11 @@ pub struct StreamAddr { pub address: InternetAddr, pub stream_type: ST, } - impl Display for StreamAddr { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}://{}", self.stream_type, self.address) } } - impl> FromStr for StreamAddr { type Err = ParseInternetAddrError; diff --git a/common/src/stream/pool.rs b/common/src/stream/pool.rs index 31d1cd1..d461d63 100644 --- a/common/src/stream/pool.rs +++ b/common/src/stream/pool.rs @@ -11,18 +11,14 @@ use tokio_conn_pool::{ConnPool, ConnPoolEntry}; use crate::{ config::{Merge, SharableConfig}, header::heartbeat::send_noop, - proxy_table::ProxyConfig, - stream::{ - proxy_table::{StreamProxyConfigBuildError, StreamProxyConfigBuilder}, - IoAddr, - }, + proxy_table::{AddressString, ProxyConfig, ProxyConfigBuildError, ProxyConfigBuilder}, + stream::IoAddr, }; use super::{ - addr::{StreamAddr, StreamAddrStr, StreamType}, + addr::{StreamAddr, StreamType}, connect::StreamConnectorTable, context::StreamContext, - proxy_table::StreamProxyConfig, IoStream, }; @@ -30,26 +26,24 @@ const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(30); #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(deny_unknown_fields)] -#[serde(bound(deserialize = "SAS: Deserialize<'de>"))] -pub struct PoolBuilder( - #[serde(default)] pub Vec>>, -); -impl PoolBuilder { +#[serde(bound(deserialize = "AS: Deserialize<'de>"))] +pub struct PoolBuilder(#[serde(default)] pub Vec>>); +impl PoolBuilder { pub fn new() -> Self { Self(vec![]) } } -impl PoolBuilder +impl PoolBuilder where - SAS: StreamAddrStr, - ST: StreamType, + AS: AddressString
>, { pub fn build( self, connector_table: CT, - stream_proxy: &HashMap, StreamProxyConfig>, + proxy_server: &HashMap, ProxyConfig>>, ) -> Result, C>, PoolBuildError> where + ST: StreamType + Clone, C: std::fmt::Debug + IoStream, CT: StreamConnectorTable, { @@ -57,13 +51,11 @@ where .0 .into_iter() .map(|c| match c { - SharableConfig::SharingKey(k) => stream_proxy + SharableConfig::SharingKey(k) => proxy_server .get(&k) .cloned() - .ok_or_else(|| PoolBuildError::KeyNotFound(k)), - SharableConfig::Private(c) => c - .build::() - .map_err(PoolBuildError::StreamProxyConfigBuild), + .ok_or_else(|| PoolBuildError::ProxyServerKeyNotFound(k)), + SharableConfig::Private(c) => c.build().map_err(PoolBuildError::ProxyConfigBuild), }) .collect::, _>>()?; let entries = pool_entries_from_proxy_configs(c.into_iter(), connector_table.clone()); @@ -74,9 +66,9 @@ where #[derive(Debug, Error)] pub enum PoolBuildError { #[error("{0}")] - StreamProxyConfigBuild(#[from] StreamProxyConfigBuildError), - #[error("Key not found: {0}")] - KeyNotFound(Arc), + ProxyConfigBuild(#[from] ProxyConfigBuildError), + #[error("Proxy server key not found: {0}")] + ProxyServerKeyNotFound(Arc), } impl Default for PoolBuilder { fn default() -> Self { @@ -160,7 +152,7 @@ where send_noop( &mut conn, HEARTBEAT_INTERVAL, - &self.proxy_config.crypto.clone(), + &self.proxy_config.header_crypto.clone(), ) .await .ok()?; diff --git a/common/src/stream/proxy_table.rs b/common/src/stream/proxy_table.rs index ab2216e..10c94d6 100644 --- a/common/src/stream/proxy_table.rs +++ b/common/src/stream/proxy_table.rs @@ -1,82 +1,99 @@ -use std::{collections::HashMap, sync::Arc}; +// use std::{collections::HashMap, sync::Arc}; -use serde::{Deserialize, Serialize}; -use thiserror::Error; +// use serde::{Deserialize, Serialize}; +// use thiserror::Error; -use crate::{ - config::SharableConfig, - proxy_table::{ProxyConfig, ProxyTable, ProxyTableGroup, WeightedProxyChain}, -}; +// use crate::{ +// config::SharableConfig, +// proxy_table::{ +// ProxyAction, ProxyConfig, ProxyGroup, ProxyTable, ProxyTableEntry, WeightedProxyChain, +// }, +// }; -use super::{addr::StreamAddrStr, StreamAddr}; +// use super::{addr::StreamAddrStr, StreamAddr}; -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(deny_unknown_fields)] -pub struct StreamProxyConfigBuilder { - pub address: SAS, - pub header_key: tokio_chacha20::config::ConfigBuilder, -} +// #[derive(Debug, Clone, Serialize, Deserialize)] +// #[serde(deny_unknown_fields)] +// pub struct StreamProxyConfigBuilder { +// pub address: SAS, +// pub header_key: tokio_chacha20::config::ConfigBuilder, +// pub payload_key: Option, +// } -impl StreamProxyConfigBuilder { - pub fn build(self) -> Result, StreamProxyConfigBuildError> - where - SAS: StreamAddrStr, - { - let crypto = self.header_key.build()?; - let address = self.address.into_inner(); - Ok(ProxyConfig { address, crypto }) - } -} +// impl StreamProxyConfigBuilder { +// pub fn build(self) -> Result, StreamProxyConfigBuildError> +// where +// SAS: StreamAddrStr, +// { +// let header_crypto = self.header_key.build()?; +// let payload_crypto = self.payload_key.map(|p| p.build()).transpose()?; +// let address = self.address.into_inner(); +// Ok(ProxyConfig { +// address, +// header_crypto, +// payload_crypto, +// }) +// } +// } -#[derive(Debug, Error)] -pub enum StreamProxyConfigBuildError { - #[error("{0}")] - Crypto(#[from] tokio_chacha20::config::ConfigBuildError), - #[error("Key not found: {0}")] - KeyNotFound(Arc), -} +// #[derive(Debug, Error)] +// pub enum StreamProxyConfigBuildError { +// #[error("{0}")] +// Crypto(#[from] tokio_chacha20::config::ConfigBuildError), +// #[error("Key not found: {0}")] +// KeyNotFound(Arc), +// #[error("Multiple payload keys")] +// MultiplePayloadKeys, +// } -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(deny_unknown_fields)] -pub struct StreamWeightedProxyChainBuilder { - pub weight: usize, - pub chain: Vec>>, - pub payload_key: Option, -} +// #[derive(Debug, Clone, Serialize, Deserialize)] +// #[serde(deny_unknown_fields)] +// pub struct StreamWeightedProxyChainBuilder { +// pub weight: usize, +// pub chain: Vec>>, +// } -impl StreamWeightedProxyChainBuilder { - pub fn build( - self, - stream_proxy: &HashMap, StreamProxyConfig>, - ) -> Result, StreamProxyConfigBuildError> - where - SAS: StreamAddrStr, - { - let payload_crypto = match self.payload_key { - Some(c) => Some(c.build()?), - None => None, - }; - let chain = self - .chain - .into_iter() - .map(|c| match c { - SharableConfig::SharingKey(k) => stream_proxy - .get(&k) - .cloned() - .ok_or_else(|| StreamProxyConfigBuildError::KeyNotFound(k)), - SharableConfig::Private(c) => c.build(), - }) - .collect::>()?; - Ok(WeightedProxyChain { - weight: self.weight, - chain, - payload_crypto, - }) - } -} +// impl StreamWeightedProxyChainBuilder { +// pub fn build( +// self, +// proxy_server: &HashMap, StreamProxyConfig>, +// ) -> Result, StreamProxyConfigBuildError> +// where +// SAS: StreamAddrStr, +// { +// let chain = self +// .chain +// .into_iter() +// .map(|c| match c { +// SharableConfig::SharingKey(k) => proxy_server +// .get(&k) +// .cloned() +// .ok_or_else(|| StreamProxyConfigBuildError::KeyNotFound(k)), +// SharableConfig::Private(c) => c.build(), +// }) +// .collect::, _>>()?; +// let mut payload_crypto = None; +// for proxy_config in chain.iter() { +// let Some(p) = &proxy_config.payload_crypto else { +// continue; +// }; +// if payload_crypto.is_some() { +// return Err(StreamProxyConfigBuildError::MultiplePayloadKeys); +// } +// payload_crypto = Some(p.clone()); +// } +// Ok(WeightedProxyChain { +// weight: self.weight, +// chain, +// payload_crypto, +// }) +// } +// } -pub type StreamProxyConfig = ProxyConfig>; -pub type StreamProxyChain = [StreamProxyConfig]; -pub type StreamWeightedProxyChain = WeightedProxyChain>; -pub type StreamProxyTable = ProxyTable>; -pub type StreamProxyTableGroup = ProxyTableGroup>; +// pub type StreamProxyConfig = ProxyConfig>; +// pub type StreamProxyChain = [StreamProxyConfig]; +// pub type StreamWeightedProxyChain = WeightedProxyChain>; +// pub type StreamProxyTable = ProxyTable>; +// pub type StreamProxyTableEntry = ProxyTableEntry>; +// pub type StreamProxyTableEntryAction = ProxyAction>; +// pub type StreamProxyGroup = ProxyGroup>; diff --git a/common/src/udp/proxy_table.rs b/common/src/udp/proxy_table.rs index 27013c4..c188a99 100644 --- a/common/src/udp/proxy_table.rs +++ b/common/src/udp/proxy_table.rs @@ -1,75 +1,18 @@ -use std::{collections::HashMap, sync::Arc}; - -use serde::{Deserialize, Serialize}; -use thiserror::Error; - use crate::{ addr::{InternetAddr, InternetAddrStr}, - config::SharableConfig, - proxy_table::{ProxyConfig, ProxyTable, ProxyTableGroup, WeightedProxyChain}, + proxy_table::{ + ProxyAction, ProxyConfig, ProxyConfigBuilder, ProxyGroup, ProxyGroupBuilder, ProxyTable, + ProxyTableBuilder, ProxyTableEntry, WeightedProxyChain, + }, }; -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(deny_unknown_fields)] -pub struct UdpProxyConfigBuilder { - pub address: InternetAddrStr, - pub header_key: tokio_chacha20::config::ConfigBuilder, -} - -impl UdpProxyConfigBuilder { - pub fn build(self) -> Result { - let crypto = self.header_key.build()?; - let address = self.address.0; - Ok(ProxyConfig { address, crypto }) - } -} - -#[derive(Debug, Error)] -pub enum UdpProxyConfigBuildError { - #[error("{0}")] - Crypto(#[from] tokio_chacha20::config::ConfigBuildError), - #[error("Key not found: {0}")] - KeyNotFound(Arc), -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(deny_unknown_fields)] -pub struct UdpWeightedProxyChainBuilder { - pub weight: usize, - pub chain: Vec>, - pub payload_key: Option, -} - -impl UdpWeightedProxyChainBuilder { - pub fn build( - self, - udp_proxy: &HashMap, UdpProxyConfig>, - ) -> Result { - let payload_crypto = match self.payload_key { - Some(c) => Some(c.build()?), - None => None, - }; - let chain = self - .chain - .into_iter() - .map(|c| match c { - SharableConfig::SharingKey(k) => udp_proxy - .get(&k) - .cloned() - .ok_or_else(|| UdpProxyConfigBuildError::KeyNotFound(k)), - SharableConfig::Private(c) => c.build(), - }) - .collect::>()?; - Ok(WeightedProxyChain { - weight: self.weight, - chain, - payload_crypto, - }) - } -} - +pub type UdpProxyConfigBuilder = ProxyConfigBuilder; pub type UdpProxyConfig = ProxyConfig; pub type UdpProxyChain = [UdpProxyConfig]; pub type UdpWeightedProxyChain = WeightedProxyChain; pub type UdpProxyTable = ProxyTable; -pub type UdpProxyTableGroup = ProxyTableGroup; +pub type UdpProxyTableEntry = ProxyTableEntry; +pub type UdpProxyTableEntryAction = ProxyAction; +pub type UdpProxyGroup = ProxyGroup; +pub type UdpProxyTableBuilder = ProxyTableBuilder; +pub type UdpProxyGroupBuilder = ProxyGroupBuilder; diff --git a/protocol/src/stream/addr.rs b/protocol/src/stream/addr.rs index f74d5f5..bc9417b 100644 --- a/protocol/src/stream/addr.rs +++ b/protocol/src/stream/addr.rs @@ -4,7 +4,8 @@ use serde::{de::Visitor, Deserialize, Serialize}; use common::{ addr::ParseInternetAddrError, - stream::addr::{StreamAddr, StreamAddrStr, StreamType}, + proxy_table::AddressString, + stream::addr::{StreamAddr, StreamType}, }; pub type ConcreteStreamAddr = StreamAddr; @@ -41,12 +42,10 @@ impl FromStr for ConcreteStreamType { #[derive(Debug, Clone)] pub struct ConcreteStreamAddrStr(pub ConcreteStreamAddr); -impl StreamAddrStr for ConcreteStreamAddrStr { - type StreamType = ConcreteStreamType; - fn inner(&self) -> &StreamAddr { - &self.0 - } - fn into_inner(self) -> StreamAddr { +impl AddressString for ConcreteStreamAddrStr { + type Address = ConcreteStreamAddr; + + fn into_address(self) -> Self::Address { self.0 } } diff --git a/protocol/src/stream/mod.rs b/protocol/src/stream/mod.rs index 44cf070..76da2b6 100644 --- a/protocol/src/stream/mod.rs +++ b/protocol/src/stream/mod.rs @@ -3,4 +3,5 @@ pub mod connect; pub mod connection; pub mod context; pub mod pool; +pub mod proxy_table; pub mod streams; diff --git a/protocol/src/stream/proxy_table.rs b/protocol/src/stream/proxy_table.rs new file mode 100644 index 0000000..d5be96a --- /dev/null +++ b/protocol/src/stream/proxy_table.rs @@ -0,0 +1,17 @@ +use common::proxy_table::{ + ProxyAction, ProxyConfig, ProxyConfigBuilder, ProxyGroup, ProxyGroupBuilder, ProxyTable, + ProxyTableBuilder, ProxyTableEntry, WeightedProxyChain, +}; + +use super::addr::{ConcreteStreamAddr, ConcreteStreamAddrStr}; + +pub type StreamProxyConfigBuilder = ProxyConfigBuilder; +pub type StreamProxyConfig = ProxyConfig; +pub type StreamProxyChain = [StreamProxyConfig]; +pub type StreamWeightedProxyChain = WeightedProxyChain; +pub type StreamProxyTable = ProxyTable; +pub type StreamProxyTableEntry = ProxyTableEntry; +pub type StreamProxyTableEntryAction = ProxyAction; +pub type StreamProxyGroup = ProxyGroup; +pub type StreamProxyTableBuilder = ProxyTableBuilder; +pub type StreamProxyGroupBuilder = ProxyGroupBuilder; diff --git a/proxy_client/src/stream.rs b/proxy_client/src/stream.rs index 008244b..a974166 100644 --- a/proxy_client/src/stream.rs +++ b/proxy_client/src/stream.rs @@ -7,14 +7,12 @@ use common::{ heartbeat::{self, HeartbeatError}, route::{RouteError, RouteResponse}, }, - proxy_table::{convert_proxies_to_header_crypto_pairs, Tracer}, - stream::{pool::connect_with_pool, proxy_table::StreamProxyChain}, + proxy_table::{convert_proxies_to_header_crypto_pairs, ProxyChain, Tracer, TracerBuilder}, + stream::pool::connect_with_pool, }; use metrics::counter; use protocol::stream::{ - addr::{ConcreteStreamAddr, ConcreteStreamType}, - connection::ConnAndAddr, - context::ConcreteStreamContext, + addr::ConcreteStreamAddr, connection::ConnAndAddr, context::ConcreteStreamContext, pool::ConcreteConnectError, }; use thiserror::Error; @@ -24,7 +22,7 @@ const IO_TIMEOUT: Duration = Duration::from_secs(60); #[instrument(skip(proxies, stream_context))] pub async fn establish( - proxies: &StreamProxyChain, + proxies: &ProxyChain, destination: ConcreteStreamAddr, stream_context: &ConcreteStreamContext, ) -> Result { @@ -118,6 +116,23 @@ pub enum StreamEstablishError { }, } +#[derive(Debug, Clone)] +pub struct StreamTracerBuilder { + stream_context: ConcreteStreamContext, +} +impl StreamTracerBuilder { + pub fn new(stream_context: ConcreteStreamContext) -> Self { + Self { stream_context } + } +} +impl TracerBuilder for StreamTracerBuilder { + type Tracer = StreamTracer; + + fn build(&self) -> Self::Tracer { + StreamTracer::new(self.stream_context.clone()) + } +} + #[derive(Debug, Clone)] pub struct StreamTracer { stream_context: ConcreteStreamContext, @@ -132,7 +147,7 @@ impl Tracer for StreamTracer { async fn trace_rtt( &self, - chain: &StreamProxyChain, + chain: &ProxyChain, ) -> Result { trace_rtt(chain, &self.stream_context) .await @@ -141,7 +156,7 @@ impl Tracer for StreamTracer { } pub async fn trace_rtt( - proxies: &StreamProxyChain, + proxies: &ProxyChain, stream_context: &ConcreteStreamContext, ) -> Result { if proxies.is_empty() { diff --git a/proxy_client/src/udp.rs b/proxy_client/src/udp.rs index 404f76f..267ae28 100644 --- a/proxy_client/src/udp.rs +++ b/proxy_client/src/udp.rs @@ -12,7 +12,7 @@ use common::{ codec::{read_header, write_header, CodecError}, route::{RouteError, RouteResponse}, }, - proxy_table::{convert_proxies_to_header_crypto_pairs, Tracer}, + proxy_table::{convert_proxies_to_header_crypto_pairs, Tracer, TracerBuilder}, udp::{ io_copy::{UdpRecv, UdpSend}, proxy_table::UdpProxyChain, @@ -259,7 +259,8 @@ impl UdpProxyClientReadHalf { // Decode and check headers for node in self.proxies.iter() { trace!(?node.address, "Reading response"); - let mut crypto_cursor = tokio_chacha20::cursor::DecryptCursor::new(*node.crypto.key()); + let mut crypto_cursor = + tokio_chacha20::cursor::DecryptCursor::new(*node.header_crypto.key()); let resp: RouteResponse = read_header(&mut reader, &mut crypto_cursor)?; if let Err(err) = resp.result { warn!(?err, %node.address, "Upstream responded with an error"); @@ -296,21 +297,37 @@ pub enum RecvError { Response { err: RouteError, addr: InternetAddr }, } +pub struct UdpTracerBuilder {} +impl UdpTracerBuilder { + pub fn new() -> Self { + Self {} + } +} +impl Default for UdpTracerBuilder { + fn default() -> Self { + Self::new() + } +} +impl TracerBuilder for UdpTracerBuilder { + type Tracer = UdpTracer; + + fn build(&self) -> Self::Tracer { + UdpTracer::new() + } +} + #[derive(Debug, Clone)] pub struct UdpTracer {} - impl UdpTracer { pub fn new() -> Self { Self {} } } - impl Default for UdpTracer { fn default() -> Self { Self::new() } } - impl Tracer for UdpTracer { type Address = InternetAddr; @@ -357,7 +374,8 @@ pub async fn trace_rtt(proxies: &UdpProxyChain) -> Result let mut reader = io::Cursor::new(&buf[..n]); for node in proxies.iter() { trace!(?node.address, "Reading response"); - let mut crypto_cursor = tokio_chacha20::cursor::DecryptCursor::new(*node.crypto.key()); + let mut crypto_cursor = + tokio_chacha20::cursor::DecryptCursor::new(*node.header_crypto.key()); let resp: RouteResponse = read_header(&mut reader, &mut crypto_cursor)?; if let Err(err) = resp.result { warn!(?err, %node.address, "Upstream responded with an error"); diff --git a/proxy_server/src/lib.rs b/proxy_server/src/lib.rs index 8968e4d..57ce1ed 100644 --- a/proxy_server/src/lib.rs +++ b/proxy_server/src/lib.rs @@ -5,7 +5,7 @@ use protocol::context::ConcreteContext; use serde::Deserialize; use stream::{ kcp::KcpProxyServerConfig, mptcp::MptcpProxyServerConfig, tcp::TcpProxyServerConfig, - StreamProxy, + StreamProxyServer, }; use thiserror::Error; use udp::{UdpProxy, UdpProxyServerBuilder, UdpProxyServerConfig}; @@ -112,10 +112,10 @@ impl Merge for ProxyServerConfig { } pub struct ProxyServerLoader { - tcp_server: loading::Loader, + tcp_server: loading::Loader, udp_server: loading::Loader, - kcp_server: loading::Loader, - mptcp_server: loading::Loader, + kcp_server: loading::Loader, + mptcp_server: loading::Loader, } impl ProxyServerLoader { diff --git a/proxy_server/src/stream/kcp.rs b/proxy_server/src/stream/kcp.rs index 22bbff4..ad1311d 100644 --- a/proxy_server/src/stream/kcp.rs +++ b/proxy_server/src/stream/kcp.rs @@ -6,19 +6,23 @@ use protocol::stream::{ streams::kcp::{fast_kcp_config, KcpServer}, }; use serde::Deserialize; +use thiserror::Error; use tokio::net::ToSocketAddrs; use tokio_kcp::KcpListener; use crate::ListenerBindError; -use super::{StreamProxy, StreamProxyBuilder, StreamProxyConfig, StreamProxyServerBuildError}; +use super::{ + StreamProxyServer, StreamProxyServerBuildError, StreamProxyServerBuilder, + StreamProxyServerConfig, +}; #[derive(Debug, Clone, Deserialize)] #[serde(deny_unknown_fields)] pub struct KcpProxyServerConfig { pub listen_addr: Arc, #[serde(flatten)] - pub inner: StreamProxyConfig, + pub inner: StreamProxyServerConfig, } impl KcpProxyServerConfig { @@ -34,13 +38,12 @@ impl KcpProxyServerConfig { #[derive(Debug, Clone)] pub struct KcpProxyServerBuilder { pub listen_addr: Arc, - pub inner: StreamProxyBuilder, + pub inner: StreamProxyServerBuilder, } - impl loading::Builder for KcpProxyServerBuilder { - type Hook = StreamProxy; + type Hook = StreamProxyServer; type Server = KcpServer; - type Err = StreamProxyServerBuildError; + type Err = KcpProxyServerBuildError; async fn build_server(self) -> Result { let listen_addr = self.listen_addr.clone(); @@ -59,10 +62,18 @@ impl loading::Builder for KcpProxyServerBuilder { } } +#[derive(Debug, Error)] +pub enum KcpProxyServerBuildError { + #[error("{0}")] + Hook(#[from] StreamProxyServerBuildError), + #[error("{0}")] + Server(#[from] ListenerBindError), +} + pub async fn build_kcp_proxy_server( listen_addr: impl ToSocketAddrs, - stream_proxy: StreamProxy, -) -> Result, ListenerBindError> { + stream_proxy: StreamProxyServer, +) -> Result, ListenerBindError> { let config = fast_kcp_config(); let listener = KcpListener::bind(config, listen_addr) .await diff --git a/proxy_server/src/stream/mod.rs b/proxy_server/src/stream/mod.rs index 42c7601..e4c0e4c 100644 --- a/proxy_server/src/stream/mod.rs +++ b/proxy_server/src/stream/mod.rs @@ -18,8 +18,6 @@ use serde::Deserialize; use thiserror::Error; use tracing::{error, info, instrument, warn}; -use crate::ListenerBindError; - pub mod kcp; pub mod mptcp; pub mod tcp; @@ -28,14 +26,14 @@ const IO_TIMEOUT: Duration = Duration::from_secs(60); #[derive(Debug, Clone, Deserialize)] #[serde(deny_unknown_fields)] -pub struct StreamProxyConfig { +pub struct StreamProxyServerConfig { pub header_key: tokio_chacha20::config::ConfigBuilder, pub payload_key: Option, } -impl StreamProxyConfig { - pub fn into_builder(self, stream_context: ConcreteStreamContext) -> StreamProxyBuilder { - StreamProxyBuilder { +impl StreamProxyServerConfig { + pub fn into_builder(self, stream_context: ConcreteStreamContext) -> StreamProxyServerBuilder { + StreamProxyServerBuilder { header_key: self.header_key, payload_key: self.payload_key, stream_context, @@ -44,23 +42,26 @@ impl StreamProxyConfig { } #[derive(Debug, Clone)] -pub struct StreamProxyBuilder { +pub struct StreamProxyServerBuilder { pub header_key: tokio_chacha20::config::ConfigBuilder, pub payload_key: Option, pub stream_context: ConcreteStreamContext, } -impl StreamProxyBuilder { - pub fn build(self) -> Result { +impl StreamProxyServerBuilder { + pub fn build(self) -> Result { let header_crypto = self .header_key .build() - .map_err(StreamProxyBuildError::HeaderCrypto)?; + .map_err(StreamProxyServerBuildError::HeaderCrypto)?; let payload_crypto = match self.payload_key { - Some(key) => Some(key.build().map_err(StreamProxyBuildError::PayloadCrypto)?), + Some(key) => Some( + key.build() + .map_err(StreamProxyServerBuildError::PayloadCrypto)?, + ), None => None, }; - Ok(StreamProxy::new( + Ok(StreamProxyServer::new( header_crypto, payload_crypto, self.stream_context, @@ -69,7 +70,7 @@ impl StreamProxyBuilder { } #[derive(Debug, Error)] -pub enum StreamProxyBuildError { +pub enum StreamProxyServerBuildError { #[error("HeaderCrypto: {0}")] HeaderCrypto(#[source] tokio_chacha20::config::ConfigBuildError), #[error("PayloadCrypto: {0}")] @@ -78,22 +79,14 @@ pub enum StreamProxyBuildError { StreamPool(#[from] ParseInternetAddrError), } -#[derive(Debug, Error)] -pub enum StreamProxyServerBuildError { - #[error("{0}")] - Hook(#[from] StreamProxyBuildError), - #[error("{0}")] - Server(#[from] ListenerBindError), -} - #[derive(Debug)] -pub struct StreamProxy { +pub struct StreamProxyServer { acceptor: StreamProxyAcceptor, payload_crypto: Option, stream_context: ConcreteStreamContext, } -impl StreamProxy { +impl StreamProxyServer { pub fn new( header_crypto: tokio_chacha20::config::Config, payload_crypto: Option, @@ -165,10 +158,8 @@ impl StreamProxy { // }); // } } - -impl loading::Hook for StreamProxy {} - -impl StreamServerHook for StreamProxy { +impl loading::Hook for StreamProxyServer {} +impl StreamServerHook for StreamProxyServer { #[instrument(skip(self))] async fn handle_stream(&self, stream: S) where @@ -192,7 +183,6 @@ pub struct StreamProxyAcceptor { crypto: tokio_chacha20::config::Config, stream_context: ConcreteStreamContext, } - impl StreamProxyAcceptor { pub fn new( crypto: tokio_chacha20::config::Config, diff --git a/proxy_server/src/stream/mptcp.rs b/proxy_server/src/stream/mptcp.rs index 9328fab..8bc486f 100644 --- a/proxy_server/src/stream/mptcp.rs +++ b/proxy_server/src/stream/mptcp.rs @@ -10,7 +10,10 @@ use tracing::error; use crate::ListenerBindError; -use super::{StreamProxy, StreamProxyBuildError, StreamProxyBuilder, StreamProxyConfig}; +use super::{ + StreamProxyServer, StreamProxyServerBuildError, StreamProxyServerBuilder, + StreamProxyServerConfig, +}; const MAX_SESSION_STREAMS: usize = 4; @@ -19,7 +22,7 @@ const MAX_SESSION_STREAMS: usize = 4; pub struct MptcpProxyServerConfig { pub listen_addr: Arc, #[serde(flatten)] - pub inner: StreamProxyConfig, + pub inner: StreamProxyServerConfig, } impl MptcpProxyServerConfig { @@ -35,11 +38,11 @@ impl MptcpProxyServerConfig { #[derive(Debug, Clone)] pub struct MptcpProxyServerBuilder { pub listen_addr: Arc, - pub inner: StreamProxyBuilder, + pub inner: StreamProxyServerBuilder, } impl loading::Builder for MptcpProxyServerBuilder { - type Hook = StreamProxy; + type Hook = StreamProxyServer; type Server = MptcpServer; type Err = MptcpProxyServerBuildError; @@ -63,15 +66,15 @@ impl loading::Builder for MptcpProxyServerBuilder { #[derive(Debug, Error)] pub enum MptcpProxyServerBuildError { #[error("{0}")] - Hook(#[from] StreamProxyBuildError), + Hook(#[from] StreamProxyServerBuildError), #[error("{0}")] Server(#[from] ListenerBindError), } pub async fn build_mptcp_proxy_server( listen_addr: impl ToSocketAddrs, - stream_proxy: StreamProxy, -) -> Result, ListenerBindError> { + stream_proxy: StreamProxyServer, +) -> Result, ListenerBindError> { let listener = MptcpListener::bind(listen_addr, NonZeroUsize::new(MAX_SESSION_STREAMS).unwrap()) .await diff --git a/proxy_server/src/stream/tcp.rs b/proxy_server/src/stream/tcp.rs index bef8235..04f993e 100644 --- a/proxy_server/src/stream/tcp.rs +++ b/proxy_server/src/stream/tcp.rs @@ -9,14 +9,17 @@ use tracing::error; use crate::ListenerBindError; -use super::{StreamProxy, StreamProxyBuildError, StreamProxyBuilder, StreamProxyConfig}; +use super::{ + StreamProxyServer, StreamProxyServerBuildError, StreamProxyServerBuilder, + StreamProxyServerConfig, +}; #[derive(Debug, Clone, Deserialize)] #[serde(deny_unknown_fields)] pub struct TcpProxyServerConfig { pub listen_addr: Arc, #[serde(flatten)] - pub inner: StreamProxyConfig, + pub inner: StreamProxyServerConfig, } impl TcpProxyServerConfig { @@ -32,11 +35,11 @@ impl TcpProxyServerConfig { #[derive(Debug, Clone)] pub struct TcpProxyServerBuilder { pub listen_addr: Arc, - pub inner: StreamProxyBuilder, + pub inner: StreamProxyServerBuilder, } impl loading::Builder for TcpProxyServerBuilder { - type Hook = StreamProxy; + type Hook = StreamProxyServer; type Server = TcpServer; type Err = TcpProxyServerBuildError; @@ -60,15 +63,15 @@ impl loading::Builder for TcpProxyServerBuilder { #[derive(Debug, Error)] pub enum TcpProxyServerBuildError { #[error("{0}")] - Hook(#[from] StreamProxyBuildError), + Hook(#[from] StreamProxyServerBuildError), #[error("{0}")] Server(#[from] ListenerBindError), } pub async fn build_tcp_proxy_server( listen_addr: impl ToSocketAddrs, - stream_proxy: StreamProxy, -) -> Result, ListenerBindError> { + stream_proxy: StreamProxyServer, +) -> Result, ListenerBindError> { let listener = TcpListener::bind(listen_addr) .await .map_err(ListenerBindError)?; @@ -101,7 +104,7 @@ mod tests { // Start proxy server let proxy_addr = { - let proxy = StreamProxy::new( + let proxy = StreamProxyServer::new( crypto.clone(), None, ConcreteStreamContext { diff --git a/server/config.toml b/server/config.toml index be31c97..0c2892e 100644 --- a/server/config.toml +++ b/server/config.toml @@ -1,68 +1,62 @@ -[global] -stream_pool = [ - { address = "tcp://proxy.example.org:80", header_key = "base64" }, - { address = "mptcp://proxy.example.org:81", header_key = "base64" }, +[stream.proxy_server] +"tcp1" = { address = "tcp://proxy.example.org:80", header_key = "base64", payload_key = "base64" } +"kcp1" = { address = "kcp://proxy.example.org:80", header_key = "base64", payload_key = "base64" } +"mptcp1" = { address = "mptcp://proxy.example.org:81", header_key = "base64", payload_key = "base64" } + +[stream] +pool = ["tcp1", "mptcp1"] + +[udp.proxy_server] +"udp1" = { address = "proxy.example.org:80", header_key = "base64", payload_key = "base64" } + +[access_server.matcher] +"localhost" = [ + { addr = "127.0.0.1" }, # IPv4 + { addr = "::1" }, # IPv6 + { addr = "localhost" }, # domain name regex ] -[access_server.stream_proxy_tables] -"default" = { groups = [ - { matcher = { }, chains = [ - { weight = 1, chain = [ - { address = "tcp://proxy.example.org:80", header_key = "base64" }, - ], payload_key = "base64" }, - { weight = 1, chain = [ - { address = "kcp://proxy.example.org:80", header_key = "base64" }, - ], payload_key = "base64" }, - { weight = 1, chain = [ - { address = "mptcp://proxy.example.org:81", header_key = "base64" }, - ], payload_key = "base64" }, - ], trace_rtt = false, active_chains = 3 }, -] } -[access_server.udp_proxy_tables] -"default" = { groups = [ - { matcher = { }, chains = [ - { weight = 1, chain = [ - { address = "proxy.example.org:80", header_key = "base64" }, - ], payload_key = "base64" }, - ], trace_rtt = false, active_chains = 1 }, -] } -[access_server.filters] +[access_server.stream.proxy_group] +"default" = { chains = [ + { weight = 1, chain = ["tcp1"] }, + { weight = 1, chain = ["kcp1"] }, + { weight = 1, chain = ["mptcp1"] }, +], trace_rtt = false, active_chains = 3 } + +[access_server.stream.proxy_table] "default" = [ - "localhost", - # { matcher = { port = 443 }, action = "proxy" }, + { matcher = "localhost", action = "default" }, # Forward local traffic directly + # <- At this point, no more local traffic # { matcher = { }, action = "block" }, # Block all + { matcher = { }, action = "default" }, # Forward remaining traffic to the "default" proxy group above ] -"localhost" = [ - { matcher = [ - { addr = "127.0.0.1" }, # IPv4 - { addr = "::1" }, # IPv6 - { addr = "localhost" }, # domain name regex - ], action = "direct" }, -] + +[access_server.udp.proxy_group] +"default" = { chains = [ + { weight = 1, chain = ["udp1"] }, +], trace_rtt = false, active_chains = 1 } [[access_server.tcp_server]] listen_addr = "0.0.0.0:80" destination = "tcp://www.example.org:80" -proxy_table = "default" # Or, `proxy_table = { chains = ..., trace_rtt = ..., ... }` +proxy_group = "default" # Or, `proxy_group = { chains = ..., trace_rtt = ..., ... }` # speed_limit = 1024.0 # 1.0 KiB/s [[access_server.udp_server]] listen_addr = "0.0.0.0:80" destination = "www.example.org:80" -proxy_table = "default" # Or, `proxy_table = { chains = ..., trace_rtt = ..., ... }` +proxy_group = "default" # Or, `proxy_group = { chains = ..., trace_rtt = ..., ... }` # speed_limit = 1024.0 # 1.0 KiB/s [[access_server.http_server]] listen_addr = "0.0.0.0:80" -proxy_table = "default" # Or, `proxy_table = { chains = ..., trace_rtt = ..., ... }` -filter = "default" +proxy_table = "default" # Or, `proxy_table = [...]` # speed_limit = 1024.0 # 1.0 KiB/s [[access_server.socks5_tcp_server]] listen_addr = "0.0.0.0:80" udp_server_addr = "0.0.0.0:80" -proxy_table = "default" # Or, `proxy_table = { chains = ..., trace_rtt = ..., ... }` -filter = "default" +proxy_table = "default" # Or, `proxy_table = [...]` # speed_limit = 1024.0 # 1.0 KiB/s # users = [ # { username = "", password = "" }, @@ -70,7 +64,7 @@ filter = "default" [[access_server.socks5_udp_server]] listen_addr = "0.0.0.0:80" -proxy_table = "default" # Or, `proxy_table = { chains = ..., trace_rtt = ..., ... }` +proxy_group = "default" # Or, `proxy_group = { chains = ..., trace_rtt = ..., ... }` # speed_limit = 1024.0 # 1.0 KiB/s [[proxy_server.tcp_server]] diff --git a/server/src/lib.rs b/server/src/lib.rs index b3edd2b..6407c3f 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -1,11 +1,11 @@ -use std::{collections::HashMap, convert::Infallible, sync::Arc}; +use std::{collections::HashMap, sync::Arc}; use access_server::{AccessServerConfig, AccessServerLoader}; use common::{ config::{merge_map, Merge}, context::Context, error::{AnyError, AnyResult}, - stream::{proxy_table::StreamProxyConfigBuilder, session_table::StreamSessionTable}, + stream::session_table::StreamSessionTable, udp::{ context::UdpContext, proxy_table::UdpProxyConfigBuilder, session_table::UdpSessionTable, }, @@ -14,10 +14,11 @@ use config::ConfigReader; use protocol::{ context::ConcreteContext, stream::{ - addr::{ConcreteStreamAddrStr, ConcreteStreamType}, + addr::ConcreteStreamType, connect::ConcreteStreamConnectorTable, context::ConcreteStreamContext, pool::{ConcreteConnPool, ConcretePoolBuilder}, + proxy_table::StreamProxyConfigBuilder, }, }; use proxy_server::{ProxyServerConfig, ProxyServerLoader}; @@ -155,23 +156,23 @@ pub async fn load_and_clean( cancellation: CancellationToken, context: ConcreteContext, ) -> AnyResult { - let mut stream_proxy = HashMap::new(); - for (k, v) in config.stream_proxy { + let mut stream_proxy_server = HashMap::new(); + for (k, v) in config.stream.proxy_server { let v = v.build()?; - stream_proxy.insert(k, v); + stream_proxy_server.insert(k, v); } let mut udp_proxy = HashMap::new(); - for (k, v) in config.udp_proxy { + for (k, v) in config.udp.proxy_server { let v = v.build()?; udp_proxy.insert(k, v); } context.stream.pool.replaced_by( config - .global - .stream_pool - .build(context.stream.connector_table.clone(), &stream_proxy)?, + .stream + .pool + .build(context.stream.connector_table.clone(), &stream_proxy_server)?, ); config @@ -181,7 +182,7 @@ pub async fn load_and_clean( &mut server_loader.access_server, cancellation, context.clone(), - &stream_proxy, + &stream_proxy_server, &udp_proxy, ) .await?; @@ -198,54 +199,71 @@ pub async fn load_and_clean( #[derive(Debug, Deserialize, Default)] #[serde(deny_unknown_fields)] -pub struct ServerConfig { - #[serde(default)] - pub access_server: AccessServerConfig, - #[serde(default)] - pub proxy_server: ProxyServerConfig, +pub struct StreamConfig { #[serde(default)] - pub global: Global, + pool: ConcretePoolBuilder, #[serde(default)] - pub stream_proxy: HashMap, StreamProxyConfigBuilder>, + proxy_server: HashMap, StreamProxyConfigBuilder>, +} +impl Merge for StreamConfig { + type Error = AnyError; + + fn merge(self, other: Self) -> Result + where + Self: Sized, + { + let pool = self.pool.merge(other.pool)?; + let proxy_server = merge_map(self.proxy_server, other.proxy_server)?; + Ok(Self { pool, proxy_server }) + } +} + +#[derive(Debug, Deserialize, Default)] +#[serde(deny_unknown_fields)] +pub struct UdpConfig { #[serde(default)] - pub udp_proxy: HashMap, UdpProxyConfigBuilder>, + proxy_server: HashMap, UdpProxyConfigBuilder>, } -impl Merge for ServerConfig { +impl Merge for UdpConfig { type Error = AnyError; fn merge(self, other: Self) -> Result where Self: Sized, { - let access_server = self.access_server.merge(other.access_server)?; - let proxy_server = self.proxy_server.merge(other.proxy_server)?; - let global = self.global.merge(other.global)?; - let stream_proxy = merge_map(self.stream_proxy, other.stream_proxy)?; - let udp_proxy = merge_map(self.udp_proxy, other.udp_proxy)?; - Ok(Self { - access_server, - proxy_server, - global, - stream_proxy, - udp_proxy, - }) + let proxy_server = merge_map(self.proxy_server, other.proxy_server)?; + Ok(Self { proxy_server }) } } -#[derive(Debug, Default, Deserialize)] +#[derive(Debug, Deserialize, Default)] #[serde(deny_unknown_fields)] -pub struct Global { +pub struct ServerConfig { + #[serde(default)] + pub access_server: AccessServerConfig, + #[serde(default)] + pub proxy_server: ProxyServerConfig, #[serde(default)] - pub stream_pool: ConcretePoolBuilder, + pub stream: StreamConfig, + #[serde(default)] + pub udp: UdpConfig, } -impl Merge for Global { - type Error = Infallible; +impl Merge for ServerConfig { + type Error = AnyError; fn merge(self, other: Self) -> Result where Self: Sized, { - let stream_pool = self.stream_pool.merge(other.stream_pool)?; - Ok(Self { stream_pool }) + let access_server = self.access_server.merge(other.access_server)?; + let proxy_server = self.proxy_server.merge(other.proxy_server)?; + let stream = self.stream.merge(other.stream)?; + let udp = self.udp.merge(other.udp)?; + Ok(Self { + access_server, + proxy_server, + stream, + udp, + }) } } diff --git a/tests/src/tcp.rs b/tests/src/tcp.rs index 0778f8c..38abc1a 100644 --- a/tests/src/tcp.rs +++ b/tests/src/tcp.rs @@ -2,22 +2,19 @@ mod tests { use std::{io, time::Duration}; - use common::{ - loading::Server, - proxy_table::ProxyConfig, - stream::{addr::StreamAddr, proxy_table::StreamProxyConfig}, - }; + use common::{loading::Server, proxy_table::ProxyConfig, stream::addr::StreamAddr}; use protocol::stream::{ addr::{ConcreteStreamAddr, ConcreteStreamType}, connect::ConcreteStreamConnectorTable, connection::ConnAndAddr, context::ConcreteStreamContext, pool::ConcreteConnPool, + proxy_table::StreamProxyConfig, }; use proxy_client::stream::{establish, trace_rtt}; use proxy_server::stream::{ kcp::build_kcp_proxy_server, mptcp::build_mptcp_proxy_server, tcp::build_tcp_proxy_server, - StreamProxy, + StreamProxyServer, }; use serial_test::serial; use swap::Swap; @@ -45,9 +42,9 @@ mod tests { join_set: &mut tokio::task::JoinSet<()>, addr: &str, ty: ConcreteStreamType, - ) -> StreamProxyConfig { + ) -> StreamProxyConfig { let crypto = create_random_crypto(); - let proxy = StreamProxy::new(crypto.clone(), None, stream_context()); + let proxy = StreamProxyServer::new(crypto.clone(), None, stream_context()); let proxy_addr = match ty { ConcreteStreamType::Tcp => { let server = build_tcp_proxy_server(addr, proxy).await.unwrap(); @@ -82,21 +79,22 @@ mod tests { address: proxy_addr.into(), stream_type: ty, }, - crypto, + header_crypto: crypto, + payload_crypto: None, } } async fn spawn_tcp_proxy( join_set: &mut tokio::task::JoinSet<()>, addr: &str, - ) -> StreamProxyConfig { + ) -> StreamProxyConfig { spawn_proxy(join_set, addr, ConcreteStreamType::Tcp).await } async fn spawn_kcp_proxy( join_set: &mut tokio::task::JoinSet<()>, addr: &str, - ) -> StreamProxyConfig { + ) -> StreamProxyConfig { spawn_proxy(join_set, addr, ConcreteStreamType::Kcp).await } diff --git a/tests/src/udp.rs b/tests/src/udp.rs index a879193..46d8591 100644 --- a/tests/src/udp.rs +++ b/tests/src/udp.rs @@ -38,7 +38,8 @@ mod tests { }); ProxyConfig { address: proxy_addr.into(), - crypto, + header_crypto: crypto, + payload_crypto: None, } }