diff --git a/forward-proxy/src/config.rs b/forward-proxy/src/config.rs index 34cf625..d39ac5a 100644 --- a/forward-proxy/src/config.rs +++ b/forward-proxy/src/config.rs @@ -52,4 +52,4 @@ pub struct InfluxDBConfig { pub influxdb_org: String, pub influxdb_bucket: String, pub influxdb_auth_token: String, -} \ No newline at end of file +} diff --git a/forward-proxy/src/handler/consts.rs b/forward-proxy/src/handler/consts.rs index 9253f6e..5eaca2c 100644 --- a/forward-proxy/src/handler/consts.rs +++ b/forward-proxy/src/handler/consts.rs @@ -42,4 +42,3 @@ impl RequestPaths { pub const INIT_TUNNEL: &'static str = "/init-tunnel"; pub const HEALTHCHECK: &'static str = "/healthcheck"; } - diff --git a/forward-proxy/src/handler/mod.rs b/forward-proxy/src/handler/mod.rs index 9c4805c..16596a6 100644 --- a/forward-proxy/src/handler/mod.rs +++ b/forward-proxy/src/handler/mod.rs @@ -2,24 +2,27 @@ use std::collections::HashMap; use std::sync::{Arc, Mutex}; use pingora::http::StatusCode; -use reqwest::Client; use pingora_router::{ - ctx::{Layer8Context, Layer8ContextTrait}, - handler::{APIHandlerResponse, DefaultHandlerTrait, RequestBodyTrait, ResponseBodyTrait} + ctx::{Layer8Context, Layer8ContextTrait}, + handler::{APIHandlerResponse, DefaultHandlerTrait, RequestBodyTrait, ResponseBodyTrait}, }; +use reqwest::Client; use serde::Deserialize; use tracing::{debug, error, info}; +use crate::config::HandlerConfig; +use crate::handler::consts::LogTypes; use crate::handler::types::{ - response::{ErrorResponse, FpHealthcheckError, FpHealthcheckSuccess, InitTunnelResponseFromRP, InitTunnelResponseToINT}, - request::InitTunnelRequest + request::InitTunnelRequest, + response::{ + ErrorResponse, FpHealthcheckError, FpHealthcheckSuccess, InitTunnelResponseFromRP, + InitTunnelResponseToINT, + }, }; use utils::{self, jwt::JWTClaims}; -use crate::config::HandlerConfig; -use crate::handler::consts::LogTypes; -pub mod types; pub mod consts; +pub mod types; pub struct ForwardHandler { pub config: HandlerConfig, @@ -53,8 +56,7 @@ impl ForwardHandler { &self, backend_url: String, ctx: &mut Layer8Context, - ) -> Result - { + ) -> Result { let correlation_id = ctx.get_correlation_id(); let client = Client::new(); @@ -68,14 +70,18 @@ impl ForwardHandler { self.config.auth_get_certificate_url, backend_url.replace("http://", "").replace("https://", "") ); - let res = client.get(&request_path) - .header("Authorization", format!("Bearer {}", self.config.auth_access_token)) + let res = client + .get(&request_path) + .header( + "Authorization", + format!("Bearer {}", self.config.auth_access_token), + ) .send() .await // unable to connect .map_err(|e| { let response_body = ErrorResponse { - error: format!("Failed to connect to layer8: {}", e) + error: format!("Failed to connect to layer8: {}", e), }; APIHandlerResponse { @@ -87,7 +93,10 @@ impl ForwardHandler { // connected but request failed if !res.status().is_success() { let response_body = ErrorResponse { - error: format!("Failed to get public key from layer8, status code: {}", res.status().as_u16()), + error: format!( + "Failed to get public key from layer8, status code: {}", + res.status().as_u16() + ), }; error!( %correlation_id, @@ -96,24 +105,43 @@ impl ForwardHandler { ); ctx.insert_response_header("Connection", "close"); // Ensure connection closes??? - - Err(APIHandlerResponse { + return Err(APIHandlerResponse { status: StatusCode::BAD_REQUEST, body: Some(response_body.to_bytes()), - }) - } else { - #[derive(Deserialize, Debug)] - struct AuthServerResponse { - pub x509_certificate: String, - pub client_id: String, + }); + } + #[derive(Deserialize, Debug)] + struct AuthServerResponse { + pub x509_certificate: String, + pub client_id: String, + } + + let auth_res: AuthServerResponse = res.json().await.map_err(|err| { + error!( + %correlation_id, + log_type=LogTypes::AUTHENTICATION_SERVER, + "Failed to parse authentication server response: {:?}", + err + ); + APIHandlerResponse { + status: StatusCode::INTERNAL_SERVER_ERROR, + body: None, } + })?; - let auth_res: AuthServerResponse = res.json().await.map_err(|err| { + // save `client_id` to ctx for later use + ctx.set( + consts::CtxKeys::BACKEND_AUTH_CLIENT_ID.to_string(), + auth_res.client_id.clone(), + ); + + let pub_key = + utils::cert::extract_x509_pem(auth_res.x509_certificate.clone()).map_err(|e| { error!( %correlation_id, log_type=LogTypes::AUTHENTICATION_SERVER, - "Failed to parse authentication server response: {:?}", - err + "Failed to parse x509 certificate: {:?}", + e ); APIHandlerResponse { status: StatusCode::INTERNAL_SERVER_ERROR, @@ -121,43 +149,22 @@ impl ForwardHandler { } })?; - // save `client_id` to ctx for later use - ctx.set(consts::CtxKeys::BACKEND_AUTH_CLIENT_ID.to_string(), auth_res.client_id.clone()); - - let pub_key = utils::cert::extract_x509_pem(auth_res.x509_certificate.clone()) - .map_err(|e| { - error!( - %correlation_id, - log_type=LogTypes::AUTHENTICATION_SERVER, - "Failed to parse x509 certificate: {:?}", - e - ); - APIHandlerResponse { - status: StatusCode::INTERNAL_SERVER_ERROR, - body: None, - } - })?; - - debug!(%correlation_id, "AuthenticationServer response: {:?}", auth_res); - info!( - %correlation_id, - log_type=LogTypes::AUTHENTICATION_SERVER, - "Obtained ntor credentials for backend_url: {}", - backend_url - ); + debug!(%correlation_id, "AuthenticationServer response: {:?}", auth_res); + info!( + %correlation_id, + log_type=LogTypes::AUTHENTICATION_SERVER, + "Obtained ntor credentials for backend_url: {}", + backend_url + ); - Ok(NTorServerCertificate { - server_id: backend_url, // todo I still prefer taking the server_id value from certificate's subject - public_key: pub_key, - }) - } + Ok(NTorServerCertificate { + server_id: backend_url, // todo I still prefer taking the server_id value from certificate's subject + public_key: pub_key, + }) } /// Verify `int_fp_jwt` and return `fp_rp_jwt` - pub fn verify_int_fp_jwt( - &self, - token: &str, - ) -> Result { + pub fn verify_int_fp_jwt(&self, token: &str) -> Result { return match utils::jwt::verify_jwt_token(token, &self.config.jwt_virtual_connection_key) { Ok(_claims) => { // todo check claims if needed @@ -166,47 +173,47 @@ impl ForwardHandler { let jwts = self.jwts_storage.lock().unwrap(); jwts.get(token).cloned() } { - None => { - Err("token not found!".to_string()) - } - Some(session) => Ok(session) + None => Err("token not found!".to_string()), + Some(session) => Ok(session), } } - Err(err) => Err(err.to_string()) + Err(err) => Err(err.to_string()), }; } /// Validate request body and get ntor certificate for the given backend URL. pub async fn handle_init_tunnel_request(&self, ctx: &mut Layer8Context) -> APIHandlerResponse { // validate request body - let received_body = match ForwardHandler::parse_request_body::< - InitTunnelRequest, - ErrorResponse - >(&ctx.get_request_body()) - { - Ok(res) => res.to_bytes(), - Err(Some(e)) => { - return APIHandlerResponse { - status: StatusCode::BAD_REQUEST, - body: Some(e.to_bytes()), - }; - } - Err(None) => { - return APIHandlerResponse { - status: StatusCode::BAD_REQUEST, - body: None, - }; - } - }; + let received_body = + match ForwardHandler::parse_request_body::( + &ctx.get_request_body(), + ) { + Ok(res) => res.to_bytes(), + Err(Some(e)) => { + return APIHandlerResponse { + status: StatusCode::BAD_REQUEST, + body: Some(e.to_bytes()), + }; + } + Err(None) => { + return APIHandlerResponse { + status: StatusCode::BAD_REQUEST, + body: None, + }; + } + }; // get public key to initialize encrypted tunnel { // it's safe to use unwrap here because this param was already checked in `request_filter` - let backend_url = ctx.param("backend_url").unwrap_or(&"".to_string()).to_string(); + let backend_url = ctx + .param("backend_url") + .unwrap_or(&"".to_string()) + .to_string(); let server_certificate = match self.get_public_key(backend_url.to_string(), ctx).await { Ok(cert) => cert, - Err(err) => return err + Err(err) => return err, }; debug!("Server certificate: {:?}", server_certificate); @@ -227,65 +234,73 @@ impl ForwardHandler { } pub fn handle_init_tunnel_response(&self, ctx: &mut Layer8Context) -> APIHandlerResponse { - let ntor_server_id = ctx.get(&consts::CtxKeys::NTOR_SERVER_ID.to_string()).unwrap_or(&"".to_string()).clone(); + let ntor_server_id = ctx + .get(consts::CtxKeys::NTOR_SERVER_ID) + .unwrap_or(&"".to_string()) + .clone(); let ntor_static_public_key = hex::decode( - ctx.get(&consts::CtxKeys::NTOR_STATIC_PUBLIC_KEY.to_string()).clone().unwrap_or(&"".to_string()) - ).unwrap_or_default(); + ctx.get(consts::CtxKeys::NTOR_STATIC_PUBLIC_KEY) + .unwrap_or(&"".to_string()), + ) + .unwrap_or_default(); let response_body = ctx.get_response_body(); - - return match utils::bytes_to_json::(response_body) { + let res_from_rp = match utils::bytes_to_json::(response_body) { + Ok(val) => val, Err(e) => { error!( - correlation_id=ctx.get_correlation_id(), - log_type=LogTypes::HANDLE_UPSTREAM_RESPONSE, + correlation_id = ctx.get_correlation_id(), + log_type = LogTypes::HANDLE_UPSTREAM_RESPONSE, "Error parsing RP response: {:?}", e ); - APIHandlerResponse { + return APIHandlerResponse { status: StatusCode::INTERNAL_SERVER_ERROR, body: None, - } - } - Ok(res_from_rp) => { - let int_fp_jwt = { - let mut claims = JWTClaims::new(Some(self.config.jwt_exp_in_hours)); - claims.uuid = Some(utils::new_uuid()); - utils::jwt::create_jwt_token(claims, &self.config.jwt_virtual_connection_key) }; + } + }; - let int_fp_session = IntFPSession { - client_id: ctx.get(&consts::CtxKeys::BACKEND_AUTH_CLIENT_ID.to_string()).unwrap_or(&"".to_string()).to_string(), - rp_base_url: ctx.param("backend_url").unwrap_or(&"".to_string()).to_string(), - fp_rp_jwt: res_from_rp.fp_rp_jwt, - }; + let int_fp_jwt = { + let mut claims = JWTClaims::new(Some(self.config.jwt_exp_in_hours)); + claims.uuid = Some(utils::new_uuid()); + utils::jwt::create_jwt_token(claims, &self.config.jwt_virtual_connection_key) + }; - let mut jwts = self.jwts_storage.lock().unwrap(); - jwts.insert(int_fp_jwt.clone(), int_fp_session); + let int_fp_session = IntFPSession { + client_id: ctx + .get(consts::CtxKeys::BACKEND_AUTH_CLIENT_ID) + .unwrap_or(&"".to_string()) + .clone(), + rp_base_url: ctx.param("backend_url").unwrap_or(&"".to_string()).clone(), + fp_rp_jwt: res_from_rp.fp_rp_jwt, + }; - let res_to_int = InitTunnelResponseToINT { - ephemeral_public_key: res_from_rp.public_key, - t_b_hash: res_from_rp.t_b_hash, - int_rp_jwt: res_from_rp.int_rp_jwt, - int_fp_jwt, - ntor_static_public_key, - ntor_server_id, - }; + let mut jwts = self.jwts_storage.lock().unwrap(); + jwts.insert(int_fp_jwt.clone(), int_fp_session); - APIHandlerResponse { - status: StatusCode::OK, - body: Some(res_to_int.to_bytes()), - } - } + let res_to_int = InitTunnelResponseToINT { + ephemeral_public_key: res_from_rp.public_key, + t_b_hash: res_from_rp.t_b_hash, + int_rp_jwt: res_from_rp.int_rp_jwt, + int_fp_jwt, + ntor_static_public_key, + ntor_server_id, }; + + APIHandlerResponse { + status: StatusCode::OK, + body: Some(res_to_int.to_bytes()), + } } pub fn handle_healthcheck(&self, ctx: &mut Layer8Context) -> APIHandlerResponse { if let Some(error) = ctx.param("error") { if error == "true" { let response_bytes = FpHealthcheckError { - fp_healthcheck_error: "this is placeholder for a custom error".to_string() - }.to_bytes(); + fp_healthcheck_error: "this is placeholder for a custom error".to_string(), + } + .to_bytes(); ctx.insert_response_header("x-fp-healthcheck-error", "response-header-error"); return APIHandlerResponse { @@ -297,7 +312,8 @@ impl ForwardHandler { let response_bytes = FpHealthcheckSuccess { fp_healthcheck_success: "this is placeholder for a custom body".to_string(), - }.to_bytes(); + } + .to_bytes(); ctx.insert_response_header("x-fp-healthcheck-success", "response-header-success"); @@ -306,4 +322,4 @@ impl ForwardHandler { body: Some(response_bytes), }; } -} \ No newline at end of file +} diff --git a/forward-proxy/src/handler/types/mod.rs b/forward-proxy/src/handler/types/mod.rs index 0ec9b86..e006218 100644 --- a/forward-proxy/src/handler/types/mod.rs +++ b/forward-proxy/src/handler/types/mod.rs @@ -1,2 +1,2 @@ -pub mod response; pub mod request; +pub mod response; diff --git a/forward-proxy/src/handler/types/request.rs b/forward-proxy/src/handler/types/request.rs index 5f52bb8..39ad41c 100644 --- a/forward-proxy/src/handler/types/request.rs +++ b/forward-proxy/src/handler/types/request.rs @@ -1,5 +1,5 @@ -use serde::{Deserialize, Serialize}; use pingora_router::handler::RequestBodyTrait; +use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, Debug)] pub struct InitTunnelRequest { diff --git a/forward-proxy/src/handler/types/response.rs b/forward-proxy/src/handler/types/response.rs index 478eff2..c7f0bc0 100644 --- a/forward-proxy/src/handler/types/response.rs +++ b/forward-proxy/src/handler/types/response.rs @@ -1,5 +1,5 @@ -use serde::{Deserialize, Serialize}; use pingora_router::handler::ResponseBodyTrait; +use serde::{Deserialize, Serialize}; use serde_json::Error; #[derive(Serialize, Deserialize, Debug)] @@ -10,13 +10,14 @@ pub struct ErrorResponse { impl ResponseBodyTrait for ErrorResponse { fn from_json_err(err: Error) -> Option { Some(ErrorResponse { - error: err.to_string() + error: err.to_string(), }) } } #[derive(Serialize, Deserialize, Debug)] -pub struct InitTunnelResponseFromRP { // this struct should match ReverseProxy's Response +pub struct InitTunnelResponseFromRP { + // this struct should match ReverseProxy's Response pub public_key: Vec, pub t_b_hash: Vec, #[serde(rename = "jwt1")] // a little bit of obfuscation @@ -28,7 +29,8 @@ pub struct InitTunnelResponseFromRP { // this struct should match ReverseProxy's impl ResponseBodyTrait for InitTunnelResponseFromRP {} #[derive(Serialize, Deserialize, Debug)] -pub struct InitTunnelResponseToINT { // this struct should match Interceptor's expected Response +pub struct InitTunnelResponseToINT { + // this struct should match Interceptor's expected Response pub ephemeral_public_key: Vec, pub t_b_hash: Vec, #[serde(rename = "jwt1")] @@ -38,7 +40,7 @@ pub struct InitTunnelResponseToINT { // this struct should match Interceptor's e #[serde(rename = "public_key")] pub ntor_static_public_key: Vec, #[serde(rename = "server_id")] - pub ntor_server_id: String + pub ntor_server_id: String, } impl ResponseBodyTrait for InitTunnelResponseToINT {} diff --git a/forward-proxy/src/main.rs b/forward-proxy/src/main.rs index 0b0e18a..58b03a5 100644 --- a/forward-proxy/src/main.rs +++ b/forward-proxy/src/main.rs @@ -1,15 +1,16 @@ -mod proxy; -mod handler; +use pingora::prelude::*; +use tokio::runtime::Runtime; +use tracing::{debug, info}; + mod config; +mod handler; +mod proxy; mod statistics; -use crate::handler::ForwardHandler; +use config::FPConfig; +use handler::ForwardHandler; use proxy::ForwardProxy; -use pingora::prelude::*; -use tokio::runtime::Runtime; -use crate::config::FPConfig; -use tracing::{info, debug}; -use crate::statistics::Statistics; +use statistics::Statistics; fn load_config() -> FPConfig { // Load environment variables from .env file @@ -30,7 +31,6 @@ fn main() { let rt = Runtime::new().unwrap(); rt.block_on(Statistics::init_influxdb_client(&config.influxdb_config)); - let _logger_guard = utils::log::init_logger( config.log_config.log_level.clone(), config.log_config.log_format.clone(), @@ -41,7 +41,8 @@ fn main() { let mut server = Server::new(Some(Opt { conf: std::env::var("SERVER_CONF").ok(), ..Default::default() - })).expect("Failed to create server"); + })) + .expect("Failed to create server"); server.bootstrap(); let fp_handler = ForwardHandler::new(config.handler_config); @@ -55,7 +56,10 @@ fn main() { server.add_service(proxy); - info!("Starting server at {}:{}", config.listen_address, config.listen_port); + info!( + "Starting server at {}:{}", + config.listen_address, config.listen_port + ); server.run_forever(); } diff --git a/forward-proxy/src/proxy.rs b/forward-proxy/src/proxy.rs index fcec0b8..3643476 100644 --- a/forward-proxy/src/proxy.rs +++ b/forward-proxy/src/proxy.rs @@ -1,8 +1,6 @@ -use crate::config::TlsConfig; -use crate::handler::ForwardHandler; -use crate::handler::consts::{CtxKeys, HeaderKeys, LogTypes, RequestPaths}; -use crate::handler::types::response::ErrorResponse; -use crate::statistics::Statistics; +use std::sync::Arc; +use std::time::Duration; + use async_trait::async_trait; use boring::x509::X509; use bytes::Bytes; @@ -17,10 +15,14 @@ use pingora_error::ErrorType; use pingora_router::ctx::{Layer8Context, Layer8ContextTrait}; use pingora_router::handler::ResponseBodyTrait; use reqwest::header::TRANSFER_ENCODING; -use std::sync::Arc; -use std::time::Duration; use tracing::{debug, error, info}; +use crate::config::TlsConfig; +use crate::handler::ForwardHandler; +use crate::handler::consts::{CtxKeys, HeaderKeys, LogTypes, RequestPaths}; +use crate::handler::types::response::ErrorResponse; +use crate::statistics::Statistics; + pub struct ForwardProxy { tls_config: TlsConfig, handler: ForwardHandler, @@ -66,11 +68,11 @@ impl ProxyHttp for ForwardProxy { let correlation_id = ctx.get_correlation_id(); let addrs = ctx - .get(&CtxKeys::UPSTREAM_ADDRESS.to_string()) + .get(&CtxKeys::UPSTREAM_ADDRESS) .unwrap_or(&"".to_string()) .clone(); let sni = ctx - .get(&CtxKeys::UPSTREAM_SNI.to_string()) + .get(&CtxKeys::UPSTREAM_SNI) .unwrap_or(&"".to_string()) .clone(); info!( @@ -361,7 +363,7 @@ impl ProxyHttp for ForwardProxy { request_summary = session.request_summary(), "Forward proxy passing through request body unchanged." ); - *body = Some(Bytes::copy_from_slice(ctx.get_request_body().as_slice())); + *body = Some(Bytes::copy_from_slice(ctx.get_request_body())); return Ok(()); } }; @@ -441,7 +443,6 @@ impl ProxyHttp for ForwardProxy { "{} token is empty", HeaderKeys::INT_FP_JWT ); - return Err(pingora::Error::new(pingora::ErrorType::HTTPStatus( u16::from(StatusCode::BAD_REQUEST), ))); @@ -545,7 +546,7 @@ impl ProxyHttp for ForwardProxy { request_summary = session.request_summary(), "Forward proxy passing through response body unchanged." ); - *body = Some(Bytes::copy_from_slice(ctx.get_response_body().as_slice())); + *body = Some(Bytes::copy_from_slice(ctx.get_response_body())); return Ok(None); } }; @@ -600,7 +601,7 @@ impl ProxyHttp for ForwardProxy { || session.req_header().uri.path() == RequestPaths::INIT_TUNNEL) { let client_id = ctx - .get(&CtxKeys::BACKEND_AUTH_CLIENT_ID.to_string()) + .get(&CtxKeys::BACKEND_AUTH_CLIENT_ID) .unwrap_or(&"".to_string()) .clone(); let request_path = session.req_header().uri.path().to_string(); @@ -647,7 +648,7 @@ impl ProxyHttp for ForwardProxy { || e.etype == ErrorType::ConnectRefused { let mut addrs = ctx - .get(&CtxKeys::UPSTREAM_ADDRESS.to_string()) + .get(&CtxKeys::UPSTREAM_ADDRESS) .unwrap_or(&"".to_string()) .clone(); diff --git a/pingora-router/Cargo.toml b/pingora-router/Cargo.toml index 6be7d3e..92d849f 100644 --- a/pingora-router/Cargo.toml +++ b/pingora-router/Cargo.toml @@ -9,4 +9,5 @@ pingora = { version = "0.5.0", features = ["lb", "boringssl"] } serde = "1.0.219" serde_json = "1.0.140" chrono = "0.4.40" -uuid = "1.16.0" \ No newline at end of file +uuid = { version = "1.16.0", features = ["v4"] } +bincode = "2.0.1" diff --git a/pingora-router/src/ctx.rs b/pingora-router/src/ctx.rs index 1426926..69ceeb8 100644 --- a/pingora-router/src/ctx.rs +++ b/pingora-router/src/ctx.rs @@ -1,10 +1,12 @@ use std::collections::HashMap; use std::time::Instant; + use pingora::http::{Method, RequestHeader, StatusCode}; use pingora::proxy::Session; -use crate::utils::get_request_body; use uuid; +use crate::utils::get_request_body; + /* * Each type in this crate serves a specific purpose and may be updated as requirements evolve. */ @@ -22,10 +24,16 @@ pub struct Layer8ContextRequestSummary { impl Layer8ContextRequestSummary { pub(crate) fn from(session: &Session) -> Self { let method = session.req_header().method.clone(); - let scheme = session.req_header().uri.scheme() + let scheme = session + .req_header() + .uri + .scheme() .map(|s| s.to_string()) .unwrap_or_else(|| "".to_string()); - let host = session.req_header().uri.host() + let host = session + .req_header() + .uri + .host() .map(|h| h.to_string()) .unwrap_or_else(|| "".to_string()); let path = session.req_header().uri.path().to_string(); @@ -124,11 +132,10 @@ impl Layer8Context { pub async fn read_request_body(&mut self, session: &mut Session) -> pingora::Result { match get_request_body(session).await { Ok(body) => self.request.body = body, - Err(err) => return Err(err) + Err(err) => return Err(err), }; Ok(true) } - } impl Layer8ContextTrait for Layer8Context { @@ -149,8 +156,10 @@ impl Layer8ContextTrait for Layer8Context { fn set_request_header(&mut self, header: RequestHeader) { for (key, val) in header.headers.iter() { - self.request.header.insert(key.to_string(), val.to_str().unwrap_or("").to_string()); - }; + self.request + .header + .insert(key.to_string(), val.to_str().unwrap_or("").to_string()); + } } fn get_request_header(&self) -> &Layer8Header { @@ -158,7 +167,9 @@ impl Layer8ContextTrait for Layer8Context { } fn insert_response_header(&mut self, key: &str, val: &str) { - self.response.header.insert(key.to_lowercase().to_string(), val.to_string()); + self.response + .header + .insert(key.to_lowercase().to_string(), val.to_string()); } fn remove_response_header(&mut self, key: &str) -> Option { @@ -177,8 +188,8 @@ impl Layer8ContextTrait for Layer8Context { self.request.body.extend(body) } - fn get_request_body(&self) -> Vec { - self.request.body.clone() + fn get_request_body(&self) -> &[u8] { + &self.request.body } fn set_response_body(&mut self, body: Vec) { @@ -189,8 +200,8 @@ impl Layer8ContextTrait for Layer8Context { self.response.body.extend(body); } - fn get_response_body(&self) -> Vec { - self.response.body.clone() + fn get_response_body(&self) -> &[u8] { + &self.response.body } fn get(&self, key: &str) -> Option<&String> { @@ -245,10 +256,10 @@ pub trait Layer8ContextTrait { fn get_response_header(&self) -> &Layer8Header; fn set_request_body(&mut self, body: Vec); fn extend_request_body(&mut self, body: Vec); - fn get_request_body(&self) -> Vec; + fn get_request_body(&self) -> &[u8]; fn set_response_body(&mut self, body: Vec); fn extend_response_body(&mut self, body: Vec); - fn get_response_body(&self) -> Vec; + fn get_response_body(&self) -> &[u8]; fn get(&self, key: &str) -> Option<&String>; fn set(&mut self, key: String, value: String); fn set_request_summary(&mut self, summary: Layer8ContextRequestSummary); diff --git a/pingora-router/src/handler.rs b/pingora-router/src/handler.rs index e2f3413..d67d081 100644 --- a/pingora-router/src/handler.rs +++ b/pingora-router/src/handler.rs @@ -1,9 +1,9 @@ -use std::fmt::Debug; -use pingora::http::StatusCode; use crate::ctx::Layer8Context; use futures::future::BoxFuture; +use pingora::http::StatusCode; use serde::de::Deserialize; use serde::ser::Serialize; +use std::fmt::Debug; /* * Each type in this crate has a specific purpose and may be updated as requirements evolve. @@ -41,7 +41,9 @@ use serde::ser::Serialize; /// async move { h.handle(ctx).await }.boxed() /// }); /// ``` -pub type APIHandler = Box Fn(&'a T, &'a mut Layer8Context) -> BoxFuture<'a, APIHandlerResponse> + Send + Sync>; +pub type APIHandler = Box< + dyn for<'a> Fn(&'a T, &'a mut Layer8Context) -> BoxFuture<'a, APIHandlerResponse> + Send + Sync, +>; /// `APIHandlerResponse` contains information returned by handlers and can be /// shared across handlers during request processing. @@ -68,13 +70,15 @@ pub trait ResponseBodyTrait: Serialize + for<'de> Deserialize<'de> + Debug { serde_json::to_vec(self).unwrap() } - fn from_bytes(bytes: Vec) -> Result, serde_json::Error> { - serde_json::from_slice(&bytes) + fn from_bytes(bytes: &[u8]) -> Result, serde_json::Error> { + serde_json::from_slice(bytes) } /// Override this method to handle error serialization if your handler implements /// the `DefaultHandler` trait. - fn from_json_err(_err: serde_json::Error) -> Option {None} + fn from_json_err(_err: serde_json::Error) -> Option { + None + } } /// `RequestBodyTrait` provides a default method to deserialize the request body bytes @@ -89,8 +93,8 @@ pub trait RequestBodyTrait: Serialize + for<'de> Deserialize<'de> + Debug { serde_json::to_vec(self).unwrap() } - fn from_bytes(bytes: Vec) -> Result, serde_json::Error> { - serde_json::from_slice(&bytes) + fn from_bytes(bytes: &[u8]) -> Result, serde_json::Error> { + serde_json::from_slice(bytes) } } @@ -107,11 +111,12 @@ pub trait RequestBodyTrait: Serialize + for<'de> Deserialize<'de> + Debug { /// If deserialization fails, it returns no body, an error response of type `E: impl /// ResponseBodyTrait` (constructed from the JSON error), and a 400 Bad Request status. pub trait DefaultHandlerTrait { - fn parse_request_body(data: &Vec) -> Result> - { - match T::from_bytes(data.clone()) { + fn parse_request_body( + data: &[u8], + ) -> Result> { + match T::from_bytes(data) { Ok(body) => Ok(*body), - Err(e) => Err(E::from_json_err(e)) + Err(e) => Err(E::from_json_err(e)), } } -} \ No newline at end of file +} diff --git a/pingora-router/src/lib.rs b/pingora-router/src/lib.rs index 7b98195..8bc7f7e 100644 --- a/pingora-router/src/lib.rs +++ b/pingora-router/src/lib.rs @@ -1,4 +1,4 @@ pub mod ctx; pub mod handler; -mod utils; pub mod router; +mod utils; diff --git a/pingora-router/src/router.rs b/pingora-router/src/router.rs index 36a99fb..be41e28 100644 --- a/pingora-router/src/router.rs +++ b/pingora-router/src/router.rs @@ -1,7 +1,7 @@ -use std::collections::HashMap; -use pingora::http::{Method, StatusCode}; use crate::ctx::{Layer8Context, Layer8ContextTrait}; use crate::handler::{APIHandler, APIHandlerResponse}; +use pingora::http::{Method, StatusCode}; +use std::collections::HashMap; /// `Router` is a generic struct that manages HTTP route registration and handler dispatching. /// @@ -65,12 +65,24 @@ impl Router { } } - fn get_handlers(&self, method: &Method, path: &str) -> Option<&Box<[APIHandler]>> { + fn get_handlers(&self, method: &Method, path: &str) -> Option<&[APIHandler]> { match *method { - Method::POST => self.posts.get(path), - Method::GET => self.gets.get(path), - Method::PUT => self.puts.get(path), - Method::DELETE => self.deletes.get(path), + Method::POST => self + .posts + .get(path) + .map(|box_handlers| box_handlers.as_ref()), + Method::GET => self + .gets + .get(path) + .map(|box_handlers| box_handlers.as_ref()), + Method::PUT => self + .puts + .get(path) + .map(|box_handlers| box_handlers.as_ref()), + Method::DELETE => self + .deletes + .get(path) + .map(|box_handlers| box_handlers.as_ref()), _ => return None, } } @@ -122,5 +134,3 @@ impl Router { self.deletes.insert(self.get_base_path(&path), handlers); } } - - diff --git a/pingora-router/src/utils.rs b/pingora-router/src/utils.rs index 5c7f346..ee5bc61 100644 --- a/pingora-router/src/utils.rs +++ b/pingora-router/src/utils.rs @@ -4,16 +4,14 @@ pub(crate) async fn get_request_body(session: &mut Session) -> pingora::Result { - match option { - Some(chunk) => body.extend_from_slice(&chunk), - None => break, - } - } + Ok(option) => match option { + Some(chunk) => body.extend_from_slice(&chunk), + None => break, + }, Err(err) => { return Err(err); } } } Ok(body) -} \ No newline at end of file +} diff --git a/reverse-proxy/Cargo.toml b/reverse-proxy/Cargo.toml index f971e24..861b5b6 100644 --- a/reverse-proxy/Cargo.toml +++ b/reverse-proxy/Cargo.toml @@ -5,25 +5,30 @@ edition = "2024" [dependencies] async-trait = "0.1" +bincode = "2.0.1" +boring = "4.17.0" bytes = "1.10.1" +chrono = "0.4.40" +config = "0.15.11" +dotenv = "0.15.0" +env_logger = "0.11.7" +envy = "0.4.2" +futures = "0.3.31" +hex = "0.4.3" +ntor = { git = "https://github.com/globe-and-citizen/ntor.git", branch = "fix/use-bincode-instead-of-json" } +once_cell = "1.21.3" +pingora = { version = "0.5.0", features = ["lb", "boringssl"] } +pingora-router = { path = "../pingora-router", version = "0.1.0" } +reqwest = { version = "0.11", default-features = false, features = [ + "json", + "rustls-tls", +] } serde = { version = "1.0.219", features = ["derive"] } serde_json = "1.0.140" serde_yaml = "0.8.26" -chrono = "0.4.40" -reqwest = { version="0.11", default-features=false, features=["json", "rustls-tls"] } -tokio-rustls = "0.26.2" tokio = "1.44.2" -pingora-router = { path = "../pingora-router", version = "0.1.0" } -pingora = { version = "0.5.0", features = ["lb", "boringssl"] } -futures = "0.3.31" -boring = "4.17.0" -once_cell = "1.21.3" -dotenv = "0.15.0" -ntor = { git = "https://github.com/globe-and-citizen/ntor.git", tag = "0.1.1" } -config = "0.15.11" +tokio-rustls = "0.26.2" toml = "0.8.23" -uuid = { version = "1.16.0", features = ["v4"] } utils = { path = "../utils", version = "0.1.0" } -envy = "0.4.2" -hex = "0.4.3" tracing = "0.1.41" +uuid = { version = "1.16.0", features = ["v4"] } diff --git a/reverse-proxy/src/config.rs b/reverse-proxy/src/config.rs index 26a0b28..d909251 100644 --- a/reverse-proxy/src/config.rs +++ b/reverse-proxy/src/config.rs @@ -1,5 +1,5 @@ -use serde::Deserialize; use crate::tls_conf::TlsConfig; +use serde::Deserialize; #[derive(Debug, Deserialize, Clone)] pub struct RPConfig { @@ -10,7 +10,7 @@ pub struct RPConfig { #[serde(flatten)] pub tls: TlsConfig, #[serde(flatten)] - pub handler: HandlerConfig + pub handler: HandlerConfig, } #[derive(Debug, Deserialize, Clone)] @@ -28,7 +28,7 @@ pub(super) struct LogConfig { pub(super) struct ServerConfig { pub listen_address: String, #[serde(deserialize_with = "utils::deserializer::string_to_number")] - pub listen_port: u16 + pub listen_port: u16, } #[derive(Debug, Deserialize, Clone)] diff --git a/reverse-proxy/src/handler/common/consts.rs b/reverse-proxy/src/handler/common/consts.rs index e85559a..cfef011 100644 --- a/reverse-proxy/src/handler/common/consts.rs +++ b/reverse-proxy/src/handler/common/consts.rs @@ -16,4 +16,4 @@ impl LogTypes { #[allow(dead_code)] pub const HEALTHCHECK: &'static str = "HEALTHCHECK"; pub const TLS_HANDSHAKE: &'static str = "TLS_HANDSHAKE"; -} \ No newline at end of file +} diff --git a/reverse-proxy/src/handler/common/handler.rs b/reverse-proxy/src/handler/common/handler.rs index b2cd99c..2ff6724 100644 --- a/reverse-proxy/src/handler/common/handler.rs +++ b/reverse-proxy/src/handler/common/handler.rs @@ -1,4 +1,4 @@ /// Struct containing only associated methods (no instance methods or fields). /// The contents are quite drafting, but the idea is to handle common operations #[allow(dead_code)] -pub struct CommonHandler {} \ No newline at end of file +pub struct CommonHandler {} diff --git a/reverse-proxy/src/handler/common/mod.rs b/reverse-proxy/src/handler/common/mod.rs index a6ef363..34cd7a5 100644 --- a/reverse-proxy/src/handler/common/mod.rs +++ b/reverse-proxy/src/handler/common/mod.rs @@ -1,3 +1,3 @@ pub(crate) mod consts; -pub mod types; pub mod handler; +pub mod types; diff --git a/reverse-proxy/src/handler/common/types.rs b/reverse-proxy/src/handler/common/types.rs index 93b7cd4..251b702 100644 --- a/reverse-proxy/src/handler/common/types.rs +++ b/reverse-proxy/src/handler/common/types.rs @@ -1,17 +1,17 @@ -use std::fmt::Debug; -use serde::{Deserialize, Serialize}; use pingora_router::handler::ResponseBodyTrait; +use serde::{Deserialize, Serialize}; use serde_json::Error; +use std::fmt::Debug; #[derive(Serialize, Deserialize, Debug)] pub struct ErrorResponse { - pub error: String + pub error: String, } impl ResponseBodyTrait for ErrorResponse { fn from_json_err(err: Error) -> Option { Some(ErrorResponse { - error: err.to_string() + error: err.to_string(), }) } -} \ No newline at end of file +} diff --git a/reverse-proxy/src/handler/healthcheck/mod.rs b/reverse-proxy/src/handler/healthcheck/mod.rs index ee98941..a3d5d35 100644 --- a/reverse-proxy/src/handler/healthcheck/mod.rs +++ b/reverse-proxy/src/handler/healthcheck/mod.rs @@ -1,5 +1,5 @@ -use serde::{Deserialize, Serialize}; use pingora_router::handler::ResponseBodyTrait; +use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, Debug)] pub struct RpHealthcheckSuccess { @@ -13,4 +13,4 @@ pub struct RpHealthcheckError { pub(crate) rp_healthcheck_error: String, } -impl ResponseBodyTrait for RpHealthcheckError {} \ No newline at end of file +impl ResponseBodyTrait for RpHealthcheckError {} diff --git a/reverse-proxy/src/handler/init_tunnel/handler.rs b/reverse-proxy/src/handler/init_tunnel/handler.rs index 4811e9c..596c2c1 100644 --- a/reverse-proxy/src/handler/init_tunnel/handler.rs +++ b/reverse-proxy/src/handler/init_tunnel/handler.rs @@ -1,11 +1,12 @@ use pingora::http::StatusCode; use pingora_router::ctx::{Layer8Context, Layer8ContextTrait}; use pingora_router::handler::{APIHandlerResponse, DefaultHandlerTrait, ResponseBodyTrait}; + use crate::handler::common::types::ErrorResponse; -use crate::handler::init_tunnel::{InitEncryptedTunnelRequest}; +use crate::handler::init_tunnel::InitEncryptedTunnelRequest; /// Struct containing only associated methods (no instance methods or fields) -pub(crate) struct InitTunnelHandler {} +pub(crate) struct InitTunnelHandler; impl DefaultHandlerTrait for InitTunnelHandler {} @@ -13,18 +14,17 @@ impl InitTunnelHandler { pub(crate) async fn validate_request_body( ctx: &mut Layer8Context, _backend_url: String, - ) -> Result - { + ) -> Result { return match InitTunnelHandler::parse_request_body::< InitEncryptedTunnelRequest, - ErrorResponse + ErrorResponse, >(&ctx.get_request_body()) { Ok(res) => Ok(res), Err(err) => { let body = match err { None => None, - Some(err_response) => Some(err_response.to_bytes()) + Some(err_response) => Some(err_response.to_bytes()), }; Err(APIHandlerResponse { diff --git a/reverse-proxy/src/handler/init_tunnel/mod.rs b/reverse-proxy/src/handler/init_tunnel/mod.rs index 486ddf9..eeb3e9e 100644 --- a/reverse-proxy/src/handler/init_tunnel/mod.rs +++ b/reverse-proxy/src/handler/init_tunnel/mod.rs @@ -1,7 +1,7 @@ pub(crate) mod handler; -use serde::{Deserialize, Serialize}; use pingora_router::handler::{RequestBodyTrait, ResponseBodyTrait}; +use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, Debug)] pub struct InitEncryptedTunnelRequest { @@ -10,13 +10,6 @@ pub struct InitEncryptedTunnelRequest { impl RequestBodyTrait for InitEncryptedTunnelRequest {} -#[derive(Serialize, Deserialize, Debug)] -pub struct InitTunnelRequestToBackend { - pub success: bool, -} - -impl RequestBodyTrait for InitTunnelRequestToBackend {} - #[derive(Serialize, Deserialize, Debug)] pub struct InitEncryptedTunnelResponse { pub public_key: Vec, @@ -26,7 +19,7 @@ pub struct InitEncryptedTunnelResponse { pub int_rp_jwt: String, #[serde(rename = "jwt2")] - pub fp_rp_jwt: String + pub fp_rp_jwt: String, } impl ResponseBodyTrait for InitEncryptedTunnelResponse {} diff --git a/reverse-proxy/src/handler/mod.rs b/reverse-proxy/src/handler/mod.rs index 9a4dec3..4b8d867 100644 --- a/reverse-proxy/src/handler/mod.rs +++ b/reverse-proxy/src/handler/mod.rs @@ -1,24 +1,26 @@ use std::collections::HashMap; use std::sync::{Mutex, MutexGuard}; + use ntor::common::{InitSessionMessage, NTorParty}; use ntor::server::NTorServer; use pingora::http::StatusCode; -use tracing::{debug, info}; use pingora_router::ctx::{Layer8Context, Layer8ContextTrait}; use pingora_router::handler::{APIHandlerResponse, ResponseBodyTrait}; -use init_tunnel::handler::InitTunnelHandler; use proxy::handler::ProxyHandler; -use init_tunnel::InitEncryptedTunnelResponse; -use utils::{new_uuid}; +use tracing::info; use utils::jwt::JWTClaims; -use crate::config::{HandlerConfig, RPConfig}; -use crate::handler::common::consts::LogTypes; -use crate::handler::healthcheck::{RpHealthcheckError, RpHealthcheckSuccess}; +use utils::new_uuid; pub(crate) mod common; +mod healthcheck; mod init_tunnel; mod proxy; -mod healthcheck; + +use crate::config::{HandlerConfig, RPConfig}; +use crate::handler::common::consts::LogTypes; +use crate::handler::healthcheck::{RpHealthcheckError, RpHealthcheckSuccess}; +use init_tunnel::InitEncryptedTunnelResponse; +use init_tunnel::handler::InitTunnelHandler; thread_local! { // @@ -33,7 +35,7 @@ pub struct ReverseHandler { impl ReverseHandler { pub fn new(config: RPConfig) -> Self { - let ntor_secret = config.handler.ntor_static_secret.clone(); + let ntor_secret = config.handler.ntor_static_secret; let jwt_secret = config.handler.jwt_virtual_connection_secret.clone(); ReverseHandler { @@ -51,12 +53,10 @@ impl ReverseHandler { return match shared_secret { Some(secret) => Ok(secret.clone()), - None => { - Err(APIHandlerResponse { - status: StatusCode::UNAUTHORIZED, - body: Some("Invalid or expired nTor session ID".as_bytes().to_vec()), - }) - } + None => Err(APIHandlerResponse { + status: StatusCode::UNAUTHORIZED, + body: Some("Invalid or expired nTor session ID".as_bytes().to_vec()), + }), }; } @@ -64,13 +64,13 @@ impl ReverseHandler { let correlation_id = ctx.get_correlation_id(); // validate request body - let request_body = match InitTunnelHandler::validate_request_body( - ctx, - self.config.backend_url.clone(), - ).await { - Ok(res) => res, - Err(res) => return res - }; + let request_body = + match InitTunnelHandler::validate_request_body(ctx, self.config.backend_url.clone()) + .await + { + Ok(res) => res, + Err(res) => return res, + }; // todo I think there are prettier ways to use nTor since we are free to modify the nTor crate, but I'm lazy let mut ntor_server = NTorServer::new_with_secret( @@ -105,8 +105,8 @@ impl ReverseHandler { }; let response = InitEncryptedTunnelResponse { - public_key: init_session_response.public_key(), - t_b_hash: init_session_response.t_b_hash(), + public_key: init_session_response.public_key().to_vec(), + t_b_hash: init_session_response.t_b_hash().to_vec(), int_rp_jwt, fp_rp_jwt, }; @@ -118,9 +118,13 @@ impl ReverseHandler { "Save new nTor session: {}", ntor_session_id ); + NTOR_SHARED_SECRETS.with(|memory| { let mut guard: MutexGuard>> = memory.lock().unwrap(); - guard.insert(ntor_session_id, ntor_server.get_shared_secret().unwrap_or_default()); + guard.insert( + ntor_session_id, + ntor_server.get_shared_secret().unwrap_or_default().to_vec(), + ); }); APIHandlerResponse { @@ -169,32 +173,40 @@ impl ReverseHandler { ctx, self.config.backend_url.clone(), wrapped_request, - ).await { + ) + .await + { Ok(res) => res, Err(res) => return res, }; - return match ProxyHandler::encrypt_response_body( + let encrypted_message = ProxyHandler::encrypt_response_body( wrapped_response, self.config.ntor_server_id.clone(), shared_secret, - ) { + ); + + match encrypted_message { Ok(encrypted_message) => { + let data = bincode::encode_to_vec(&encrypted_message, bincode::config::standard()) + .expect("this struct is bincode serializable"); + APIHandlerResponse { status: StatusCode::OK, - body: Some(encrypted_message.to_bytes()), + body: Some(data), } } - Err(res) => res - }; + Err(res) => res, + } } pub async fn handle_healthcheck(&self, ctx: &mut Layer8Context) -> APIHandlerResponse { if let Some(error) = ctx.param("error") { if error == "true" { let response_bytes = RpHealthcheckError { - rp_healthcheck_error: "this is placeholder for a custom error".to_string() - }.to_bytes(); + rp_healthcheck_error: "this is placeholder for a custom error".to_string(), + } + .to_bytes(); ctx.insert_response_header("x-rp-healthcheck-error", "response-header-error"); return APIHandlerResponse { @@ -206,7 +218,8 @@ impl ReverseHandler { let response_bytes = RpHealthcheckSuccess { rp_healthcheck_success: "this is placeholder for a custom body".to_string(), - }.to_bytes(); + } + .to_bytes(); ctx.insert_response_header("x-rp-healthcheck-success", "response-header-success"); @@ -215,4 +228,4 @@ impl ReverseHandler { body: Some(response_bytes), }; } -} \ No newline at end of file +} diff --git a/reverse-proxy/src/handler/proxy/handler.rs b/reverse-proxy/src/handler/proxy/handler.rs index 0b78a5a..e5049a9 100644 --- a/reverse-proxy/src/handler/proxy/handler.rs +++ b/reverse-proxy/src/handler/proxy/handler.rs @@ -1,45 +1,56 @@ -use pingora_router::ctx::{Layer8Context, Layer8ContextTrait}; -use reqwest::header::HeaderMap; -use pingora_router::handler::{APIHandlerResponse, DefaultHandlerTrait, ResponseBodyTrait}; -use ntor::common::NTorParty; +use ntor::common::{EncryptedMessage, NTorParty}; use ntor::server::NTorServer; -use reqwest::Client; use pingora::http::StatusCode; +use pingora_router::ctx::{Layer8Context, Layer8ContextTrait}; +use pingora_router::handler::{APIHandlerResponse, ResponseBodyTrait}; +use reqwest::Client; +use reqwest::header::HeaderMap; use tracing::{debug, error, info}; use utils::bytes_to_json; use utils::jwt::JWTClaims; + use crate::handler::common::consts::{HeaderKeys, LogTypes}; use crate::handler::common::types::ErrorResponse; -use crate::handler::proxy::{EncryptedMessage, L8ResponseObject, L8RequestObject}; +use crate::handler::proxy::{L8RequestObject, L8ResponseObject}; /// Struct containing only associated methods (no instance methods or fields) -pub struct ProxyHandler {} - -impl DefaultHandlerTrait for ProxyHandler {} +pub struct ProxyHandler; impl ProxyHandler { + fn parse_request_body(data: &[u8]) -> Result { + match bincode::decode_from_slice(data, bincode::config::standard()) { + Ok((body, _)) => Ok(body), + Err(e) => Err(e.to_string()), + } + } fn validate_jwt_token( ctx: &mut Layer8Context, header_key: &str, - jwt_secret: &Vec + jwt_secret: &[u8], ) -> Result { match ctx.get_request_header().get(header_key) { None => { return Err(APIHandlerResponse { status: StatusCode::BAD_REQUEST, - body: Some(ErrorResponse { - error: format!("Missing {} header", header_key.to_string()), - }.to_bytes()), + body: Some( + ErrorResponse { + error: format!("Missing {} header", header_key), + } + .to_bytes(), + ), }); - }, + } Some(token) => { if token.is_empty() { return Err(APIHandlerResponse { status: StatusCode::BAD_REQUEST, - body: Some(ErrorResponse { - error: format!("Empty {} header", header_key.to_string()), - }.to_bytes()), + body: Some( + ErrorResponse { + error: format!("Empty {} header", header_key), + } + .to_bytes(), + ), }); } @@ -48,19 +59,22 @@ impl ProxyHandler { Ok(data) => Ok(data.claims), Err(err) => { error!( - correlation_id=ctx.get_correlation_id(), - log_type=LogTypes::HANDLE_PROXY_REQUEST, + correlation_id = ctx.get_correlation_id(), + log_type = LogTypes::HANDLE_PROXY_REQUEST, "Error verifying {} token: {:?}", header_key, err ); Err(APIHandlerResponse { status: StatusCode::BAD_REQUEST, - body: Some(ErrorResponse { - error: err.to_string(), - }.to_bytes()), + body: Some( + ErrorResponse { + error: err.to_string(), + } + .to_bytes(), + ), }) - }, + } } } } @@ -70,14 +84,13 @@ impl ProxyHandler { pub(crate) fn validate_request_headers( ctx: &mut Layer8Context, jwt_secret: &Vec, - ) -> Result - { + ) -> Result { // verify fp_rp_jwt header match ProxyHandler::validate_jwt_token(ctx, HeaderKeys::FP_RP_JWT, jwt_secret) { Ok(_claims) => { // todo!() nothing to validate at the moment } - Err(err) => return Err(err) + Err(err) => return Err(err), } return match ProxyHandler::validate_jwt_token(ctx, HeaderKeys::INT_RP_JWT_KEY, jwt_secret) { @@ -87,43 +100,40 @@ impl ProxyHandler { Some(ntor_session_id) => Ok(ntor_session_id), None => Err(APIHandlerResponse { status: StatusCode::BAD_REQUEST, - body: Some(ErrorResponse { - error: "Missing ntor_session_id in JWT claims".to_string(), - }.to_bytes()), + body: Some( + ErrorResponse { + error: "Missing ntor_session_id in JWT claims".to_string(), + } + .to_bytes(), + ), }), } } - Err(err) => return Err(err) + Err(err) => return Err(err), }; } pub(crate) fn validate_request_body( - ctx: &mut Layer8Context - ) -> Result - { - let correlation_id = ctx.get_correlation_id(); - - match ProxyHandler::parse_request_body::< - EncryptedMessage, - ErrorResponse - >(&ctx.get_request_body()) { - Ok(res) => Ok(res), + ctx: &mut Layer8Context, + ) -> Result { + match ProxyHandler::parse_request_body(&ctx.get_request_body()) { + Ok(data) => Ok(data), Err(err) => { - let body = match err { - None => None, - Some(err_response) => { - error!( - %correlation_id, - log_type=LogTypes::HANDLE_PROXY_REQUEST, - "Error parsing request body: {}", - err_response.error - ); - Some(err_response.to_bytes()) - } - }; + let correlation_id = ctx.get_correlation_id(); + error!( + %correlation_id, + log_type=LogTypes::HANDLE_PROXY_REQUEST, + "Error parsing request body: {}", + err + ); + Err(APIHandlerResponse { status: StatusCode::BAD_REQUEST, - body, + body: Some( + format!("Failed to parse request body: {}", err) + .as_bytes() + .to_vec(), + ), }) } } @@ -133,31 +143,29 @@ impl ProxyHandler { request_body: EncryptedMessage, ntor_server_id: String, shared_secret: Vec, - ) -> Result - { + ) -> Result { let mut ntor_server = NTorServer::new(ntor_server_id); - ntor_server.set_shared_secret(shared_secret.clone()); + ntor_server.set_shared_secret(shared_secret); // Decrypt the request body using nTor shared secret - let decrypted_data = ntor_server.decrypt(ntor::common::EncryptedMessage { - nonce: <[u8; 12]>::try_from(request_body.nonce).unwrap_or_default(), - data: request_body.data, - }).map_err(|err| { + let decrypted_data = ntor_server.decrypt(request_body).map_err(|err| { return APIHandlerResponse { status: StatusCode::BAD_REQUEST, body: Some(format!("Decryption failed: {}", err).as_bytes().to_vec()), }; })?; - // let decrypted_data = request_body.data; // parse decrypted data into WrappedUserRequest - let wrapped_request: L8RequestObject = bytes_to_json(decrypted_data) - .map_err(|err| { - return APIHandlerResponse { - status: StatusCode::BAD_REQUEST, - body: Some(format!("Failed to parse request body: {}", err).as_bytes().to_vec()), - }; - })?; + let wrapped_request: L8RequestObject = bytes_to_json(&decrypted_data).map_err(|err| { + return APIHandlerResponse { + status: StatusCode::BAD_REQUEST, + body: Some( + format!("Failed to parse request body: {}", err) + .as_bytes() + .to_vec(), + ), + }; + })?; Ok(wrapped_request) } @@ -165,9 +173,8 @@ impl ProxyHandler { pub(crate) async fn rebuild_user_request( ctx: &Layer8Context, backend_url: String, - wrapped_request: L8RequestObject - ) -> Result - { + wrapped_request: L8RequestObject, + ) -> Result { let correlation_id = ctx.get_correlation_id(); let header_map = utils::hashmap_to_headermap(&wrapped_request.headers) .unwrap_or_else(|_| HeaderMap::new()); @@ -188,10 +195,11 @@ impl ProxyHandler { "Send reconstructed request to origin backend URL: {}", origin_url ); - let response = client.request( - wrapped_request.method.parse().unwrap_or_default(), - origin_url.as_str(), - ) + let response = client + .request( + wrapped_request.method.parse().unwrap_or_default(), + origin_url.as_str(), + ) .headers(header_map.clone()) .body(wrapped_request.body) .send() @@ -200,7 +208,8 @@ impl ProxyHandler { return match response { Ok(success_res) => { let status = success_res.status().as_u16(); - let status_text = success_res.status() + let status_text = success_res + .status() .canonical_reason() .unwrap_or("OK") .to_string(); @@ -236,7 +245,9 @@ impl ProxyHandler { "Error while building request to BE: {:?}", err ); - let status = err.status().unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR); + let status = err + .status() + .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR); let err_body = ErrorResponse { error: format!("Backend error: {}", status), }; @@ -253,24 +264,18 @@ impl ProxyHandler { response_body: L8ResponseObject, ntor_server_id: String, shared_secret: Vec, - ) -> Result - { + ) -> Result { let mut ntor_server = NTorServer::new(ntor_server_id); ntor_server.set_shared_secret(shared_secret); let data = response_body.to_bytes(); // Encrypt the response body using nTor shared secret - let encrypted_data = ntor_server.encrypt(data).map_err(|err| { + ntor_server.encrypt(&data).map_err(|err| { return APIHandlerResponse { status: StatusCode::INTERNAL_SERVER_ERROR, body: Some(format!("Encryption failed: {}", err).as_bytes().to_vec()), }; - })?; - - Ok(EncryptedMessage { - nonce: encrypted_data.nonce.to_vec(), - data: encrypted_data.data, }) } } diff --git a/reverse-proxy/src/handler/proxy/mod.rs b/reverse-proxy/src/handler/proxy/mod.rs index 0458ea0..f5b6c73 100644 --- a/reverse-proxy/src/handler/proxy/mod.rs +++ b/reverse-proxy/src/handler/proxy/mod.rs @@ -1,24 +1,15 @@ pub(crate) mod handler; -use std::collections::HashMap; use pingora_router::handler::{RequestBodyTrait, ResponseBodyTrait}; use serde::{Deserialize, Serialize}; - -#[derive(Serialize, Deserialize, Debug)] -pub struct EncryptedMessage { - pub nonce: Vec, - pub data: Vec, -} - -impl RequestBodyTrait for EncryptedMessage {} -impl ResponseBodyTrait for EncryptedMessage {} +use std::collections::HashMap; #[derive(Serialize, Deserialize, Debug)] pub struct L8RequestObject { pub method: String, pub uri: String, pub headers: HashMap, - pub body: Vec + pub body: Vec, } impl RequestBodyTrait for L8RequestObject {} @@ -31,7 +22,6 @@ pub struct L8ResponseObject { pub ok: bool, pub url: String, pub redirected: bool, - /* Other fields are ignored because reqwest does not support */ } diff --git a/reverse-proxy/src/main.rs b/reverse-proxy/src/main.rs index fe31a78..a03eab0 100644 --- a/reverse-proxy/src/main.rs +++ b/reverse-proxy/src/main.rs @@ -1,29 +1,31 @@ -mod handler; -mod proxy; -mod tls_conf; +use std::sync::Arc; -use crate::handler::ReverseHandler; use futures::FutureExt; use pingora::server::Server; use pingora::server::configuration::Opt; use pingora::{listeners::tls::TlsSettings, prelude::http_proxy_service}; use pingora_router::handler::APIHandler; use pingora_router::router::Router; -use std::sync::Arc; use tracing::{debug, error}; -use crate::config::RPConfig; -use crate::proxy::ReverseProxy; mod config; +mod handler; +mod proxy; +mod tls_conf; +use config::RPConfig; +use handler::ReverseHandler; +use proxy::ReverseProxy; fn load_config() -> RPConfig { // Load environment variables from .env file dotenv::dotenv().ok(); // Deserialize from env vars - let config: RPConfig = envy::from_env().map_err(|e| { - error!("Failed to load configuration: {}", e); - }).unwrap(); + let config: RPConfig = envy::from_env() + .map_err(|e| { + error!("Failed to load configuration: {}", e); + }) + .unwrap(); debug!(name: "RPConfig", value = ?config); config @@ -43,7 +45,8 @@ fn main() { let mut my_server = Server::new(Some(Opt { conf: std::env::var("SERVER_CONF").ok(), ..Default::default() - })).unwrap(); + })) + .unwrap(); my_server.bootstrap(); let handle_init_tunnel: APIHandler> = @@ -61,26 +64,22 @@ fn main() { router.post("/proxy".to_string(), Box::new([handle_proxy])); router.get("/healthcheck".to_string(), Box::new([handle_healthcheck])); - let mut my_proxy = http_proxy_service( - &my_server.configuration, - ReverseProxy::new(router), - ); + let mut my_proxy = http_proxy_service(&my_server.configuration, ReverseProxy::new(router)); if rp_config.tls.enable_tls { my_proxy.add_tls_with_settings( &format!( "{}:{}", - rp_config.server.listen_address, - rp_config.server.listen_port + rp_config.server.listen_address, rp_config.server.listen_port ), None, - TlsSettings::with_callbacks(Box::new(rp_config.tls)).expect("Cannot set TlsSettings callbacks") + TlsSettings::with_callbacks(Box::new(rp_config.tls)) + .expect("Cannot set TlsSettings callbacks"), ); } else { my_proxy.add_tcp(&format!( "{}:{}", - rp_config.server.listen_address, - rp_config.server.listen_port + rp_config.server.listen_address, rp_config.server.listen_port )); } diff --git a/reverse-proxy/src/proxy.rs b/reverse-proxy/src/proxy.rs index 00a137c..3f508e9 100644 --- a/reverse-proxy/src/proxy.rs +++ b/reverse-proxy/src/proxy.rs @@ -1,38 +1,41 @@ -use pingora::prelude::{HttpPeer, ProxyHttp}; -use pingora::proxy::Session; -use pingora::http::{ResponseHeader, StatusCode}; use async_trait::async_trait; use bytes::Bytes; -use tracing::{debug, info}; +use pingora::http::{ResponseHeader, StatusCode}; +use pingora::prelude::{HttpPeer, ProxyHttp}; +use pingora::proxy::Session; use pingora_router::ctx::{Layer8Context, Layer8ContextTrait}; use pingora_router::router::Router; +use tracing::info; + use crate::handler::common::consts::LogTypes; pub struct ReverseProxy { - router: Router + router: Router, } impl ReverseProxy { pub fn new(router: Router) -> Self { - ReverseProxy { - router - } + ReverseProxy { router } } async fn set_headers( session: &mut Session, ctx: &mut Layer8Context, - response_status: StatusCode + response_status: StatusCode, ) -> pingora::Result<()> { let mut header = ResponseHeader::build(response_status, None)?; let response_header = ctx.get_response_header().clone(); for (key, val) in response_header.iter() { - header.insert_header(key.clone(), val.clone()).unwrap_or_default(); - }; + header + .insert_header(key.clone(), val.clone()) + .unwrap_or_default(); + } // Common headers - header.insert_header("Content-Type", "application/json").unwrap_or_default(); + header + .insert_header("Content-Type", "application/json") + .unwrap_or_default(); header .insert_header("Access-Control-Allow-Origin", "*") .unwrap_or_default(); @@ -70,13 +73,16 @@ impl ProxyHttp for ReverseProxy { _session: &mut Session, _ctx: &mut Self::CTX, ) -> pingora::Result> { - let peer: Box = - Box::new(HttpPeer::new("", false, "".to_string())); + let peer: Box = Box::new(HttpPeer::new("", false, "".to_string())); Ok(peer) } /// Handle request/response data by creating a new request to BE and respond to FP - async fn request_filter(&self, session: &mut Session, ctx: &mut Self::CTX) -> pingora::Result + async fn request_filter( + &self, + session: &mut Session, + ctx: &mut Self::CTX, + ) -> pingora::Result where Self::CTX: Send + Sync, { @@ -112,7 +118,9 @@ impl ProxyHttp for ReverseProxy { ReverseProxy::::set_headers(session, ctx, handler_response.status).await?; // Write the response body to the session after setting headers - session.write_response_body(Some(Bytes::from(response_bytes)), true).await?; + session + .write_response_body(Some(Bytes::from(response_bytes)), true) + .await?; Ok(true) } diff --git a/reverse-proxy/src/tls_conf.rs b/reverse-proxy/src/tls_conf.rs index 64d0bf7..f9f6d35 100644 --- a/reverse-proxy/src/tls_conf.rs +++ b/reverse-proxy/src/tls_conf.rs @@ -1,3 +1,4 @@ +use crate::handler::common::consts::LogTypes; use boring::{ pkey::{PKey, Public}, ssl::{SslAlert, SslRef, SslVerifyError, SslVerifyMode}, @@ -5,7 +6,6 @@ use boring::{ use pingora::{listeners::TlsAccept, protocols::tls::TlsRef}; use serde::Deserialize; use tracing::{debug, error, info}; -use crate::handler::common::consts::LogTypes; #[derive(Debug, Deserialize, Clone)] pub struct TlsConfig { @@ -23,7 +23,7 @@ impl TlsAccept for TlsConfig { ssl.set_hostname("localhost") .inspect_err(|e| { error!( - log_type=LogTypes::TLS_HANDSHAKE, + log_type = LogTypes::TLS_HANDSHAKE, "Failed to set hostname: {}", e ); }) @@ -34,7 +34,7 @@ impl TlsAccept for TlsConfig { let key = PKey::private_key_from_pem(&self.key.clone().into_bytes()) .inspect_err(|e| { error!( - log_type=LogTypes::TLS_HANDSHAKE, + log_type = LogTypes::TLS_HANDSHAKE, "Failed to parse server private key: {}", e ); }) @@ -42,7 +42,7 @@ impl TlsAccept for TlsConfig { ssl.set_private_key(&key) .inspect_err(|e| { error!( - log_type=LogTypes::TLS_HANDSHAKE, + log_type = LogTypes::TLS_HANDSHAKE, "Failed to set server private key: {}", e ); }) @@ -54,7 +54,7 @@ impl TlsAccept for TlsConfig { let cert = boring::x509::X509::from_pem(&self.cert.clone().into_bytes()) .inspect_err(|e| { error!( - log_type=LogTypes::TLS_HANDSHAKE, + log_type = LogTypes::TLS_HANDSHAKE, "Failed to parse server certificate: {}", e ); }) @@ -63,7 +63,7 @@ impl TlsAccept for TlsConfig { ssl.set_certificate(&cert) .inspect_err(|e| { error!( - log_type=LogTypes::TLS_HANDSHAKE, + log_type = LogTypes::TLS_HANDSHAKE, "Failed to set server certificate: {}", e ); }) @@ -74,7 +74,7 @@ impl TlsAccept for TlsConfig { let ca_cert = boring::x509::X509::from_pem(&self.ca_cert.clone().into_bytes()) .inspect_err(|e| { error!( - log_type=LogTypes::TLS_HANDSHAKE, + log_type = LogTypes::TLS_HANDSHAKE, "Failed to parse CA certificate: {}", e ); }) @@ -87,10 +87,12 @@ impl TlsAccept for TlsConfig { } } +// Callback function type for verifying the client certificate +type VerifyCallback = + Box Result<(), SslVerifyError> + 'static + Sync + Send>; + impl TlsConfig { - fn verify_callback( - ca_cert_pub_key: PKey, - ) -> Box Result<(), SslVerifyError> + 'static + Sync + Send> { + fn verify_callback(ca_cert_pub_key: PKey) -> VerifyCallback { Box::new(move |ssl| -> Result<(), SslVerifyError> { Self::verify_client_file(&ca_cert_pub_key, ssl) }) @@ -102,7 +104,7 @@ impl TlsConfig { ) -> Result<(), SslVerifyError> { if ssl.verify_mode() != SslVerifyMode::PEER { error!( - log_type=LogTypes::TLS_HANDSHAKE, + log_type = LogTypes::TLS_HANDSHAKE, "SSL verify mode is not set to PEER, cannot verify client certificate" ); return Err(SslVerifyError::Invalid(SslAlert::INTERNAL_ERROR)); @@ -112,7 +114,7 @@ impl TlsConfig { Some(val) => val, None => { error!( - log_type=LogTypes::TLS_HANDSHAKE, + log_type = LogTypes::TLS_HANDSHAKE, "Failed to get client certificate" ); return Err(SslVerifyError::Invalid(SslAlert::NO_CERTIFICATE)); @@ -123,15 +125,18 @@ impl TlsConfig { debug!("Debug Client certificate: {:?}", client_cert.subject_name()); // Verify the client certificate against the server's CA - if !client_cert.verify(&server_ca_public_key).unwrap_or_default() { + if !client_cert + .verify(&server_ca_public_key) + .unwrap_or_default() + { error!( - log_type=LogTypes::TLS_HANDSHAKE, + log_type = LogTypes::TLS_HANDSHAKE, "Client certificate verification failed" ); return Err(SslVerifyError::Invalid(SslAlert::BAD_CERTIFICATE)); } info!( - log_type=LogTypes::TLS_HANDSHAKE, + log_type = LogTypes::TLS_HANDSHAKE, "Client certificate verification succeeded" ); diff --git a/utils/src/cert.rs b/utils/src/cert.rs index 822944f..7f0fbbf 100644 --- a/utils/src/cert.rs +++ b/utils/src/cert.rs @@ -12,4 +12,4 @@ pub fn extract_x509_pem(pem: String) -> Result, Box(deserializer: D) -> Result where D: serde::Deserializer<'de>, - T: Add + Copy + FromStr, ::Err: std::fmt::Display, + T: Add + Copy + FromStr, + ::Err: std::fmt::Display, { let s: String = Deserialize::deserialize(deserializer).map_err(|e| { serde::de::Error::custom(format!("Failed to deserialize string to number: {}", e)) @@ -22,7 +23,9 @@ where })?; let bytes = s.into_bytes(); if bytes.len() != 32 { - return Err(serde::de::Error::custom("Expected 32 bytes for nTor static secret")); + return Err(serde::de::Error::custom( + "Expected 32 bytes for nTor static secret", + )); } let mut array = [0u8; 32]; array.copy_from_slice(&bytes); @@ -51,4 +54,4 @@ where "false" | "0" => Ok(false), _ => Err(serde::de::Error::custom("Expected 'true' or 'false'")), } -} \ No newline at end of file +} diff --git a/utils/src/jwt.rs b/utils/src/jwt.rs index b771419..37400f4 100644 --- a/utils/src/jwt.rs +++ b/utils/src/jwt.rs @@ -1,5 +1,5 @@ +use jsonwebtoken::{DecodingKey, TokenData, Validation, errors::Error as JwtError}; use serde::{Deserialize, Serialize}; -use jsonwebtoken::{DecodingKey, Validation, errors::Error as JwtError, TokenData}; /// JWT (JSON Web Token) claims structure. /// @@ -20,7 +20,6 @@ use jsonwebtoken::{DecodingKey, Validation, errors::Error as JwtError, TokenData #[derive(Serialize, Deserialize, Debug, Default)] pub struct JWTClaims { /* Registered claims */ - /// The "iss" (issuer) claim identifies the principal that issued the /// JWT. The processing of this claim is generally application specific. /// The "iss" value is a case-sensitive string containing a StringOrURI @@ -90,19 +89,23 @@ pub struct JWTClaims { pub jti: Option, /* Custom claims */ - /// This claim is required in JWT token `int_fp_jwt` between Interceptor and ForwardProxy. /// Used in ForwardProxy to identify the ReverseProxy server. - #[serde(skip_serializing_if = "Option::is_none", rename(serialize = "upstream", deserialize = "upstream"))] + #[serde( + skip_serializing_if = "Option::is_none", + rename(serialize = "upstream", deserialize = "upstream") + )] pub rp_host: Option, /// This claim is required in JWT token `int_rp_jwt` between Interceptor and ReverseProxy. - #[serde(skip_serializing_if = "Option::is_none", rename(serialize = "sid", deserialize = "sid"))] + #[serde( + skip_serializing_if = "Option::is_none", + rename(serialize = "sid", deserialize = "sid") + )] pub ntor_session_id: Option, /// The `uuid` claim is used to uniquely identify the token and help prevent race conditions. pub uuid: Option, - // Additional custom claims can be added here as needed. } @@ -115,7 +118,7 @@ impl JWTClaims { let now = chrono::Utc::now(); let expiration = now + chrono::Duration::hours(hours); Some(expiration.timestamp()) - }, + } None => None, }; @@ -150,15 +153,14 @@ pub fn create_jwt_token(claims: JWTClaims, jwt_secret: &[u8]) -> String { &jsonwebtoken::Header::default(), &claims, &jsonwebtoken::EncodingKey::from_secret(&jwt_secret), - ).unwrap() + ) + .unwrap() } -pub fn verify_jwt_token(token: &str, jwt_secret: &Vec) -> Result, JwtError> { +pub fn verify_jwt_token(token: &str, jwt_secret: &[u8]) -> Result, JwtError> { jsonwebtoken::decode::( token, - &DecodingKey::from_secret(jwt_secret.as_slice()), + &DecodingKey::from_secret(jwt_secret), &Validation::default(), ) } - - diff --git a/utils/src/lib.rs b/utils/src/lib.rs index ecda098..3534cfa 100644 --- a/utils/src/lib.rs +++ b/utils/src/lib.rs @@ -1,18 +1,17 @@ -pub mod jwt; -pub mod cert; -pub mod deserializer; -pub mod log; - -use url::Url; - use std::collections::HashMap; + use base64::Engine; use base64::engine::general_purpose; -use uuid::Uuid; use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; - use serde::{Deserialize, Serialize}; use tracing::error; +use url::Url; +use uuid::Uuid; + +pub mod cert; +pub mod deserializer; +pub mod jwt; +pub mod log; pub fn to_reqwest_header(map: HashMap) -> HeaderMap { let mut header_map = HeaderMap::new(); @@ -47,14 +46,14 @@ pub fn string_to_array32(s: String) -> Option<[u8; 32]> { } } -pub fn bytes_to_json(bytes: Vec) -> Result +pub fn bytes_to_json(bytes: &[u8]) -> Result where T: Serialize + for<'de> Deserialize<'de>, { - serde_json::from_slice::(&bytes) + serde_json::from_slice::(bytes) } -pub fn bytes_to_string(bytes: &Vec) -> String { +pub fn bytes_to_string(bytes: &[u8]) -> String { String::from_utf8_lossy(bytes).to_string() } @@ -72,13 +71,9 @@ pub fn string_to_headermap(s: &str) -> Result String { - let pairs: Vec<(String, String)> = headers.iter() - .map( - |(k, v)| ( - k.to_string(), - v.to_str().unwrap_or("").to_string() - ) - ) + let pairs: Vec<(String, String)> = headers + .iter() + .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string())) .collect(); serde_json::to_string(&pairs).unwrap() } @@ -93,9 +88,8 @@ fn headervalue_to_json(val: &HeaderValue) -> serde_json::Value { // serde_json::Value to http::header::value::HeaderValue fn json_to_headervalue( - val: &serde_json::Value -) -> Result -{ + val: &serde_json::Value, +) -> Result { match val { serde_json::Value::String(s) => HeaderValue::from_str(s), _ => HeaderValue::from_str(&val.to_string()), @@ -103,9 +97,8 @@ fn json_to_headervalue( } pub fn hashmap_to_headermap( - map: &HashMap -) -> Result> -{ + map: &HashMap, +) -> Result> { let mut headers = HeaderMap::new(); for (k, v) in map { let name = HeaderName::from_bytes(k.as_bytes())?; @@ -136,9 +129,10 @@ pub fn validate_url(url: &str) -> Option { } pub fn get_socket_addrs(url: &Url) -> String { - url.socket_addrs(|| None).unwrap_or_default() + url.socket_addrs(|| None) + .unwrap_or_default() .iter() .map(|addr| addr.to_string()) .collect::>() .join(",") -} \ No newline at end of file +} diff --git a/utils/src/log.rs b/utils/src/log.rs index 94ec078..306858d 100644 --- a/utils/src/log.rs +++ b/utils/src/log.rs @@ -13,15 +13,17 @@ pub fn init_logger( let level_filter = to_level_filter(level); // Dynamic writer - let (writer, guard): (BoxMakeWriter, Option) = - if log_folder == "console" { - let (non_blocking, guard) = tracing_appender::non_blocking(std::io::stdout()); - (BoxMakeWriter::new(non_blocking), Some(guard)) - } else { - let file_appender = rolling::daily(log_folder, log_file); - let (non_blocking, guard) = tracing_appender::non_blocking(file_appender); - (BoxMakeWriter::new(non_blocking), Some(guard)) - }; + let (writer, guard): ( + BoxMakeWriter, + Option, + ) = if log_folder == "console" { + let (non_blocking, guard) = tracing_appender::non_blocking(std::io::stdout()); + (BoxMakeWriter::new(non_blocking), Some(guard)) + } else { + let file_appender = rolling::daily(log_folder, log_file); + let (non_blocking, guard) = tracing_appender::non_blocking(file_appender); + (BoxMakeWriter::new(non_blocking), Some(guard)) + }; // Structured JSON logger let builder = fmt::Subscriber::builder()