diff --git a/examples/axum-key-value-store/src/main.rs b/examples/axum-key-value-store/src/main.rs index 2a4bd47e..dde8851e 100644 --- a/examples/axum-key-value-store/src/main.rs +++ b/examples/axum-key-value-store/src/main.rs @@ -76,7 +76,7 @@ fn app() -> Router { ) .sensitive_response_headers(sensitive_headers) // Set a timeout - .layer(TimeoutLayer::new(Duration::from_secs(10))) + .layer(TimeoutLayer::with_status_code(StatusCode::REQUEST_TIMEOUT, Duration::from_secs(10))) // Compress responses .compression() // Set a `Content-Type` if there isn't one already. diff --git a/tower-http/src/timeout/mod.rs b/tower-http/src/timeout/mod.rs index facb6a92..e159b23c 100644 --- a/tower-http/src/timeout/mod.rs +++ b/tower-http/src/timeout/mod.rs @@ -1,7 +1,7 @@ //! Middleware that applies a timeout to requests. //! -//! If the request does not complete within the specified timeout it will be aborted and a `408 -//! Request Timeout` response will be sent. +//! If the request does not complete within the specified timeout, it will be aborted and a +//! response with an empty body and a custom status code will be returned. //! //! # Differences from `tower::timeout` //! @@ -9,14 +9,14 @@ //! it changes the error type to [`BoxError`](tower::BoxError). For HTTP services that is rarely //! what you want as returning errors will terminate the connection without sending a response. //! -//! This middleware won't change the error type and instead return a `408 Request Timeout` -//! response. That means if your service's error type is [`Infallible`] it will still be -//! [`Infallible`] after applying this middleware. +//! This middleware won't change the error type and instead returns a response with an empty body +//! and the specified status code. That means if your service's error type is [`Infallible`], it will +//! still be [`Infallible`] after applying this middleware. //! //! # Example //! //! ``` -//! use http::{Request, Response}; +//! use http::{Request, Response, StatusCode}; //! use http_body_util::Full; //! use bytes::Bytes; //! use std::{convert::Infallible, time::Duration}; @@ -31,8 +31,8 @@ //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! let svc = ServiceBuilder::new() -//! // Timeout requests after 30 seconds -//! .layer(TimeoutLayer::new(Duration::from_secs(30))) +//! // Timeout requests after 30 seconds with the specified status code +//! .layer(TimeoutLayer::with_status_code(StatusCode::REQUEST_TIMEOUT, Duration::from_secs(30))) //! .service_fn(handle); //! # Ok(()) //! # } diff --git a/tower-http/src/timeout/service.rs b/tower-http/src/timeout/service.rs index 230fe717..68ea56ef 100644 --- a/tower-http/src/timeout/service.rs +++ b/tower-http/src/timeout/service.rs @@ -17,12 +17,25 @@ use tower_service::Service; #[derive(Debug, Clone, Copy)] pub struct TimeoutLayer { timeout: Duration, + status_code: StatusCode, } impl TimeoutLayer { /// Creates a new [`TimeoutLayer`]. + /// + /// By default, it will return a `408 Request Timeout` response if the request does not complete within the specified timeout. + /// To customize the response status code, use the `with_status_code` method. + #[deprecated(since = "0.6.7", note = "Use `TimeoutLayer::with_status_code` instead")] pub fn new(timeout: Duration) -> Self { - TimeoutLayer { timeout } + Self::with_status_code(StatusCode::REQUEST_TIMEOUT, timeout) + } + + /// Creates a new [`TimeoutLayer`] with the specified status code for the timeout response. + pub fn with_status_code(status_code: StatusCode, timeout: Duration) -> Self { + Self { + timeout, + status_code, + } } } @@ -30,26 +43,37 @@ impl Layer for TimeoutLayer { type Service = Timeout; fn layer(&self, inner: S) -> Self::Service { - Timeout::new(inner, self.timeout) + Timeout::with_status_code(inner, self.status_code, self.timeout) } } /// Middleware which apply a timeout to requests. /// -/// If the request does not complete within the specified timeout it will be aborted and a `408 -/// Request Timeout` response will be sent. -/// /// See the [module docs](super) for an example. #[derive(Debug, Clone, Copy)] pub struct Timeout { inner: S, timeout: Duration, + status_code: StatusCode, } impl Timeout { /// Creates a new [`Timeout`]. + /// + /// By default, it will return a `408 Request Timeout` response if the request does not complete within the specified timeout. + /// To customize the response status code, use the `with_status_code` method. + #[deprecated(since = "0.6.7", note = "Use `Timeout::with_status_code` instead")] pub fn new(inner: S, timeout: Duration) -> Self { - Self { inner, timeout } + Self::with_status_code(inner, StatusCode::REQUEST_TIMEOUT, timeout) + } + + /// Creates a new [`Timeout`] with the specified status code for the timeout response. + pub fn with_status_code(inner: S, status_code: StatusCode, timeout: Duration) -> Self { + Self { + inner, + timeout, + status_code, + } } define_inner_service_accessors!(); @@ -57,8 +81,17 @@ impl Timeout { /// Returns a new [`Layer`] that wraps services with a `Timeout` middleware. /// /// [`Layer`]: tower_layer::Layer + #[deprecated( + since = "0.6.7", + note = "Use `Timeout::layer_with_status_code` instead" + )] pub fn layer(timeout: Duration) -> TimeoutLayer { - TimeoutLayer::new(timeout) + TimeoutLayer::with_status_code(StatusCode::REQUEST_TIMEOUT, timeout) + } + + /// Returns a new [`Layer`] that wraps services with a `Timeout` middleware with the specified status code. + pub fn layer_with_status_code(status_code: StatusCode, timeout: Duration) -> TimeoutLayer { + TimeoutLayer::with_status_code(status_code, timeout) } } @@ -81,6 +114,7 @@ where ResponseFuture { inner: self.inner.call(req), sleep, + status_code: self.status_code, } } } @@ -92,6 +126,7 @@ pin_project! { inner: F, #[pin] sleep: Sleep, + status_code: StatusCode, } } @@ -107,7 +142,7 @@ where if this.sleep.poll(cx).is_ready() { let mut res = Response::new(B::default()); - *res.status_mut() = StatusCode::REQUEST_TIMEOUT; + *res.status_mut() = *this.status_code; return Poll::Ready(Ok(res)); } @@ -269,3 +304,93 @@ where Poll::Ready(Ok(res.map(|body| TimeoutBody::new(timeout, body)))) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::test_helpers::Body; + use http::{Request, Response, StatusCode}; + use std::time::Duration; + use tower::{BoxError, ServiceBuilder, ServiceExt}; + + #[tokio::test] + async fn request_completes_within_timeout() { + let mut service = ServiceBuilder::new() + .layer(TimeoutLayer::with_status_code( + StatusCode::GATEWAY_TIMEOUT, + Duration::from_secs(1), + )) + .service_fn(fast_handler); + + let request = Request::get("/").body(Body::empty()).unwrap(); + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + } + + #[tokio::test] + async fn timeout_middleware_with_custom_status_code() { + let timeout_service = Timeout::with_status_code( + tower::service_fn(slow_handler), + StatusCode::REQUEST_TIMEOUT, + Duration::from_millis(10), + ); + + let mut service = ServiceBuilder::new().service(timeout_service); + + let request = Request::get("/").body(Body::empty()).unwrap(); + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT); + } + + #[tokio::test] + async fn timeout_response_has_empty_body() { + let mut service = ServiceBuilder::new() + .layer(TimeoutLayer::with_status_code( + StatusCode::GATEWAY_TIMEOUT, + Duration::from_millis(10), + )) + .service_fn(slow_handler); + + let request = Request::get("/").body(Body::empty()).unwrap(); + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::GATEWAY_TIMEOUT); + + // Verify the body is empty (default) + use http_body_util::BodyExt; + let body = res.into_body(); + let bytes = body.collect().await.unwrap().to_bytes(); + assert!(bytes.is_empty()); + } + + #[tokio::test] + async fn deprecated_new_method_compatibility() { + #[allow(deprecated)] + let layer = TimeoutLayer::new(Duration::from_millis(10)); + + let mut service = ServiceBuilder::new().layer(layer).service_fn(slow_handler); + + let request = Request::get("/").body(Body::empty()).unwrap(); + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + // Should use default 408 status code + assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT); + } + + async fn slow_handler(_req: Request) -> Result, BoxError> { + tokio::time::sleep(Duration::from_secs(10)).await; + Ok(Response::builder() + .status(StatusCode::OK) + .body(Body::empty()) + .unwrap()) + } + + async fn fast_handler(_req: Request) -> Result, BoxError> { + Ok(Response::builder() + .status(StatusCode::OK) + .body(Body::empty()) + .unwrap()) + } +}