Skip to content

Commit

Permalink
Unifty httpd in blackholes
Browse files Browse the repository at this point in the history
This commit unifies the httpd bits of the blackholes, removing a duplicated
loop in three places. The type logic here started out general and got
gradually more concrete and there's scope to reduce duplication even further
but I don't know that it's a pressing issue.

Signed-off-by: Brian L. Troutwine <[email protected]>
  • Loading branch information
blt committed Dec 27, 2024
1 parent 977e2c1 commit eaf3562
Show file tree
Hide file tree
Showing 5 changed files with 286 additions and 236 deletions.
1 change: 1 addition & 0 deletions lading/src/blackhole.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
use serde::{Deserialize, Serialize};

mod common;
pub mod http;
pub mod splunk_hec;
pub mod sqs;
Expand Down
95 changes: 95 additions & 0 deletions lading/src/blackhole/common.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
use bytes::Bytes;
use http_body_util::combinators::BoxBody;
use hyper::service::Service;
use hyper_util::{
rt::{TokioExecutor, TokioIo},
server::conn::auto,
};
use lading_signal::Watcher;
use std::{net::SocketAddr, sync::Arc};
use tokio::{net::TcpListener, pin, sync::Semaphore, task::JoinSet};
use tracing::{debug, error, info};

#[derive(thiserror::Error, Debug)]
pub enum Error {
/// Wrapper for [`std::io::Error`].
#[error("IO error: {0}")]
Io(std::io::Error),
}

pub(crate) async fn run_httpd<SF, S>(
addr: SocketAddr,
concurrency_limit: usize,
shutdown: Watcher,
make_service: SF,
) -> Result<(), Error>
where
// "service factory"
SF: Send + Sync + 'static + Clone + Fn() -> S,
// The bounds on `S` per
// https://docs.rs/hyper/latest/hyper/service/trait.Service.html and then
// made concrete per
// https://docs.rs/hyper-util/latest/hyper_util/server/conn/auto/struct.Builder.html#method.serve_connection.
S: Service<
hyper::Request<hyper::body::Incoming>,
Response = hyper::Response<BoxBody<Bytes, hyper::Error>>,
Error = hyper::Error,
> + Send
+ 'static,

S::Future: Send + 'static,
{
let listener = TcpListener::bind(addr).await.map_err(Error::Io)?;
let sem = Arc::new(Semaphore::new(concurrency_limit));
let mut join_set = JoinSet::new();

let shutdown_fut = shutdown.recv();
pin!(shutdown_fut);
loop {
tokio::select! {
() = &mut shutdown_fut => {
info!("Shutdown signal received, stopping accept loop.");
break;
}

incoming = listener.accept() => {
let (stream, addr) = match incoming {
Ok(sa) => sa,
Err(e) => {
error!("Error accepting connection: {e}");
continue;
}
};

let sem = Arc::clone(&sem);
let service_factory = make_service.clone();

join_set.spawn(async move {
debug!("Accepted connection from {addr}");
let permit = match sem.acquire_owned().await {
Ok(p) => p,
Err(e) => {
error!("Semaphore closed: {e}");
return;
}
};

let builder = auto::Builder::new(TokioExecutor::new());
let serve_future = builder.serve_connection_with_upgrades(
TokioIo::new(stream),
service_factory(),
);

if let Err(e) = serve_future.await {
error!("Error serving {addr}: {e}");
}
drop(permit);
});
}
}
}

drop(listener);
while join_set.join_next().await.is_some() {}
Ok(())
}
108 changes: 31 additions & 77 deletions lading/src/blackhole/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,14 @@
//! `requests_received`: Total requests received
//!
use std::{net::SocketAddr, sync::Arc, time::Duration};

use bytes::Bytes;
use http::{header::InvalidHeaderValue, status::InvalidStatusCode, HeaderMap};
use http_body_util::{combinators::BoxBody, BodyExt};
use hyper::{header, service::service_fn, Request, Response, StatusCode};
use hyper_util::{
rt::{TokioExecutor, TokioIo},
server::conn::auto,
};
use hyper::{header, Request, Response, StatusCode};
use metrics::counter;
use serde::{Deserialize, Serialize};
use tokio::{pin, sync::Semaphore, task::JoinSet};
use tracing::{debug, error, info};
use std::{net::SocketAddr, time::Duration};
use tracing::{debug, error};

use super::General;

Expand All @@ -42,9 +36,9 @@ pub enum Error {
/// Failed to deserialize the configuration.
#[error("Failed to deserialize the configuration: {0}")]
Serde(#[from] serde_json::Error),
/// Wrapper for [`std::io::Error`].
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
/// Wrapper for [`crate::blackhole::common::Error`].
#[error(transparent)]
Common(#[from] crate::blackhole::common::Error),
}

/// Body variant supported by this blackhole.
Expand Down Expand Up @@ -240,72 +234,32 @@ impl Http {
/// Function will return an error if the configuration is invalid or if
/// receiving a packet fails.
pub async fn run(self) -> Result<(), Error> {
let listener = tokio::net::TcpListener::bind(self.httpd_addr).await?;
let sem = Arc::new(Semaphore::new(self.concurrency_limit));
let mut join_set = JoinSet::new();

let shutdown = self.shutdown.recv();
pin!(shutdown);
loop {
tokio::select! {
() = &mut shutdown => {
info!("shutdown signal received");
break;
}

incoming = listener.accept() => {
let (stream, addr) = match incoming {
Ok((s,a)) => (s,a),
Err(e) => {
error!("accept error: {e}");
continue;
}
};

let metric_labels = self.metric_labels.clone();
let body_bytes = self.body_bytes.clone();
let headers = self.headers.clone();
let status = self.status;
let response_delay = self.response_delay;
let sem = Arc::clone(&sem);

join_set.spawn(async move {
debug!("Accepted connection from {addr}");
let permit = match sem.acquire_owned().await {
Ok(p) => p,
Err(e) => {
error!("Semaphore closed: {e}");
return;
}
};

let builder = auto::Builder::new(TokioExecutor::new());
let serve_future = builder
.serve_connection(
TokioIo::new(stream),
service_fn(move |req: Request<hyper::body::Incoming>| {
debug!("REQUEST: {:?}", req);
srv(
status,
metric_labels.clone(),
body_bytes.clone(),
req,
headers.clone(),
response_delay,
)
})
);
crate::blackhole::common::run_httpd(
self.httpd_addr,
self.concurrency_limit,
self.shutdown,
move || {
let metric_labels = self.metric_labels.clone();
let body_bytes = self.body_bytes.clone();
let headers = self.headers.clone();
let status = self.status;
let response_delay = self.response_delay;

hyper::service::service_fn(move |req| {
debug!("REQUEST: {:?}", req);
srv(
status,
metric_labels.clone(),
body_bytes.clone(),
req,
headers.clone(),
response_delay,
)
})
},
)
.await?;

if let Err(e) = serve_future.await {
error!("Error serving {addr}: {e}");
}
drop(permit);
});
}
}
}
drop(listener);
while join_set.join_next().await.is_some() {}
Ok(())
}
}
Expand Down
Loading

0 comments on commit eaf3562

Please sign in to comment.