Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@ repository = "https://github.com/maxcountryman/axum-sessions"
documentation = "https://docs.rs/axum-sessions"

[dependencies]
async-session = "3.0.0"
async-session = { git = "https://github.com/http-rs/async-session", branch = "overhaul-session-and-session-store", default-features = false }
base64 = "0.21.0"
futures = "0.3.21"
hmac = { version = "0.12.1", features = ["std"] }
http-body = "0.4.5"
sha2 = "0.10.6"
tower = "0.4.12"
tracing = "0.1"

Expand All @@ -34,6 +37,7 @@ features = ["sync"]
http = "0.2.8"
hyper = "0.14.19"
serde = "1.0.147"
serde_json = "1.0.93"

[dev-dependencies.rand]
version = "0.8.5"
Expand All @@ -43,3 +47,8 @@ features = ["min_const_gen"]
version = "1.20.1"
default-features = false
features = ["macros", "rt-multi-thread"]

[dev-dependencies.async-session]
git = "https://github.com/http-rs/async-session"
branch = "overhaul-session-and-session-store"
features = ["memory-store"]
35 changes: 18 additions & 17 deletions src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,7 @@ use std::{
time::Duration,
};

use async_session::{
base64,
hmac::{Hmac, Mac, NewMac},
sha2::Sha256,
SessionStore,
};
use async_session::SessionStore;
use axum::{
http::{
header::{HeaderValue, COOKIE, SET_COOKIE},
Expand All @@ -21,7 +16,10 @@ use axum::{
response::Response,
};
use axum_extra::extract::cookie::{Cookie, Key, SameSite};
use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
use futures::future::BoxFuture;
use hmac::{Hmac, Mac};
use sha2::{digest::generic_array::GenericArray, Sha256};
use tokio::sync::RwLock;
use tower::{Layer, Service};

Expand Down Expand Up @@ -65,7 +63,10 @@ pub struct SessionLayer<Store> {
key: Key,
}

impl<Store: SessionStore> SessionLayer<Store> {
impl<Store> SessionLayer<Store>
where
Store: SessionStore + Clone + Send + Sync + 'static,
{
/// Creates a layer which will attach a [`SessionHandle`] to requests via an
/// extension. This session is derived from a cryptographically signed
/// cookie. When the client sends a valid, known cookie then the session is
Expand Down Expand Up @@ -234,7 +235,7 @@ impl<Store: SessionStore> SessionLayer<Store> {
mac.update(cookie.value().as_bytes());

// Cookie's new value is [MAC | original-value].
let mut new_value = base64::encode(mac.finalize().into_bytes());
let mut new_value = BASE64.encode(mac.finalize().into_bytes());
new_value.push_str(cookie.value());
cookie.set_value(new_value);
}
Expand All @@ -251,18 +252,21 @@ impl<Store: SessionStore> SessionLayer<Store> {

// Split [MAC | original-value] into its two parts.
let (digest_str, value) = cookie_value.split_at(BASE64_DIGEST_LEN);
let digest = base64::decode(digest_str).map_err(|_| "bad base64 digest")?;
let digest = BASE64.decode(digest_str).map_err(|_| "bad base64 digest")?;

// Perform the verification.
let mut mac = Hmac::<Sha256>::new_from_slice(self.key.signing()).expect("good key");
mac.update(value.as_bytes());
mac.verify(&digest)
mac.verify(GenericArray::from_slice(&digest))
.map(|_| value.to_string())
.map_err(|_| "value did not verify")
}
}

impl<Inner, Store: SessionStore> Layer<Inner> for SessionLayer<Store> {
impl<Inner, Store> Layer<Inner> for SessionLayer<Store>
where
Store: SessionStore + Clone + Send + Sync + 'static,
{
type Service = Session<Inner, Store>;

fn layer(&self, inner: Inner) -> Self::Service {
Expand All @@ -280,13 +284,13 @@ pub struct Session<Inner, Store: SessionStore> {
layer: SessionLayer<Store>,
}

impl<Inner, ReqBody, ResBody, Store: SessionStore> Service<Request<ReqBody>>
for Session<Inner, Store>
impl<Inner, ReqBody, ResBody, Store> Service<Request<ReqBody>> for Session<Inner, Store>
where
Inner: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
ResBody: Send + 'static,
ReqBody: Send + 'static,
Inner::Future: Send + 'static,
Store: SessionStore + Clone + Send + Sync + 'static,
{
type Response = Inner::Response;
type Error = Inner::Error;
Expand Down Expand Up @@ -382,17 +386,14 @@ where

#[cfg(test)]
mod tests {
use async_session::{
serde::{Deserialize, Serialize},
serde_json,
};
use axum::http::{Request, Response};
use http::{
header::{COOKIE, SET_COOKIE},
HeaderValue, StatusCode,
};
use hyper::Body;
use rand::Rng;
use serde::{Deserialize, Serialize};
use tower::{BoxError, Service, ServiceBuilder, ServiceExt};

use super::PersistencePolicy;
Expand Down