From 319a85ce3f064948e5987ae46d5851c8d82f766d Mon Sep 17 00:00:00 2001 From: Banyc <36535895+Banyc@users.noreply.github.com> Date: Wed, 21 Feb 2024 02:16:15 +0800 Subject: [PATCH] feat: matchable proxy table --- Cargo.lock | 1 + Cargo.toml | 1 + access_server/Cargo.toml | 1 + access_server/src/socks5/server/tcp.rs | 18 ++++++-- access_server/src/socks5/server/udp.rs | 7 +++- access_server/src/stream/proxy_table.rs | 41 ++++++++++++++++--- .../src/stream/streams/http_tunnel/mod.rs | 13 +++++- access_server/src/stream/streams/tcp.rs | 7 +++- access_server/src/udp/mod.rs | 7 +++- access_server/src/udp/proxy_table.rs | 41 ++++++++++++++++--- common/Cargo.toml | 2 +- common/src/filter.rs | 8 ++++ common/src/proxy_table.rs | 35 ++++++++++++++-- common/src/stream/proxy_table.rs | 3 +- common/src/udp/proxy_table.rs | 3 +- server/config.toml | 36 ++++++++-------- 16 files changed, 181 insertions(+), 43 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index dc2266c..2296f40 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -16,6 +16,7 @@ dependencies = [ "pin-project-lite", "protocol", "proxy_client", + "regex", "serde", "table_log", "thiserror", diff --git a/Cargo.toml b/Cargo.toml index 3ffa653..4d66eaa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,7 @@ monitor_table = { git = "https://github.com/Banyc/monitor_table.git", rev = "f9c mptcp = { git = "https://github.com/Banyc/mptcp.git", rev = "8c0b8ee35bc7d570de272f5da7af6742b460e23a" } once_cell = "1" openssl = "0.10" +regex = "1" serde = "1" strict-num = "0.2" swap = { git = "https://github.com/Banyc/swap.git", rev = "d10a8b5b10503fa6ebac523cfcaa4d62135a665f" } diff --git a/access_server/Cargo.toml b/access_server/Cargo.toml index ae71588..a85136c 100644 --- a/access_server/Cargo.toml +++ b/access_server/Cargo.toml @@ -15,6 +15,7 @@ monitor_table = { workspace = true } pin-project-lite = "0.2" protocol = { path = "../protocol" } proxy_client = { path = "../proxy_client" } +regex = { workspace = true } serde = { workspace = true, features = ["derive"] } table_log = { workspace = true } thiserror = { workspace = true } diff --git a/access_server/src/socks5/server/tcp.rs b/access_server/src/socks5/server/tcp.rs index 00d62ad..e849c76 100644 --- a/access_server/src/socks5/server/tcp.rs +++ b/access_server/src/socks5/server/tcp.rs @@ -532,8 +532,12 @@ impl Socks5ServerTcpAccess { async fn establish_proxy_chain( &self, destination: InternetAddr, - ) -> Result<(ConnAndAddr, Option), StreamEstablishError> { - let proxy_chain = self.proxy_table.choose_chain(); + ) -> Result<(ConnAndAddr, Option), EstablishProxyChainError> + { + let Some(proxy_table_group) = self.proxy_table.group(&destination) else { + return Err(EstablishProxyChainError::NoProxy); + }; + let proxy_chain = proxy_table_group.choose_chain(); let res = proxy_client::stream::establish( &proxy_chain.chain, StreamAddr { @@ -547,6 +551,14 @@ impl Socks5ServerTcpAccess { } } +#[derive(Debug, Error)] +pub enum EstablishProxyChainError { + #[error("No proxy")] + NoProxy, + #[error("{0}")] + StreamEstablish(#[from] StreamEstablishError), +} + pub enum EstablishResult { Blocked { destination: InternetAddr, @@ -604,7 +616,7 @@ pub enum EstablishError { #[error("IO error: {0}")] Io(#[from] io::Error), #[error("Failed to establish proxy chain: {0}")] - EstablishProxyChain(#[from] StreamEstablishError), + EstablishProxyChain(#[from] EstablishProxyChainError), #[error("Command BIND not supported")] CmdBindNotSupported, #[error("No UDP server available")] diff --git a/access_server/src/socks5/server/udp.rs b/access_server/src/socks5/server/udp.rs index d966fa3..dd8ac1c 100644 --- a/access_server/src/socks5/server/udp.rs +++ b/access_server/src/socks5/server/udp.rs @@ -127,7 +127,10 @@ impl Socks5ServerUdpAccess { downstream_writer: UdpDownstreamWriter, ) -> Result<(), AccessProxyError> { // Connect to upstream - let proxy_chain = self.proxy_table.choose_chain(); + let Some(proxy_table_group) = self.proxy_table.group(&flow.flow().upstream.0) else { + return Err(AccessProxyError::NoProxy); + }; + let proxy_chain = proxy_table_group.choose_chain(); let upstream = UdpProxyClient::establish(proxy_chain.chain.clone(), flow.flow().upstream.0.clone()) .await?; @@ -174,6 +177,8 @@ impl Socks5ServerUdpAccess { #[derive(Debug, Error)] pub enum AccessProxyError { + #[error("No proxy")] + NoProxy, #[error("Failed to establish proxy chain: {0}")] Establish(#[from] EstablishError), } diff --git a/access_server/src/stream/proxy_table.rs b/access_server/src/stream/proxy_table.rs index a7066c7..457326e 100644 --- a/access_server/src/stream/proxy_table.rs +++ b/access_server/src/stream/proxy_table.rs @@ -1,9 +1,10 @@ use std::{collections::HashMap, num::NonZeroUsize, sync::Arc}; use common::{ - proxy_table::{ProxyTable, ProxyTableError}, + filter::MatcherBuilder, + proxy_table::{ProxyTable, ProxyTableError, ProxyTableGroup}, stream::proxy_table::{ - StreamProxyConfig, StreamProxyConfigBuildError, StreamProxyTable, + StreamProxyConfig, StreamProxyConfigBuildError, StreamProxyTable, StreamProxyTableGroup, StreamWeightedProxyChainBuilder, }, }; @@ -19,18 +20,43 @@ use tokio_util::sync::CancellationToken; #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(deny_unknown_fields)] pub struct StreamProxyTableBuilder { + pub groups: Vec, +} +impl StreamProxyTableBuilder { + pub fn build( + self, + stream_proxy: &HashMap, StreamProxyConfig>, + stream_context: &ConcreteStreamContext, + cancellation: CancellationToken, + ) -> Result, StreamProxyTableBuildError> { + let mut built = vec![]; + for group in self.groups { + let g = group.build(stream_proxy, stream_context, cancellation.clone())?; + built.push(g); + } + Ok(ProxyTable::new(built)) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct StreamProxyTableGroupBuilder { + pub matcher: MatcherBuilder, pub chains: Vec>, pub trace_rtt: bool, pub active_chains: Option, } - -impl StreamProxyTableBuilder { +impl StreamProxyTableGroupBuilder { pub fn build( self, stream_proxy: &HashMap, StreamProxyConfig>, stream_context: &ConcreteStreamContext, cancellation: CancellationToken, - ) -> Result, StreamProxyTableBuildError> { + ) -> Result, StreamProxyTableBuildError> { + let matcher = self + .matcher + .build() + .map_err(StreamProxyTableBuildError::Matcher)?; let chains = self .chains .into_iter() @@ -41,7 +67,8 @@ impl StreamProxyTableBuilder { true => Some(StreamTracer::new(stream_context.clone())), false => None, }; - Ok(ProxyTable::new( + Ok(ProxyTableGroup::new( + matcher, chains, tracer, self.active_chains, @@ -52,6 +79,8 @@ impl StreamProxyTableBuilder { #[derive(Debug, Error)] pub enum StreamProxyTableBuildError { + #[error("Matcher: {0}")] + Matcher(#[source] regex::Error), #[error("Chain config is invalid: {0}")] ChainConfig(#[source] StreamProxyConfigBuildError), #[error("{0}")] diff --git a/access_server/src/stream/streams/http_tunnel/mod.rs b/access_server/src/stream/streams/http_tunnel/mod.rs index 9a3c255..67b8943 100644 --- a/access_server/src/stream/streams/http_tunnel/mod.rs +++ b/access_server/src/stream/streams/http_tunnel/mod.rs @@ -231,7 +231,11 @@ impl HttpAccess { } // Establish proxy chain - let proxy_chain = self.proxy_table.choose_chain(); + let Some(proxy_table_group) = self.proxy_table.group(&addr.address) else { + trace!(?addr, "No proxy {}", method); + return Ok(respond_with_rejection()); + }; + let proxy_chain = proxy_table_group.choose_chain(); let upstream = establish(&proxy_chain.chain, addr.clone(), &self.stream_context).await?; let session_guard = self.stream_context.session_table.as_ref().map(|s| { @@ -500,7 +504,10 @@ impl HttpConnect { address: address.clone(), stream_type: ConcreteStreamType::Tcp, }; - let proxy_chain = self.proxy_table.choose_chain(); + let Some(proxy_table_group) = self.proxy_table.group(&address) else { + return Err(HttpConnectError::NoProxy); + }; + let proxy_chain = proxy_table_group.choose_chain(); let upstream = establish( &proxy_chain.chain, destination.clone(), @@ -534,6 +541,8 @@ impl HttpConnect { #[derive(Debug, Error)] pub enum HttpConnectError { + #[error("No proxy")] + NoProxy, #[error("Failed to establish proxy chain")] EstablishProxyChain(#[from] StreamEstablishError), } diff --git a/access_server/src/stream/streams/tcp.rs b/access_server/src/stream/streams/tcp.rs index 5d44667..90ec0fa 100644 --- a/access_server/src/stream/streams/tcp.rs +++ b/access_server/src/stream/streams/tcp.rs @@ -134,7 +134,10 @@ impl TcpAccess { where S: IoStream + IoAddr, { - let proxy_chain = self.proxy_table.choose_chain(); + let Some(proxy_table_group) = self.proxy_table.group(&self.destination.address) else { + return Err(ProxyError::NoProxy); + }; + let proxy_chain = proxy_table_group.choose_chain(); let upstream = establish( &proxy_chain.chain, self.destination.clone(), @@ -168,6 +171,8 @@ impl TcpAccess { #[derive(Debug, Error)] pub enum ProxyError { + #[error("No proxy")] + NoProxy, #[error("Failed to get downstream address: {0}")] DownstreamAddr(#[source] io::Error), #[error("Failed to establish proxy chain: {0}")] diff --git a/access_server/src/udp/mod.rs b/access_server/src/udp/mod.rs index 95ad084..e9e3980 100644 --- a/access_server/src/udp/mod.rs +++ b/access_server/src/udp/mod.rs @@ -138,7 +138,10 @@ impl UdpAccess { downstream_writer: UdpDownstreamWriter, ) -> Result<(), AccessProxyError> { // Connect to upstream - let proxy_chain = self.proxy_table.choose_chain(); + let Some(proxy_table_group) = self.proxy_table.group(&flow.flow().upstream.0) else { + return Err(AccessProxyError::NoProxy); + }; + let proxy_chain = proxy_table_group.choose_chain(); let upstream = UdpProxyClient::establish(proxy_chain.chain.clone(), self.destination.clone()).await?; let upstream_remote = upstream.remote_addr().clone(); @@ -174,6 +177,8 @@ impl UdpAccess { #[derive(Debug, Error)] pub enum AccessProxyError { + #[error("No proxy")] + NoProxy, #[error("Failed to establish proxy chain: {0}")] Establish(#[from] EstablishError), #[error("Failed to copy: {0}")] diff --git a/access_server/src/udp/proxy_table.rs b/access_server/src/udp/proxy_table.rs index c0dd61a..860c9ce 100644 --- a/access_server/src/udp/proxy_table.rs +++ b/access_server/src/udp/proxy_table.rs @@ -1,9 +1,11 @@ use std::{collections::HashMap, num::NonZeroUsize, sync::Arc}; use common::{ - proxy_table::{ProxyTable, ProxyTableError}, + filter::MatcherBuilder, + proxy_table::{ProxyTable, ProxyTableError, ProxyTableGroup}, udp::proxy_table::{ - UdpProxyConfig, UdpProxyConfigBuildError, UdpProxyTable, UdpWeightedProxyChainBuilder, + UdpProxyConfig, UdpProxyConfigBuildError, UdpProxyTable, UdpProxyTableGroup, + UdpWeightedProxyChainBuilder, }, }; use proxy_client::udp::UdpTracer; @@ -14,17 +16,41 @@ use tokio_util::sync::CancellationToken; #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(deny_unknown_fields)] pub struct UdpProxyTableBuilder { + pub groups: Vec, +} +impl UdpProxyTableBuilder { + pub fn build( + self, + udp_proxy: &HashMap, UdpProxyConfig>, + cancellation: CancellationToken, + ) -> Result { + let mut built = vec![]; + for group in self.groups { + let g = group.build(udp_proxy, cancellation.clone())?; + built.push(g); + } + Ok(ProxyTable::new(built)) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct UdpProxyTableGroupBuilder { + pub matcher: MatcherBuilder, pub chains: Vec, pub trace_rtt: bool, pub active_chains: Option, } - -impl UdpProxyTableBuilder { +impl UdpProxyTableGroupBuilder { pub fn build( self, udp_proxy: &HashMap, UdpProxyConfig>, cancellation: CancellationToken, - ) -> Result { + ) -> Result { + let matcher = self + .matcher + .build() + .map_err(UdpProxyTableBuildError::Matcher)?; let chains = self .chains .into_iter() @@ -35,7 +61,8 @@ impl UdpProxyTableBuilder { true => Some(UdpTracer::new()), false => None, }; - Ok(ProxyTable::new( + Ok(ProxyTableGroup::new( + matcher, chains, tracer, self.active_chains, @@ -46,6 +73,8 @@ impl UdpProxyTableBuilder { #[derive(Debug, Error)] pub enum UdpProxyTableBuildError { + #[error("Matcher: {0}")] + Matcher(#[source] regex::Error), #[error("Chain config is invalid: {0}")] ChainConfig(#[source] UdpProxyConfigBuildError), #[error("{0}")] diff --git a/common/Cargo.toml b/common/Cargo.toml index e149cd4..afa9f6a 100644 --- a/common/Cargo.toml +++ b/common/Cargo.toml @@ -21,7 +21,7 @@ monitor_table = { workspace = true } once_cell = { workspace = true } pin-project-lite = "0.2" rand = "0.8" -regex = "1" +regex = { workspace = true } scopeguard = "1" serde = { workspace = true, features = ["derive", "rc"] } slotmap = "1" diff --git a/common/src/filter.rs b/common/src/filter.rs index 2abefd6..42d8440 100644 --- a/common/src/filter.rs +++ b/common/src/filter.rs @@ -206,6 +206,14 @@ impl MatcherBuilderKind { #[derive(Debug, Clone)] pub struct Matcher(MatcherKind); +impl Matcher { + pub fn matches(&self, addr: &InternetAddr) -> bool { + match addr.deref() { + InternetAddrKind::SocketAddr(addr) => self.0.is_match_ip(*addr), + InternetAddrKind::DomainName { addr, port } => self.0.is_match_domain_name(addr, *port), + } + } +} #[derive(Debug, Clone)] enum MatcherKind { diff --git a/common/src/proxy_table.rs b/common/src/proxy_table.rs index 44c8d56..f143842 100644 --- a/common/src/proxy_table.rs +++ b/common/src/proxy_table.rs @@ -11,7 +11,10 @@ use thiserror::Error; use tokio_util::sync::CancellationToken; use tracing::{info, trace}; -use crate::{cache_cell::CacheCell, error::AnyError, header::route::RouteRequest}; +use crate::{ + addr::InternetAddr, cache_cell::CacheCell, error::AnyError, filter::Matcher, + header::route::RouteRequest, +}; const TRACE_INTERVAL: Duration = Duration::from_secs(30); const TRACE_DEAD_INTERVAL: Duration = Duration::from_secs(60 * 2); @@ -57,17 +60,37 @@ where #[derive(Debug, Clone)] pub struct ProxyTable { + groups: Vec>, +} +impl ProxyTable +where + A: std::fmt::Debug + Display + Clone + Send + Sync + 'static, +{ + pub fn new(groups: Vec>) -> Self { + Self { groups } + } + + pub fn group(&self, addr: &InternetAddr) -> Option<&ProxyTableGroup> { + self.groups + .iter() + .find(|&group| group.matcher().matches(addr)) + } +} + +#[derive(Debug, Clone)] +pub struct ProxyTableGroup { + matcher: Matcher, chains: Arc<[GaugedProxyChain]>, cum_weight: NonZeroUsize, score_store: Arc>, active_chains: NonZeroUsize, } - -impl ProxyTable +impl ProxyTableGroup where A: std::fmt::Debug + Display + Clone + Send + Sync + 'static, { pub fn new( + matcher: Matcher, chains: Vec>, tracer: Option, active_chains: Option, @@ -99,6 +122,7 @@ where .collect::>(); let score_store = Arc::new(RwLock::new(ScoreStore::new(None, TRACE_INTERVAL))); Ok(Self { + matcher, chains, cum_weight, score_store, @@ -106,6 +130,10 @@ where }) } + pub fn matcher(&self) -> &Matcher { + &self.matcher + } + pub fn choose_chain(&self) -> &WeightedProxyChain { if self.chains.len() == 1 { return self.chains[0].weighted(); @@ -165,7 +193,6 @@ where scores[..self.active_chains.get()].to_vec() } } - #[derive(Debug, Error)] pub enum ProxyTableError { #[error("Zero accumulated weight with chains")] diff --git a/common/src/stream/proxy_table.rs b/common/src/stream/proxy_table.rs index 0def419..ab2216e 100644 --- a/common/src/stream/proxy_table.rs +++ b/common/src/stream/proxy_table.rs @@ -5,7 +5,7 @@ use thiserror::Error; use crate::{ config::SharableConfig, - proxy_table::{ProxyConfig, ProxyTable, WeightedProxyChain}, + proxy_table::{ProxyConfig, ProxyTable, ProxyTableGroup, WeightedProxyChain}, }; use super::{addr::StreamAddrStr, StreamAddr}; @@ -79,3 +79,4 @@ pub type StreamProxyConfig = ProxyConfig>; pub type StreamProxyChain = [StreamProxyConfig]; pub type StreamWeightedProxyChain = WeightedProxyChain>; pub type StreamProxyTable = ProxyTable>; +pub type StreamProxyTableGroup = ProxyTableGroup>; diff --git a/common/src/udp/proxy_table.rs b/common/src/udp/proxy_table.rs index 15f378d..27013c4 100644 --- a/common/src/udp/proxy_table.rs +++ b/common/src/udp/proxy_table.rs @@ -6,7 +6,7 @@ use thiserror::Error; use crate::{ addr::{InternetAddr, InternetAddrStr}, config::SharableConfig, - proxy_table::{ProxyConfig, ProxyTable, WeightedProxyChain}, + proxy_table::{ProxyConfig, ProxyTable, ProxyTableGroup, WeightedProxyChain}, }; #[derive(Debug, Clone, Serialize, Deserialize)] @@ -72,3 +72,4 @@ pub type UdpProxyConfig = ProxyConfig; pub type UdpProxyChain = [UdpProxyConfig]; pub type UdpWeightedProxyChain = WeightedProxyChain; pub type UdpProxyTable = ProxyTable; +pub type UdpProxyTableGroup = ProxyTableGroup; diff --git a/server/config.toml b/server/config.toml index dd17b85..be31c97 100644 --- a/server/config.toml +++ b/server/config.toml @@ -5,23 +5,27 @@ stream_pool = [ ] [access_server.stream_proxy_tables] -"default" = { chains = [ - { weight = 1, chain = [ - { address = "tcp://proxy.example.org:80", header_key = "base64" }, - ], payload_key = "base64" }, - { weight = 1, chain = [ - { address = "kcp://proxy.example.org:80", header_key = "base64" }, - ], payload_key = "base64" }, - { weight = 1, chain = [ - { address = "mptcp://proxy.example.org:81", header_key = "base64" }, - ], payload_key = "base64" }, -], trace_rtt = false, active_chains = 3 } +"default" = { groups = [ + { matcher = { }, chains = [ + { weight = 1, chain = [ + { address = "tcp://proxy.example.org:80", header_key = "base64" }, + ], payload_key = "base64" }, + { weight = 1, chain = [ + { address = "kcp://proxy.example.org:80", header_key = "base64" }, + ], payload_key = "base64" }, + { weight = 1, chain = [ + { address = "mptcp://proxy.example.org:81", header_key = "base64" }, + ], payload_key = "base64" }, + ], trace_rtt = false, active_chains = 3 }, +] } [access_server.udp_proxy_tables] -"default" = { chains = [ - { weight = 1, chain = [ - { address = "proxy.example.org:80", header_key = "base64" }, - ], payload_key = "base64" }, -], trace_rtt = false, active_chains = 1 } +"default" = { groups = [ + { matcher = { }, chains = [ + { weight = 1, chain = [ + { address = "proxy.example.org:80", header_key = "base64" }, + ], payload_key = "base64" }, + ], trace_rtt = false, active_chains = 1 }, +] } [access_server.filters] "default" = [ "localhost",