Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 88 additions & 92 deletions forward-proxy/src/handler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,64 +105,64 @@ impl ForwardHandler {
);

ctx.insert_response_header("Connection", "close"); // Ensure connection closes???

Err(APIHandlerResponse {
return Err(APIHandlerResponse {
status: StatusCode::BAD_REQUEST,
body: Some(response_body.to_bytes()),
})
} else {
#[derive(Deserialize, Debug)]
struct AuthServerResponse {
pub x509_certificate: String,
pub client_id: String,
body: Some(
serde_json::to_vec(&response_body).expect("this struct is json serializable"),
),
});
}
#[derive(Deserialize, Debug)]
struct AuthServerResponse {
pub x509_certificate: String,
pub client_id: String,
}

let auth_res: AuthServerResponse = res.json().await.map_err(|err| {
error!(
%correlation_id,
log_type=LogTypes::AUTHENTICATION_SERVER,
"Failed to parse authentication server response: {:?}",
err
);
APIHandlerResponse {
status: StatusCode::INTERNAL_SERVER_ERROR,
body: None,
}
})?;

// save `client_id` to ctx for later use
ctx.set(
consts::CtxKeys::BACKEND_AUTH_CLIENT_ID.to_string(),
auth_res.client_id.clone(),
);

let auth_res: AuthServerResponse = res.json().await.map_err(|err| {
let pub_key =
utils::cert::extract_x509_pem(auth_res.x509_certificate.clone()).map_err(|e| {
error!(
%correlation_id,
log_type=LogTypes::AUTHENTICATION_SERVER,
"Failed to parse authentication server response: {:?}",
err
"Failed to parse x509 certificate: {:?}",
e
);
APIHandlerResponse {
status: StatusCode::INTERNAL_SERVER_ERROR,
body: None,
}
})?;

// save `client_id` to ctx for later use
ctx.set(
consts::CtxKeys::BACKEND_AUTH_CLIENT_ID.to_string(),
auth_res.client_id.clone(),
);

let pub_key = utils::cert::extract_x509_pem(auth_res.x509_certificate.clone())
.map_err(|e| {
error!(
%correlation_id,
log_type=LogTypes::AUTHENTICATION_SERVER,
"Failed to parse x509 certificate: {:?}",
e
);
APIHandlerResponse {
status: StatusCode::INTERNAL_SERVER_ERROR,
body: None,
}
})?;

debug!(%correlation_id, "AuthenticationServer response: {:?}", auth_res);
info!(
%correlation_id,
log_type=LogTypes::AUTHENTICATION_SERVER,
"Obtained ntor credentials for backend_url: {}",
backend_url
);
debug!(%correlation_id, "AuthenticationServer response: {:?}", auth_res);
info!(
%correlation_id,
log_type=LogTypes::AUTHENTICATION_SERVER,
"Obtained ntor credentials for backend_url: {}",
backend_url
);

Ok(NTorServerCertificate {
server_id: backend_url, // todo I still prefer taking the server_id value from certificate's subject
public_key: pub_key,
})
}
Ok(NTorServerCertificate {
server_id: backend_url, // todo I still prefer taking the server_id value from certificate's subject
public_key: pub_key,
})
}

