Skip to content

Commit

Permalink
Erase type on server state and place in Arc<State> any hashmap
Browse files Browse the repository at this point in the history
  • Loading branch information
nyxtom committed Jul 18, 2022
1 parent 41afeb9 commit ac0abd3
Show file tree
Hide file tree
Showing 26 changed files with 247 additions and 306 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ serde = "1.0.117"
serde_json = "1.0.59"
routefinder = "0.5.0"
regex = "1.5.5"
hashbrown = "0.12.3"

[dev-dependencies]
async-std = { version = "1.6.5", features = ["unstable", "attributes"] }
Expand Down
10 changes: 2 additions & 8 deletions examples/graphql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::sync::{Arc, RwLock};

use juniper::{http::graphiql, http::GraphQLRequest, RootNode};
use lazy_static::lazy_static;
use tide::{http::mime, Body, Redirect, Request, RequestState, Response, Server, StatusCode};
use tide::{http::mime, Body, Redirect, Request, Response, Server, StatusCode};

#[derive(Clone)]
struct User {
Expand Down Expand Up @@ -76,7 +76,7 @@ lazy_static! {

async fn handle_graphql(mut request: Request) -> tide::Result {
let query: GraphQLRequest = request.body_json().await?;
let response = query.execute(&SCHEMA, request.state());
let response = query.execute(&SCHEMA, request.state::<State>());
let status = if response.is_ok() {
StatusCode::Ok
} else {
Expand Down Expand Up @@ -105,9 +105,3 @@ async fn main() -> std::io::Result<()> {
app.listen("0.0.0.0:8080").await?;
Ok(())
}

impl RequestState<State> for Request {
fn state(&self) -> &State {
self.ext::<State>().unwrap()
}
}
10 changes: 2 additions & 8 deletions examples/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::sync::Arc;

use tide::http::mime;
use tide::utils::{After, Before};
use tide::{Middleware, Next, Request, RequestState, Response, Result, StatusCode};
use tide::{Middleware, Next, Request, Response, Result, StatusCode};

#[derive(Debug)]
struct User {
Expand All @@ -24,7 +24,7 @@ impl UserDatabase {
// application state. Because it depends on a specific request state,
// it would likely be closely tied to a specific application
async fn user_loader(mut request: Request, next: Next) -> Result {
if let Some(user) = request.state().find_user().await {
if let Some(user) = request.state::<UserDatabase>().find_user().await {
tide::log::trace!("user loaded", {user: user.name});
request.set_ext(user);
Ok(next.run(request).await)
Expand Down Expand Up @@ -125,9 +125,3 @@ async fn main() -> Result<()> {
app.listen("127.0.0.1:8080").await?;
Ok(())
}

impl RequestState<UserDatabase> for Request {
fn state(&self) -> &UserDatabase {
self.ext::<UserDatabase>().unwrap()
}
}
12 changes: 2 additions & 10 deletions examples/state.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;

use tide::RequestState;

#[derive(Clone)]
struct State {
value: Arc<AtomicU32>,
Expand All @@ -22,21 +20,15 @@ async fn main() -> tide::Result<()> {
let mut app = tide::with_state(State::new());
app.with(tide::log::LogMiddleware::new());
app.at("/").get(|req: tide::Request| async move {
let state = req.state();
let state = req.state::<State>();
let value = state.value.load(Ordering::Relaxed);
Ok(format!("{}\n", value))
});
app.at("/inc").get(|req: tide::Request| async move {
let state = req.state();
let state = req.state::<State>();
let value = state.value.fetch_add(1, Ordering::Relaxed) + 1;
Ok(format!("{}\n", value))
});
app.listen("127.0.0.1:8080").await?;
Ok(())
}

impl RequestState<State> for tide::Request {
fn state(&self) -> &State {
self.ext::<State>().unwrap()
}
}
12 changes: 3 additions & 9 deletions examples/upload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::sync::Arc;
use async_std::{fs::OpenOptions, io};
use tempfile::TempDir;
use tide::prelude::*;
use tide::{Body, Request, RequestState, Response, StatusCode};
use tide::{Body, Request, Response, StatusCode};

#[derive(Clone)]
struct TempDirState {
Expand All @@ -24,12 +24,6 @@ impl TempDirState {
}
}

impl RequestState<TempDirState> for Request {
fn state(&self) -> &TempDirState {
self.ext::<TempDirState>().unwrap()
}
}

#[async_std::main]
async fn main() -> Result<(), IoError> {
// tide::log::start();
Expand All @@ -44,7 +38,7 @@ async fn main() -> Result<(), IoError> {
app.at(":file")
.put(|req: Request| async move {
let path = req.param("file")?;
let state = req.state();
let state = req.state::<TempDirState>();
let fs_path = state.path().join(path);

let file = OpenOptions::new()
Expand All @@ -64,7 +58,7 @@ async fn main() -> Result<(), IoError> {
})
.get(|req: Request| async move {
let path = req.param("file")?;
let fs_path = req.state().path().join(path);
let fs_path = req.state::<TempDirState>().path().join(path);

if let Ok(body) = Body::from_file(fs_path).await {
Ok(body.into())
Expand Down
2 changes: 1 addition & 1 deletion src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::{Middleware, Request, Response};
/// This trait is automatically implemented for `Fn` types, and so is rarely implemented
/// directly by Tide users.
///
/// In practice, endpoints are functions that take a `Request<State>` as an argument and
/// In practice, endpoints are functions that take a `Request` as an argument and
/// return a type `T` that implements `Into<Response>`.
///
/// # Examples
Expand Down
17 changes: 6 additions & 11 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ mod response_builder;
mod route;
mod router;
mod server;
mod state;

pub mod convert;
pub mod listener;
Expand All @@ -97,8 +98,8 @@ pub use request::Request;
pub use response::Response;
pub use response_builder::ResponseBuilder;
pub use route::Route;
pub use server::RequestState;
pub use server::Server;
pub use state::State;

pub use http_types::{self as http, Body, Error, Status, StatusCode};

Expand All @@ -117,7 +118,7 @@ pub use http_types::{self as http, Body, Error, Status, StatusCode};
/// # Ok(()) }) }
/// ```
#[must_use]
pub fn new() -> server::Server<()> {
pub fn new() -> server::Server {
Server::new()
}

Expand All @@ -131,7 +132,7 @@ pub fn new() -> server::Server<()> {
/// # use async_std::task::block_on;
/// # fn main() -> Result<(), std::io::Error> { block_on(async {
/// #
/// use tide::{Request, RequestState};
/// use tide::{Request};
///
/// /// The shared application state.
/// #[derive(Clone)]
Expand All @@ -144,22 +145,16 @@ pub fn new() -> server::Server<()> {
/// name: "Nori".to_string()
/// };
///
/// impl RequestState<State> for Request {
/// fn state(&self) -> &State {
/// self.ext::<State>().unwrap()
/// }
/// }
///
/// // Initialize the application with state.
/// let mut app = tide::with_state(state);
/// app.at("/").get(|req: Request| async move {
/// Ok(format!("Hello, {}!", &req.state().name))
/// Ok(format!("Hello, {}!", &req.state::<State>().name))
/// });
/// app.listen("127.0.0.1:8080").await?;
/// #
/// # Ok(()) }) }
/// ```
pub fn with_state<State>(state: State) -> server::Server<State>
pub fn with_state<State>(state: State) -> server::Server
where
State: Clone + Send + Sync + 'static,
{
Expand Down
21 changes: 9 additions & 12 deletions src/listener/concurrent_listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ use futures_util::stream::{futures_unordered::FuturesUnordered, StreamExt};
///```
#[derive(Default)]
pub struct ConcurrentListener<State> {
listeners: Vec<Box<dyn Listener<State>>>,
pub struct ConcurrentListener {
listeners: Vec<Box<dyn Listener>>,
}

impl<State: Clone + Send + Sync + 'static> ConcurrentListener<State> {
impl ConcurrentListener {
/// creates a new ConcurrentListener
pub fn new() -> Self {
Self { listeners: vec![] }
Expand All @@ -59,7 +59,7 @@ impl<State: Clone + Send + Sync + 'static> ConcurrentListener<State> {
/// ```
pub fn add<L>(&mut self, listener: L) -> io::Result<()>
where
L: ToListener<State>,
L: ToListener,
{
self.listeners.push(Box::new(listener.to_listener()?));
Ok(())
Expand All @@ -78,19 +78,16 @@ impl<State: Clone + Send + Sync + 'static> ConcurrentListener<State> {
/// # Ok(()) }) }
pub fn with_listener<L>(mut self, listener: L) -> Self
where
L: ToListener<State>,
L: ToListener,
{
self.add(listener).expect("Unable to add listener");
self
}
}

#[async_trait::async_trait]
impl<State> Listener<State> for ConcurrentListener<State>
where
State: Clone + Send + Sync + 'static,
{
async fn bind(&mut self, app: Server<State>) -> io::Result<()> {
impl Listener for ConcurrentListener {
async fn bind(&mut self, app: Server) -> io::Result<()> {
for listener in self.listeners.iter_mut() {
listener.bind(app.clone()).await?;
}
Expand Down Expand Up @@ -118,13 +115,13 @@ where
}
}

impl<State> Debug for ConcurrentListener<State> {
impl Debug for ConcurrentListener {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", self.listeners)
}
}

impl<State> Display for ConcurrentListener<State> {
impl Display for ConcurrentListener {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let string = self
.listeners
Expand Down
24 changes: 9 additions & 15 deletions src/listener/failover_listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,12 @@ use crate::listener::ListenInfo;
///}
///```
#[derive(Default)]
pub struct FailoverListener<State> {
listeners: Vec<Option<Box<dyn Listener<State>>>>,
pub struct FailoverListener {
listeners: Vec<Option<Box<dyn Listener>>>,
index: Option<usize>,
}

impl<State> FailoverListener<State>
where
State: Clone + Send + Sync + 'static,
{
impl FailoverListener {
/// creates a new FailoverListener
pub fn new() -> Self {
Self {
Expand All @@ -69,7 +66,7 @@ where
/// ```
pub fn add<L>(&mut self, listener: L) -> io::Result<()>
where
L: ToListener<State>,
L: ToListener,
{
self.listeners.push(Some(Box::new(listener.to_listener()?)));
Ok(())
Expand All @@ -88,19 +85,16 @@ where
/// # Ok(()) }) }
pub fn with_listener<L>(mut self, listener: L) -> Self
where
L: ToListener<State>,
L: ToListener,
{
self.add(listener).expect("Unable to add listener");
self
}
}

#[async_trait::async_trait]
impl<State> Listener<State> for FailoverListener<State>
where
State: Clone + Send + Sync + 'static,
{
async fn bind(&mut self, app: Server<State>) -> io::Result<()> {
impl Listener for FailoverListener {
async fn bind(&mut self, app: Server) -> io::Result<()> {
for (index, listener) in self.listeners.iter_mut().enumerate() {
let listener = listener.as_deref_mut().expect("bind called twice");
match listener.bind(app.clone()).await {
Expand Down Expand Up @@ -148,13 +142,13 @@ where
}
}

impl<State> Debug for FailoverListener<State> {
impl Debug for FailoverListener {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", self.listeners)
}
}

impl<State> Display for FailoverListener<State> {
impl Display for FailoverListener {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let string = self
.listeners
Expand Down
15 changes: 4 additions & 11 deletions src/listener/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,11 @@ pub(crate) use unix_listener::UnixListener;
/// implement at least one [`ToListener`](crate::listener::ToListener) that
/// outputs your Listener type.
#[async_trait]
pub trait Listener<State>: Debug + Display + Send + Sync + 'static
where
State: Send + Sync + 'static,
{
pub trait Listener: Debug + Display + Send + Sync + 'static {
/// Bind the listener. This starts the listening process by opening the
/// necessary network ports, but not yet accepting incoming connections. This
/// method must be called before `accept`.
async fn bind(&mut self, app: Server<State>) -> io::Result<()>;
async fn bind(&mut self, app: Server) -> io::Result<()>;

/// Start accepting incoming connections. This method must be called only
/// after `bind` has succeeded.
Expand All @@ -54,12 +51,8 @@ where
}

#[async_trait]
impl<L, State> Listener<State> for Box<L>
where
L: Listener<State>,
State: Send + Sync + 'static,
{
async fn bind(&mut self, app: Server<State>) -> io::Result<()> {
impl<L: Listener> Listener for Box<L> {
async fn bind(&mut self, app: Server) -> io::Result<()> {
self.as_mut().bind(app).await
}

Expand Down
Loading

0 comments on commit ac0abd3

Please sign in to comment.