Skip to content

Commit

Permalink
feat: semantically merge configs
Browse files Browse the repository at this point in the history
  • Loading branch information
Banyc committed Feb 19, 2024
1 parent 8604252 commit f07a227
Show file tree
Hide file tree
Showing 11 changed files with 313 additions and 239 deletions.
275 changes: 123 additions & 152 deletions Cargo.lock

Large diffs are not rendered by default.

43 changes: 30 additions & 13 deletions access_server/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use std::{collections::HashMap, sync::Arc};

use common::{
error::AnyResult,
config::{merge_map, Merge},
error::{AnyError, AnyResult},
filter::{self, FilterBuilder, MatcherBuilder},
loading,
stream::proxy_table::StreamProxyConfig,
Expand All @@ -27,7 +28,7 @@ pub mod socks5;
pub mod stream;
pub mod udp;

#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(deny_unknown_fields)]
pub struct AccessServerConfig {
#[serde(default)]
Expand All @@ -49,7 +50,6 @@ pub struct AccessServerConfig {
#[serde(default)]
pub filters: HashMap<Arc<str>, FilterBuilder>,
}

impl AccessServerConfig {
pub fn new() -> AccessServerConfig {
AccessServerConfig {
Expand Down Expand Up @@ -190,21 +190,44 @@ impl AccessServerConfig {
Ok(())
}
}
impl Merge for AccessServerConfig {
type Error = AnyError;

impl Default for AccessServerConfig {
fn default() -> Self {
Self::new()
fn merge(mut self, other: Self) -> Result<Self, Self::Error>
where
Self: Sized,
{
self.tcp_server.extend(other.tcp_server);
self.udp_server.extend(other.udp_server);
self.http_server.extend(other.http_server);
self.socks5_tcp_server.extend(other.socks5_tcp_server);
self.socks5_udp_server.extend(other.socks5_udp_server);
let stream_proxy_tables = merge_map(self.stream_proxy_tables, other.stream_proxy_tables)?;
let udp_proxy_tables = merge_map(self.udp_proxy_tables, other.udp_proxy_tables)?;
let matchers = merge_map(self.matchers, other.matchers)?;
let filters = merge_map(self.filters, other.filters)?;
Ok(Self {
tcp_server: self.tcp_server,
udp_server: self.udp_server,
http_server: self.http_server,
socks5_tcp_server: self.socks5_tcp_server,
socks5_udp_server: self.socks5_udp_server,
stream_proxy_tables,
udp_proxy_tables,
matchers,
filters,
})
}
}

#[derive(Default)]
pub struct AccessServerLoader {
tcp_server: loading::Loader<TcpAccess>,
udp_server: loading::Loader<UdpAccess>,
http_server: loading::Loader<HttpAccess>,
socks5_tcp_server: loading::Loader<Socks5ServerTcpAccess>,
socks5_udp_server: loading::Loader<Socks5ServerUdpAccess>,
}

impl AccessServerLoader {
pub fn new() -> Self {
Self {
Expand All @@ -216,9 +239,3 @@ impl AccessServerLoader {
}
}
}

impl Default for AccessServerLoader {
fn default() -> Self {
Self::new()
}
}
32 changes: 31 additions & 1 deletion common/src/config.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,40 @@
use std::sync::Arc;
use std::{collections::HashMap, sync::Arc};

use serde::{Deserialize, Serialize};
use thiserror::Error;

#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum SharableConfig<T> {
SharingKey(Arc<str>),
Private(T),
}

pub trait Merge: Default {
type Error;
/// # Error
///
/// - Attempted key overriding
fn merge(self, other: Self) -> Result<Self, Self::Error>
where
Self: Sized;
}

pub fn merge_map<K, V>(
mut a: HashMap<K, V>,
b: HashMap<K, V>,
) -> Result<HashMap<K, V>, RepeatKeyMergeError<K>>
where
K: Eq + std::hash::Hash,
{
for (k, v) in b {
if a.contains_key(&k) {
return Err(RepeatKeyMergeError(k));
}
a.insert(k, v);
}
Ok(a)
}
#[derive(Debug, Error)]
#[error("Repeated key: {0}")]
pub struct RepeatKeyMergeError<K>(pub K);
26 changes: 19 additions & 7 deletions common/src/stream/pool.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{
collections::HashMap, io, marker::PhantomData, net::SocketAddr, sync::Arc, time::Duration,
collections::HashMap, convert::Infallible, io, marker::PhantomData, net::SocketAddr, sync::Arc,
time::Duration,
};

use async_trait::async_trait;
Expand All @@ -8,7 +9,7 @@ use thiserror::Error;
use tokio_conn_pool::{ConnPool, ConnPoolEntry};

use crate::{
config::SharableConfig,
config::{Merge, SharableConfig},
header::heartbeat::send_noop,
proxy_table::ProxyConfig,
stream::{
Expand Down Expand Up @@ -70,18 +71,29 @@ where
Ok(pool)
}
}
impl<SAS> Default for PoolBuilder<SAS> {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Error)]
pub enum PoolBuildError {
#[error("{0}")]
StreamProxyConfigBuild(#[from] StreamProxyConfigBuildError),
#[error("Key not found: {0}")]
KeyNotFound(Arc<str>),
}
impl<SAS> Default for PoolBuilder<SAS> {
fn default() -> Self {
Self::new()
}
}
impl<SAS> Merge for PoolBuilder<SAS> {
type Error = Infallible;

fn merge(mut self, other: Self) -> Result<Self, Self::Error>
where
Self: Sized,
{
self.0.extend(other.0);
Ok(Self(self.0))
}
}

fn pool_entries_from_proxy_configs<C, CT: Clone, ST>(
proxy_configs: impl Iterator<Item = ProxyConfig<StreamAddr<ST>>>,
Expand Down
26 changes: 19 additions & 7 deletions proxy_server/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::io;
use std::{convert::Infallible, io};

use common::{error::AnyResult, loading};
use common::{config::Merge, error::AnyResult, loading};
use protocol::context::ConcreteContext;
use serde::Deserialize;
use stream::{
Expand All @@ -13,7 +13,7 @@ use udp::{UdpProxy, UdpProxyServerBuilder, UdpProxyServerConfig};
pub mod stream;
pub mod udp;

#[derive(Debug, Clone, Deserialize)]
#[derive(Debug, Clone, Deserialize, Default)]
#[serde(deny_unknown_fields)]
pub struct ProxyServerConfig {
#[serde(default)]
Expand All @@ -25,7 +25,6 @@ pub struct ProxyServerConfig {
#[serde(default)]
pub mptcp_server: Vec<MptcpProxyServerConfig>,
}

impl ProxyServerConfig {
pub fn new() -> Self {
Self {
Expand Down Expand Up @@ -92,10 +91,23 @@ impl ProxyServerConfig {
Ok(())
}
}
impl Merge for ProxyServerConfig {
type Error = Infallible;

impl Default for ProxyServerConfig {
fn default() -> Self {
Self::new()
fn merge(mut self, other: Self) -> Result<Self, Self::Error>
where
Self: Sized,
{
self.tcp_server.extend(other.tcp_server);
self.udp_server.extend(other.udp_server);
self.kcp_server.extend(other.kcp_server);
self.mptcp_server.extend(other.mptcp_server);
Ok(Self {
tcp_server: self.tcp_server,
udp_server: self.udp_server,
kcp_server: self.kcp_server,
mptcp_server: self.mptcp_server,
})
}
}

Expand Down
2 changes: 1 addition & 1 deletion proxy_server/src/stream/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ mod tests {
use swap::Swap;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{TcpListener, TcpStream},
net::TcpStream,
};

#[tokio::test(flavor = "multi_thread")]
Expand Down
1 change: 1 addition & 0 deletions server/src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use common::error::AnyError;
use file_watcher_tokio::EventActor;

pub mod multi_file_config;
pub mod toml;

pub trait ConfigReader {
type Config;
Expand Down
67 changes: 13 additions & 54 deletions server/src/config/multi_file_config.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use std::{marker::PhantomData, sync::Arc};

use common::error::AnyError;
use common::{config::Merge, error::AnyError};
use serde::Deserialize;

use crate::ConfigReader;

use super::ConfigWatcher;
use super::{toml::human_toml_error, ConfigWatcher};

pub fn spawn_watch_tasks(config_file_paths: &[Arc<str>]) -> Arc<tokio::sync::Notify> {
let watcher = ConfigWatcher::new();
Expand All @@ -18,12 +18,12 @@ pub fn spawn_watch_tasks(config_file_paths: &[Arc<str>]) -> Arc<tokio::sync::Not
notify_rx
}

pub struct MultiFileConfigReader<C> {
pub struct MultiConfigReader<C> {
config_file_paths: Arc<[Arc<str>]>,
phantom_config: PhantomData<C>,
}

impl<C> MultiFileConfigReader<C> {
impl<C> MultiConfigReader<C> {
pub fn new(config_file_paths: Arc<[Arc<str>]>) -> Self {
Self {
config_file_paths,
Expand All @@ -32,60 +32,19 @@ impl<C> MultiFileConfigReader<C> {
}
}

impl<C> ConfigReader for MultiFileConfigReader<C>
impl<C> ConfigReader for MultiConfigReader<C>
where
for<'de> C: Deserialize<'de> + Send + Sync + 'static,
C: Merge<Error = AnyError>,
{
type Config = C;
async fn read_config(&self) -> Result<Self::Config, AnyError> {
read_multi_file_config(&self.config_file_paths).await
}
}

pub async fn read_multi_file_config<C>(config_file_paths: &[Arc<str>]) -> Result<C, AnyError>
where
for<'de> C: Deserialize<'de>,
{
let mut config_str = String::new();
for path in config_file_paths {
let src = tokio::fs::read_to_string(path.as_ref()).await?;
config_str.push_str(&src);
}
let config: C = toml::from_str(&config_str).map_err(|e| human_toml_error(&config_str, e))?;
Ok(config)
}

fn human_toml_error(src: &str, e: toml::de::Error) -> String {
let Some(span) = e.span() else {
return format!("{e}");
};
let affected = src
.chars()
.skip(span.start)
.take(span.end - span.start)
.collect::<String>();
let (line, col) = {
let mut line = 1;
let mut col = 1;

for (i, char) in src.chars().enumerate() {
if i == span.start {
break;
}
if char == '\n' {
line += 1;
col = 1;
}
col += 1;
let mut config = C::default();
for path in self.config_file_paths.iter() {
let src = tokio::fs::read_to_string(path.as_ref()).await?;
let c: C = toml::from_str(&src).map_err(|e| human_toml_error(path, &src, e))?;
config = config.merge(c)?;
}

(line, col)
};
let msg = e.message();
let e = format!(
"{msg}
Line {line}, Column {col}
Affected: #'{affected}'#"
);
e
Ok(config)
}
}
35 changes: 35 additions & 0 deletions server/src/config/toml.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
pub fn human_toml_error(file_path: &str, src: &str, e: toml::de::Error) -> String {
let Some(span) = e.span() else {
return format!("{e}");
};
let affected = src
.chars()
.skip(span.start)
.take(span.end - span.start)
.collect::<String>();
let (line, col) = {
let mut line = 1;
let mut col = 1;

for (i, char) in src.chars().enumerate() {
if i == span.start {
break;
}
if char == '\n' {
line += 1;
col = 1;
}
col += 1;
}

(line, col)
};
let msg = e.message();
let e = format!(
"{msg}
File `{file_path}`
Line {line}, Column {col}
Affected: #'{affected}'#"
);
e
}
Loading

0 comments on commit f07a227

Please sign in to comment.