/// Verify `int_fp_jwt` and return `fp_rp_jwt`
Expand Down Expand Up @@ -237,67 +237,63 @@ impl ForwardHandler {

pub fn handle_init_tunnel_response(&self, ctx: &mut Layer8Context) -> APIHandlerResponse {
let ntor_server_id = ctx
.get(&consts::CtxKeys::NTOR_SERVER_ID)
.get(consts::CtxKeys::NTOR_SERVER_ID)
.unwrap_or(&"".to_string())
.clone();
let ntor_static_public_key = hex::decode(
ctx.get(&consts::CtxKeys::NTOR_STATIC_PUBLIC_KEY)
ctx.get(consts::CtxKeys::NTOR_STATIC_PUBLIC_KEY)
.unwrap_or(&"".to_string()),
)
.unwrap_or_default();

let response_body = ctx.get_response_body();

return match utils::bytes_to_json::<InitTunnelResponseFromRP>(response_body) {
Err(e) => {
error!(
correlation_id = ctx.get_correlation_id(),
log_type = LogTypes::HANDLE_UPSTREAM_RESPONSE,
"Error parsing RP response: {:?}",
e
);
APIHandlerResponse {
status: StatusCode::INTERNAL_SERVER_ERROR,
body: None,
let res_from_rp =
match serde_json::from_slice::<InitTunnelResponseFromRP>(ctx.get_response_body()) {
Ok(val) => val,
Err(e) => {
error!(
correlation_id = ctx.get_correlation_id(),
log_type = LogTypes::HANDLE_UPSTREAM_RESPONSE,
"Error parsing RP response: {:?}",
e
);
return APIHandlerResponse {
status: StatusCode::INTERNAL_SERVER_ERROR,
body: None,
};
}
}
Ok(res_from_rp) => {
let int_fp_jwt = {
let mut claims = JWTClaims::new(Some(self.config.jwt_exp_in_hours));
claims.uuid = Some(utils::new_uuid());
utils::jwt::create_jwt_token(claims, &self.config.jwt_virtual_connection_key)
};
};

let int_fp_session = IntFPSession {
client_id: ctx
.get(&consts::CtxKeys::BACKEND_AUTH_CLIENT_ID)
.unwrap_or(&"".to_string())
.to_string(),
rp_base_url: ctx
.param("backend_url")
.unwrap_or(&"".to_string())
.to_string(),
fp_rp_jwt: res_from_rp.fp_rp_jwt,
};
let int_fp_jwt = {
let mut claims = JWTClaims::new(Some(self.config.jwt_exp_in_hours));
claims.uuid = Some(utils::new_uuid());
utils::jwt::create_jwt_token(claims, &self.config.jwt_virtual_connection_key)
};

let mut jwts = self.jwts_storage.lock().unwrap();
jwts.insert(int_fp_jwt.clone(), int_fp_session);
let int_fp_session = IntFPSession {
client_id: ctx
.get(consts::CtxKeys::BACKEND_AUTH_CLIENT_ID)
.unwrap_or(&"".to_string())
.clone(),
rp_base_url: ctx.param("backend_url").unwrap_or(&"".to_string()).clone(),
fp_rp_jwt: res_from_rp.fp_rp_jwt,
};

let res_to_int = InitTunnelResponseToINT {
ephemeral_public_key: res_from_rp.public_key,
t_b_hash: res_from_rp.t_b_hash,
int_rp_jwt: res_from_rp.int_rp_jwt,
int_fp_jwt,
ntor_static_public_key,
ntor_server_id,
};
let mut jwts = self.jwts_storage.lock().unwrap();
jwts.insert(int_fp_jwt.clone(), int_fp_session);

APIHandlerResponse {
status: StatusCode::OK,
body: Some(res_to_int.to_bytes()),
}
}
let res_to_int = InitTunnelResponseToINT {
ephemeral_public_key: res_from_rp.public_key,
t_b_hash: res_from_rp.t_b_hash,
int_rp_jwt: res_from_rp.int_rp_jwt,
int_fp_jwt,
ntor_static_public_key,
ntor_server_id,
};

APIHandlerResponse {
status: StatusCode::OK,
body: Some(res_to_int.to_bytes()),
}
}

pub fn handle_healthcheck(&self, ctx: &mut Layer8Context) -> APIHandlerResponse {
Expand Down
13 changes: 7 additions & 6 deletions forward-proxy/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
use pingora::prelude::*;
use tokio::runtime::Runtime;
use tracing::{debug, info};

mod config;
mod handler;
mod proxy;
mod statistics;

use crate::config::FPConfig;
use crate::handler::ForwardHandler;
use crate::statistics::Statistics;
use pingora::prelude::*;
use config::FPConfig;
use handler::ForwardHandler;
use proxy::ForwardProxy;
use tokio::runtime::Runtime;
use tracing::{debug, info};
use statistics::Statistics;

fn load_config() -> FPConfig {
// Load environment variables from .env file
Expand Down
21 changes: 11 additions & 10 deletions forward-proxy/src/proxy.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use crate::config::TlsConfig;
use crate::handler::ForwardHandler;
use crate::handler::consts::{CtxKeys, HeaderKeys, LogTypes, RequestPaths};
use crate::handler::types::response::ErrorResponse;
use crate::statistics::Statistics;
use std::sync::Arc;
use std::time::Duration;

use async_trait::async_trait;
use boring::x509::X509;
use bytes::Bytes;
Expand All @@ -17,10 +15,14 @@ use pingora_error::ErrorType;
use pingora_router::ctx::{Layer8Context, Layer8ContextTrait};
use pingora_router::handler::ResponseBodyTrait;
use reqwest::header::TRANSFER_ENCODING;
use std::sync::Arc;
use std::time::Duration;
use tracing::{debug, error, info};

use crate::config::TlsConfig;
use crate::handler::ForwardHandler;
use crate::handler::consts::{CtxKeys, HeaderKeys, LogTypes, RequestPaths};
use crate::handler::types::response::ErrorResponse;
use crate::statistics::Statistics;

pub struct ForwardProxy {
tls_config: TlsConfig,
handler: ForwardHandler,
Expand Down Expand Up @@ -361,7 +363,7 @@ impl ProxyHttp for ForwardProxy {
request_summary = session.request_summary(),
"Forward proxy passing through request body unchanged."
);
*body = Some(Bytes::copy_from_slice(ctx.get_request_body().as_slice()));
*body = Some(Bytes::copy_from_slice(ctx.get_request_body()));
return Ok(());
}
};
Expand Down Expand Up @@ -441,7 +443,6 @@ impl ProxyHttp for ForwardProxy {
"{} token is empty",
HeaderKeys::INT_FP_JWT
);

return Err(pingora::Error::new(pingora::ErrorType::HTTPStatus(
u16::from(StatusCode::BAD_REQUEST),
)));
Expand Down Expand Up @@ -545,7 +546,7 @@ impl ProxyHttp for ForwardProxy {
request_summary = session.request_summary(),
"Forward proxy passing through response body unchanged."
);
*body = Some(Bytes::copy_from_slice(ctx.get_response_body().as_slice()));
*body = Some(Bytes::copy_from_slice(ctx.get_response_body()));
return Ok(None);
}
};
Expand Down
20 changes: 11 additions & 9 deletions pingora-router/src/ctx.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use crate::utils::get_request_body;
use pingora::http::{Method, RequestHeader, StatusCode};
use pingora::proxy::Session;
use std::collections::HashMap;
use std::time::Instant;

use pingora::http::{Method, RequestHeader, StatusCode};
use pingora::proxy::Session;
use uuid;

use crate::utils::get_request_body;

/*
* Each type in this crate serves a specific purpose and may be updated as requirements evolve.
*/
Expand Down Expand Up @@ -186,8 +188,8 @@ impl Layer8ContextTrait for Layer8Context {
self.request.body.extend(body)
}

fn get_request_body(&self) -> Vec<u8> {
self.request.body.clone()
fn get_request_body(&self) -> &[u8] {
&self.request.body
}

fn set_response_body(&mut self, body: Vec<u8>) {
Expand All @@ -198,8 +200,8 @@ impl Layer8ContextTrait for Layer8Context {
self.response.body.extend(body);
}

fn get_response_body(&self) -> Vec<u8> {
self.response.body.clone()
fn get_response_body(&self) -> &[u8] {
&self.response.body
}

fn get(&self, key: &str) -> Option<&String> {
Expand Down Expand Up @@ -254,10 +256,10 @@ pub trait Layer8ContextTrait {
fn get_response_header(&self) -> &Layer8Header;
fn set_request_body(&mut self, body: Vec<u8>);
fn extend_request_body(&mut self, body: Vec<u8>);
fn get_request_body(&self) -> Vec<u8>;
fn get_request_body(&self) -> &[u8];
fn set_response_body(&mut self, body: Vec<u8>);
fn extend_response_body(&mut self, body: Vec<u8>);
fn get_response_body(&self) -> Vec<u8>;
fn get_response_body(&self) -> &[u8];
fn get(&self, key: &str) -> Option<&String>;
fn set(&mut self, key: String, value: String);
fn set_request_summary(&mut self, summary: Layer8ContextRequestSummary);
Expand Down
12 changes: 6 additions & 6 deletions pingora-router/src/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ pub trait ResponseBodyTrait: Serialize + for<'de> Deserialize<'de> + Debug {
serde_json::to_vec(self).unwrap()
}

fn from_bytes(bytes: Vec<u8>) -> Result<Box<Self>, serde_json::Error> {
serde_json::from_slice(&bytes)
fn from_bytes(bytes: &[u8]) -> Result<Box<Self>, serde_json::Error> {
serde_json::from_slice(bytes)
}

/// Override this method to handle error serialization if your handler implements
Expand All @@ -93,8 +93,8 @@ pub trait RequestBodyTrait: Serialize + for<'de> Deserialize<'de> + Debug {
serde_json::to_vec(self).unwrap()
}

fn from_bytes(bytes: Vec<u8>) -> Result<Box<Self>, serde_json::Error> {
serde_json::from_slice(&bytes)
fn from_bytes(bytes: &[u8]) -> Result<Box<Self>, serde_json::Error> {
serde_json::from_slice(bytes)
}
}

Expand All @@ -112,9 +112,9 @@ pub trait RequestBodyTrait: Serialize + for<'de> Deserialize<'de> + Debug {
/// ResponseBodyTrait` (constructed from the JSON error), and a 400 Bad Request status.
pub trait DefaultHandlerTrait {
fn parse_request_body<T: RequestBodyTrait, E: ResponseBodyTrait>(
data: &Vec<u8>,
data: &[u8],
) -> Result<T, Option<E>> {
match T::from_bytes(data.clone()) {
match T::from_bytes(data) {
Ok(body) => Ok(*body),
Err(e) => Err(E::from_json_err(e)),
}
Expand Down
Loading
Loading