diff --git a/Cargo.toml b/Cargo.toml index a07d002..58c6dc6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,12 +36,14 @@ darling = "0.20.10" erased-serde = "0.3.28" futures-util = "0.3.28" governor = "0.6" -hyper = { version = "0.14", default-features = false } +http-body-util = "0.1" +hyper = { version = "1", default-features = false } +hyper-util = { version = "0.1", default-features = false } indexmap = "2.0.0" ipnetwork = "0.20" once_cell = "1.5" -tonic = { version = "0.11.0", default-features = false } -opentelemetry-proto = "0.5.0" +tonic = { version = "0.12", default-features = false } +opentelemetry-proto = "0.7" parking_lot = "0.12.1" proc-macro2 = { version = "1", default-features = false } prometheus = { version = "0.13.3", default-features = false } @@ -68,6 +70,7 @@ tokio = "1.41.0" thread_local = "1.1" tikv-jemallocator = "0.5" tikv-jemalloc-ctl = "0.5" +tower-service = "0.3" yaml-merge-keys = "0.5" # needed for minver diff --git a/examples/Cargo.toml b/examples/Cargo.toml index cd02c67..e4963a0 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -9,7 +9,9 @@ publish = false anyhow = { workspace = true } foundations = { workspace = true } futures-util = { workspace = true } +http-body-util = { workspace = true } hyper = { workspace = true } +hyper-util = { workspace = true, features = ["server", "tokio"] } tokio = { workspace = true, features = ["full"]} [[example]] diff --git a/examples/http_server/main.rs b/examples/http_server/main.rs index 1ff715a..25b8599 100644 --- a/examples/http_server/main.rs +++ b/examples/http_server/main.rs @@ -18,9 +18,11 @@ use foundations::settings::collections::Map; use foundations::telemetry::{self, log, tracing, TelemetryConfig, TelemetryContext}; use foundations::BootstrapResult; use futures_util::stream::{FuturesUnordered, StreamExt}; -use hyper::server::conn::Http; +use http_body_util::Full; +use hyper::body::{Bytes, Incoming}; use hyper::service::service_fn; -use hyper::{Body, Request, Response}; +use hyper::{Request, Response}; +use hyper_util::rt::{TokioExecutor, TokioIo}; use std::convert::Infallible; use std::net::{SocketAddr, TcpListener as StdTcpListener}; use std::sync::Arc; @@ -193,7 +195,10 @@ async fn serve_connection( } }); - if let Err(e) = Http::new().serve_connection(conn, on_request).await { + if let Err(e) = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()) + .serve_connection(TokioIo::new(conn), on_request) + .await + { log::error!("failed to serve HTTP"; "error" => ?e); metrics::http_server::failed_connections_total(&endpoint_name).inc(); } @@ -204,9 +209,9 @@ async fn serve_connection( #[tracing::span_fn("respond to request")] async fn respond( endpoint_name: Arc, - req: Request, + req: Request, routes: Arc>, -) -> Result, Infallible> { +) -> Result>, Infallible> { log::add_fields! { "request_uri" => req.uri().to_string(), "method" => req.method().to_string() diff --git a/foundations/Cargo.toml b/foundations/Cargo.toml index e27f91b..e9dfb49 100644 --- a/foundations/Cargo.toml +++ b/foundations/Cargo.toml @@ -77,13 +77,15 @@ client-telemetry = ["logging", "metrics", "tracing", "dep:futures-util"] # Enables the telemetry server. telemetry-server = [ + "dep:http-body-util", "dep:hyper", + "dep:hyper-util", "dep:socket2", - "dep:percent-encoding" + "dep:percent-encoding", ] # Enables telemetry reporting over gRPC -telemetry-otlp-grpc = ["dep:tonic", "dep:tokio", "dep:hyper"] +telemetry-otlp-grpc = ["dep:tonic", "tonic/prost", "dep:tokio", "dep:hyper"] # Enables experimental tokio runtime metrics tokio-runtime-metrics = [ @@ -177,11 +179,9 @@ clap = { workspace = true, optional = true } erased-serde = { workspace = true, optional = true } futures-util = { workspace = true, optional = true } governor = { workspace = true, optional = true } -hyper = { workspace = true, optional = true, features = [ - "http1", - "runtime", - "server", -] } +http-body-util = { workspace = true, optional = true } +hyper = { workspace = true, optional = true, features = ["http1", "server"] } +hyper-util = { workspace = true, optional = true, features = ["tokio"] } indexmap = { workspace = true, optional = true, features = ["serde"] } once_cell = { workspace = true, optional = true } opentelemetry-proto = { workspace = true, optional = true, features = ["gen-tonic-messages", "trace"] } diff --git a/foundations/src/telemetry/driver.rs b/foundations/src/telemetry/driver.rs index 91a4ca5..728e605 100644 --- a/foundations/src/telemetry/driver.rs +++ b/foundations/src/telemetry/driver.rs @@ -1,16 +1,14 @@ use crate::utils::feature_use; use crate::BootstrapResult; use futures_util::future::BoxFuture; -use futures_util::stream::{FuturesUnordered, Stream}; -use futures_util::FutureExt; +use futures_util::stream::FuturesUnordered; +use futures_util::{FutureExt, Stream}; use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; feature_use!(cfg(feature = "telemetry-server"), { use super::server::TelemetryServerFuture; - use anyhow::anyhow; - use hyper::Server; use std::net::SocketAddr; }); @@ -38,7 +36,7 @@ impl TelemetryDriver { ) -> Self { Self { #[cfg(feature = "telemetry-server")] - server_addr: server_fut.as_ref().map(Server::local_addr), + server_addr: server_fut.as_ref().map(|fut| fut.local_addr()), #[cfg(feature = "telemetry-server")] server_fut, @@ -66,9 +64,11 @@ impl TelemetryDriver { #[cfg(feature = "telemetry-server")] { if let Some(server_fut) = self.server_fut.take() { - self.tele_futures.push( - async move { Ok(server_fut.with_graceful_shutdown(signal).await?) }.boxed(), - ); + self.tele_futures.push(Box::pin(async move { + server_fut.with_graceful_shutdown(signal).await; + + Ok(()) + })); return; } @@ -93,7 +93,7 @@ impl Future for TelemetryDriver { #[cfg(feature = "telemetry-server")] if let Some(server_fut) = &mut self.server_fut { if let Poll::Ready(res) = Pin::new(server_fut).poll(cx) { - ready_res.push(res.map_err(|err| anyhow!(err))); + match res {} } } diff --git a/foundations/src/telemetry/mod.rs b/foundations/src/telemetry/mod.rs index cc47df0..45dbfba 100644 --- a/foundations/src/telemetry/mod.rs +++ b/foundations/src/telemetry/mod.rs @@ -117,7 +117,9 @@ pub use self::testing::TestTelemetryContext; pub use self::memory_profiler::MemoryProfiler; #[cfg(feature = "telemetry-server")] -pub use self::server::{TelemetryRouteHandler, TelemetryRouteHandlerFuture, TelemetryServerRoute}; +pub use self::server::{ + BoxError, TelemetryRouteHandler, TelemetryRouteHandlerFuture, TelemetryServerRoute, +}; pub use self::driver::TelemetryDriver; pub use self::telemetry_context::{ @@ -290,7 +292,10 @@ pub fn init(config: TelemetryConfig) -> BootstrapResult { #[cfg(feature = "telemetry-server")] { - let server_fut = self::server::init(config.settings.clone(), config.custom_server_routes)?; + let server_fut = server::TelemetryServerFuture::new( + config.settings.clone(), + config.custom_server_routes, + )?; Ok(TelemetryDriver::new(server_fut, tele_futures)) } diff --git a/foundations/src/telemetry/server/mod.rs b/foundations/src/telemetry/server/mod.rs new file mode 100644 index 0000000..c940a45 --- /dev/null +++ b/foundations/src/telemetry/server/mod.rs @@ -0,0 +1,203 @@ +#[cfg(feature = "metrics")] +use super::metrics; +use super::settings::TelemetrySettings; +use crate::telemetry::log; +use crate::BootstrapResult; +use anyhow::Context as _; +use futures_util::future::FutureExt; +use futures_util::{pin_mut, ready}; +use hyper_util::rt::TokioIo; +use socket2::{Domain, SockAddr, Socket, Type}; +use std::convert::Infallible; +use std::future::Future; +use std::net::SocketAddr; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use tokio::net::TcpListener; +use tokio::sync::watch; + +mod router; + +use router::Router; +pub use router::{ + BoxError, TelemetryRouteHandler, TelemetryRouteHandlerFuture, TelemetryServerRoute, +}; + +pub(super) struct TelemetryServerFuture { + listener: TcpListener, + router: Router, +} + +impl TelemetryServerFuture { + pub(super) fn new( + settings: TelemetrySettings, + custom_routes: Vec, + ) -> BootstrapResult> { + if !settings.server.enabled { + return Ok(None); + } + + let settings = Arc::new(settings); + + // Eagerly init the memory profiler so it gets set up before syscalls are sandboxed with seccomp. + #[cfg(all(target_os = "linux", feature = "memory-profiling"))] + if settings.memory_profiler.enabled { + memory_profiling::profiler(Arc::clone(&settings)) + .map_err(|err| anyhow::anyhow!(err))?; + } + + let addr = settings.server.addr; + + #[cfg(feature = "settings")] + let addr = SocketAddr::from(addr); + + let router = Router::new(custom_routes, settings); + + let listener = { + let std_listener = std::net::TcpListener::from( + bind_socket(addr).with_context(|| format!("binding to socket {addr:?}"))?, + ); + + std_listener.set_nonblocking(true)?; + + tokio::net::TcpListener::from_std(std_listener)? + }; + + Ok(Some(TelemetryServerFuture { listener, router })) + } + pub(super) fn local_addr(&self) -> SocketAddr { + self.listener.local_addr().unwrap() + } + + // Adapted from Hyper 0.14 Server stuff and axum::serve::serve. + pub(super) async fn with_graceful_shutdown( + self, + shutdown_signal: impl Future + Send + Sync + 'static, + ) { + let (signal_tx, signal_rx) = watch::channel(()); + let signal_tx = Arc::new(signal_tx); + + tokio::spawn(async move { + shutdown_signal.await; + + drop(signal_rx); + }); + + let (close_tx, close_rx) = watch::channel(()); + let listener = self.listener; + + pin_mut!(listener); + + loop { + let socket = tokio::select! { + conn = listener.accept() => match conn { + Ok((conn, _)) => TokioIo::new(conn), + Err(e) => { + log::warn!("failed to accept connection"; "error" => e); + + continue; + } + }, + _ = signal_tx.closed() => { break }, + }; + + let router = self.router.clone(); + let signal_tx = Arc::clone(&signal_tx); + let close_rx = close_rx.clone(); + + tokio::spawn(async move { + let conn = hyper::server::conn::http1::Builder::new() + .serve_connection(socket, router) + .with_upgrades(); + + let signal_closed = signal_tx.closed().fuse(); + + pin_mut!(conn); + pin_mut!(signal_closed); + + loop { + tokio::select! { + _ = conn.as_mut() => break, + _ = &mut signal_closed => conn.as_mut().graceful_shutdown(), + } + } + + drop(close_rx); + }); + } + + drop(close_rx); + + close_tx.closed().await; + } +} + +impl Future for TelemetryServerFuture { + type Output = Infallible; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = &mut *self; + + loop { + let socket = match ready!(Pin::new(&mut this.listener).poll_accept(cx)) { + Ok((conn, _)) => TokioIo::new(conn), + Err(e) => { + log::warn!("failed to accept connection"; "error" => e); + + continue; + } + }; + + let router = this.router.clone(); + + tokio::spawn( + hyper::server::conn::http1::Builder::new() + // upgrades needed for websockets + .serve_connection(socket, router) + .with_upgrades(), + ); + } + } +} + +fn bind_socket(addr: SocketAddr) -> BootstrapResult { + let socket = Socket::new( + if addr.is_ipv4() { + Domain::IPV4 + } else { + Domain::IPV6 + }, + Type::STREAM, + None, + )?; + + socket.set_reuse_address(true)?; + #[cfg(unix)] + socket.set_reuse_port(true)?; + socket.bind(&SockAddr::from(addr))?; + socket.listen(1024)?; + + Ok(socket) +} + +#[cfg(all(target_os = "linux", feature = "memory-profiling"))] +mod memory_profiling { + use super::*; + use crate::telemetry::MemoryProfiler; + use crate::Result; + + pub(super) fn profiler(settings: Arc) -> Result { + MemoryProfiler::get_or_init_with(&settings.memory_profiler)?.ok_or_else(|| { + "profiling should be enabled via `_RJEM_MALLOC_CONF=prof:true` env var".into() + }) + } + + pub(super) async fn heap_profile(settings: Arc) -> Result { + profiler(settings)?.heap_profile().await + } + + pub(super) async fn heap_stats(settings: Arc) -> Result { + profiler(settings)?.heap_stats() + } +} diff --git a/foundations/src/telemetry/server.rs b/foundations/src/telemetry/server/router.rs similarity index 53% rename from foundations/src/telemetry/server.rs rename to foundations/src/telemetry/server/router.rs index e98380e..a162385 100644 --- a/foundations/src/telemetry/server.rs +++ b/foundations/src/telemetry/server/router.rs @@ -1,31 +1,29 @@ +#[cfg(all(target_os = "linux", feature = "memory-profiling"))] +use super::memory_profiling; #[cfg(feature = "metrics")] use super::metrics; -use super::settings::TelemetrySettings; -use crate::BootstrapResult; -use anyhow::Context as _; +use crate::telemetry::settings::TelemetrySettings; use futures_util::future::{BoxFuture, FutureExt}; -use hyper::server::conn::{AddrIncoming, AddrStream}; +use http_body_util::combinators::BoxBody; +use http_body_util::{BodyExt, Empty, Full}; +use hyper::body::{Bytes, Incoming}; use hyper::service::Service; -use hyper::{header, Body, Method, Request, Response, Server, StatusCode}; +use hyper::{header, Method, Request, Response, StatusCode}; use percent_encoding::percent_decode_str; -use socket2::{Domain, SockAddr, Socket, Type}; use std::collections::HashMap; use std::convert::Infallible; -use std::future::{ready, Future, Ready}; -use std::net::{SocketAddr, TcpListener}; -use std::pin::Pin; use std::sync::Arc; -use std::task::{Context, Poll}; - -pub(super) type TelemetryServerFuture = Server; /// Future returned by [`TelemetryServerRoute::handler`]. pub type TelemetryRouteHandlerFuture = - BoxFuture<'static, std::result::Result, Infallible>>; + BoxFuture<'static, std::result::Result>, Infallible>>; + +/// Error type returned by [`TelemetryRouteHandlerFuture`]. +pub type BoxError = Box; /// Telemetry route handler. pub type TelemetryRouteHandler = Box< - dyn Fn(Request, Arc) -> TelemetryRouteHandlerFuture + dyn Fn(Request, Arc) -> TelemetryRouteHandlerFuture + Send + Sync + 'static, @@ -131,13 +129,25 @@ pub(super) struct Router { } impl Router { - async fn handle_request(&self, req: Request) -> Response { + pub(super) fn new( + custom_routes: Vec, + settings: Arc, + ) -> Self { + Self { + routes: Arc::new(RouteMap::new(custom_routes)), + settings, + } + } + + async fn handle_request(&self, req: Request) -> Response> { let res = Response::builder(); let Ok(path) = percent_decode_str(req.uri().path()).decode_utf8() else { return res .status(StatusCode::BAD_REQUEST) - .body("can't percent-decode URI path as valid UTF-8".into()) + .body(BoxBody::new( + Full::from("can't percent-decode URI path as valid UTF-8").map_err(Into::into), + )) .unwrap(); }; @@ -147,7 +157,10 @@ impl Router { .get(req.method()) .and_then(|e| e.get(&path.to_string())) else { - return res.status(StatusCode::NOT_FOUND).body("".into()).unwrap(); + return res + .status(StatusCode::NOT_FOUND) + .body(BoxBody::new(Empty::new().map_err(Into::into))) + .unwrap(); }; match (handler)(req, Arc::clone(&self.settings)).await { @@ -157,126 +170,32 @@ impl Router { } } -impl Service<&AddrStream> for Router { - type Response = Self; +impl Service> for Router { + type Response = Response>; type Error = Infallible; - type Future = Ready>; + type Future = BoxFuture<'static, Result>; - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, _conn: &AddrStream) -> Self::Future { - ready(Ok(self.clone())) - } -} - -impl Service> for Router { - type Response = Response; - type Error = Infallible; - type Future = Pin< - Box> + Send + 'static>, - >; - - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, req: Request) -> Self::Future { + fn call(&self, req: Request) -> Self::Future { let router = self.clone(); async move { Ok(router.handle_request(req).await) }.boxed() } } -pub(super) fn init( - settings: TelemetrySettings, - custom_routes: Vec, -) -> BootstrapResult> { - if !settings.server.enabled { - return Ok(None); - } - - let settings = Arc::new(settings); - - // Eagerly init the memory profiler so it gets set up before syscalls are sandboxed with seccomp. - #[cfg(all(target_os = "linux", feature = "memory-profiling"))] - if settings.memory_profiler.enabled { - memory_profiling::profiler(Arc::clone(&settings)).map_err(|err| anyhow::anyhow!(err))?; - } - - let addr = settings.server.addr; - - #[cfg(feature = "settings")] - let addr = SocketAddr::from(addr); - - let router = Router { - routes: Arc::new(RouteMap::new(custom_routes)), - settings, - }; - - let socket = TcpListener::from( - bind_socket(addr).with_context(|| format!("binding to socket {addr:?}"))?, - ); - - let builder = Server::from_tcp(socket)?; - - Ok(Some(builder.serve(router))) -} - -fn bind_socket(addr: SocketAddr) -> BootstrapResult { - let socket = Socket::new( - if addr.is_ipv4() { - Domain::IPV4 - } else { - Domain::IPV6 - }, - Type::STREAM, - None, - )?; - - socket.set_reuse_address(true)?; - #[cfg(unix)] - socket.set_reuse_port(true)?; - socket.bind(&SockAddr::from(addr))?; - socket.listen(1024)?; - - Ok(socket) -} - fn into_response( content_type: &str, - res: crate::Result>, -) -> std::result::Result, Infallible> { + res: crate::Result>>, +) -> std::result::Result>, Infallible> { Ok(match res { Ok(data) => Response::builder() .header(header::CONTENT_TYPE, content_type) - .body(data.into()) + .body(BoxBody::new(data.into().map_err(Into::into))) .unwrap(), Err(err) => Response::builder() .status(StatusCode::INTERNAL_SERVER_ERROR) - .body(err.to_string().into()) + .body(BoxBody::new( + Full::from(err.to_string()).map_err(Into::into), + )) .unwrap(), }) } - -#[cfg(all(target_os = "linux", feature = "memory-profiling"))] -mod memory_profiling { - use super::*; - use crate::telemetry::MemoryProfiler; - use crate::Result; - - pub(super) fn profiler(settings: Arc) -> Result { - MemoryProfiler::get_or_init_with(&settings.memory_profiler)?.ok_or_else(|| { - "profiling should be enabled via `_RJEM_MALLOC_CONF=prof:true` env var".into() - }) - } - - pub(super) async fn heap_profile(settings: Arc) -> Result { - profiler(settings)?.heap_profile().await - } - - pub(super) async fn heap_stats(settings: Arc) -> Result { - profiler(settings)?.heap_stats() - } -} diff --git a/foundations/src/telemetry/telemetry_context.rs b/foundations/src/telemetry/telemetry_context.rs index ba48dfc..9eca7b8 100644 --- a/foundations/src/telemetry/telemetry_context.rs +++ b/foundations/src/telemetry/telemetry_context.rs @@ -1,3 +1,5 @@ +use futures_util::future::BoxFuture; + use super::TelemetryScope; use crate::utils::feature_use; use std::future::Future; @@ -28,7 +30,7 @@ pub struct WithTelemetryContext<'f, T> { // NOTE: we intentionally erase type here as we can get close to the type // length limit, adding telemetry wrappers on top causes compiler to fail in some // cases. - inner: Pin + Send + 'f>>, + inner: BoxFuture<'f, T>, ctx: TelemetryContext, } diff --git a/foundations/tests/telemetry_server.rs b/foundations/tests/telemetry_server.rs index c18c2ec..2448f3a 100644 --- a/foundations/tests/telemetry_server.rs +++ b/foundations/tests/telemetry_server.rs @@ -1,7 +1,10 @@ use foundations::telemetry::settings::{TelemetryServerSettings, TelemetrySettings}; use foundations::telemetry::{TelemetryConfig, TelemetryServerRoute}; use futures_util::FutureExt; +use http_body_util::combinators::BoxBody; +use http_body_util::{BodyExt, Full}; use hyper::{Method, Response}; +use std::future::IntoFuture; use std::net::{Ipv4Addr, SocketAddr}; #[cfg(target_os = "linux")] @@ -43,11 +46,17 @@ async fn telemetry_server() { path: "/custom-route".into(), methods: vec![Method::GET], handler: Box::new(|_, _| { - async { Ok(Response::builder().body("Hello".into()).unwrap()) }.boxed() + async { + Ok(Response::new(BoxBody::new( + Full::from("Hello").map_err(Into::into), + ))) + } + .boxed() }), }], }) - .unwrap(), + .unwrap() + .into_future(), ); assert_eq!(