Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions examples/dioxus-axum/.env.example
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# WorkOS (optional)
# WORKOS_API_KEY =
2 changes: 2 additions & 0 deletions examples/dioxus-axum/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ server = [
"dep:shield-dioxus-axum",
"dep:shield-memory",
"dep:shield-oidc",
"dep:shield-workos",
"dep:tokio",
"dep:tower-sessions",
"dioxus/server",
Expand All @@ -35,6 +36,7 @@ shield-dioxus.workspace = true
shield-dioxus-axum = { workspace = true, optional = true }
shield-memory = { workspace = true, optional = true }
shield-oidc = { workspace = true, features = ["native-tls"], optional = true }
shield-workos = { workspace = true, optional = true }
tokio = { workspace = true, features = ["rt-multi-thread"], optional = true }
tower-sessions = { workspace = true, optional = true }
tracing.workspace = true
53 changes: 36 additions & 17 deletions examples/dioxus-axum/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,19 @@ fn main() {
#[cfg(feature = "server")]
#[tokio::main]
async fn main() {
use std::sync::Arc;
use std::{env, sync::Arc};

use axum::Router;
use dioxus::{
cli_config::fullstack_address_or_localhost,
prelude::{DioxusRouterExt, *},
};
use shield::{Shield, ShieldOptions};
use shield::{ErasedMethod, Method, Shield, ShieldOptions};
use shield_bootstrap::BootstrapDioxusStyle;
use shield_dioxus_axum::{AxumDioxusIntegration, ShieldLayer};
use shield_memory::{MemoryStorage, User};
use shield_oidc::{Keycloak, OidcMethod};
use shield_workos::{WorkosMethod, WorkosOauthProvider, WorkosOptions};
use tokio::net::TcpListener;
use tower_sessions::{Expiry, MemoryStore, SessionManagerLayer, cookie::time::Duration};
use tracing::{Level, info};
Expand All @@ -45,21 +46,39 @@ async fn main() {
let storage = MemoryStorage::new();
let shield = Shield::new(
storage.clone(),
vec![Arc::new(
OidcMethod::new(storage).with_providers([Keycloak::builder(
"keycloak",
"http://localhost:18080/realms/Shield",
"client1",
)
.client_secret("xcpQsaGbRILTljPtX4npjmYMBjKrariJ")
.redirect_url(format!(
"http://localhost:{}/api/auth/oidc/sign-in-callback/keycloak",
dioxus::cli_config::devserver_raw_addr()
.map(|addr| addr.port())
.unwrap_or_else(|| addr.port())
))
.build()]),
)],
[
Some(Arc::new(
OidcMethod::new(storage).with_providers([Keycloak::builder(
"keycloak",
"http://localhost:18080/realms/Shield",
"client1",
)
.client_secret("xcpQsaGbRILTljPtX4npjmYMBjKrariJ")
.redirect_url(format!(
"http://localhost:{}/api/auth/oidc/sign-in-callback/keycloak",
dioxus::cli_config::devserver_raw_addr()
.map(|addr| addr.port())
.unwrap_or_else(|| addr.port())
))
.build()]),
) as Arc<dyn ErasedMethod>),
env::var("WORKOS_API_KEY").ok().map(|api_key| {
Arc::new(
WorkosMethod::from_api_key(&api_key).with_options(
WorkosOptions::builder()
.oauth_providers(vec![
WorkosOauthProvider::AppleOAuth,
WorkosOauthProvider::GoogleOAuth,
WorkosOauthProvider::MicrosoftOAuth,
])
.build(),
),
) as Arc<dyn ErasedMethod>
}),
]
.into_iter()
.flatten()
.collect(),
ShieldOptions::default(),
);
let shield_layer = ShieldLayer::new(shield.clone());
Expand Down
11 changes: 7 additions & 4 deletions packages/core/shield/src/action.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pub trait Action<P: Provider>: ErasedAction + Send + Sync {
Ok(true)
}

fn forms(&self, provider: P) -> Vec<Form>;
async fn forms(&self, provider: P) -> Result<Vec<Form>, ShieldError>;

async fn call(
&self,
Expand All @@ -62,7 +62,10 @@ pub trait ErasedAction: Send + Sync {
session: Session,
) -> Result<bool, ShieldError>;

fn erased_forms(&self, provider: Box<dyn Any + Send + Sync>) -> Vec<Form>;
async fn erased_forms(
&self,
provider: Box<dyn Any + Send + Sync>,
) -> Result<Vec<Form>, ShieldError>;

async fn erased_call(
&self,
Expand All @@ -89,8 +92,8 @@ macro_rules! erased_action {
self.condition(provider.downcast_ref().expect("TODO"), session)
}

fn erased_forms(&self, provider: Box<dyn std::any::Any + Send + Sync>) -> Vec<$crate::Form> {
self.forms(*provider.downcast().expect("TODO"))
async fn erased_forms(&self, provider: Box<dyn std::any::Any + Send + Sync>) -> Result<Vec<$crate::Form>, $crate::ShieldError> {
self.forms(*provider.downcast().expect("TODO")).await
}

async fn erased_call(
Expand Down
6 changes: 3 additions & 3 deletions packages/core/shield/src/actions/sign_out.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ impl SignOutAction {
}))
}

pub fn forms<P: Provider>(_provider: P) -> Vec<Form> {
vec![Form {
pub async fn forms<P: Provider>(_provider: P) -> Result<Vec<Form>, ShieldError> {
Ok(vec![Form {
inputs: vec![Input {
name: "submit".to_owned(),
label: None,
r#type: InputType::Submit(InputTypeSubmit {}),
value: Some(Self::name()),
}],
}]
}])
}
}
2 changes: 1 addition & 1 deletion packages/core/shield/src/shield.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ impl<U: User> Shield<U> {
continue;
}

let forms = action.erased_forms(provider);
let forms = action.erased_forms(provider).await?;
for form in forms {
provider_forms.push(ActionProviderForm {
id: provider_id.clone(),
Expand Down
4 changes: 2 additions & 2 deletions packages/methods/shield-credentials/src/actions/sign_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ impl<U: User + 'static, D: DeserializeOwned + 'static> Action<CredentialsProvide
SignInAction::name()
}

fn forms(&self, _provider: CredentialsProvider) -> Vec<Form> {
vec![self.credentials.form()]
async fn forms(&self, _provider: CredentialsProvider) -> Result<Vec<Form>, ShieldError> {
Ok(vec![self.credentials.form()])
}

async fn call(
Expand Down
4 changes: 2 additions & 2 deletions packages/methods/shield-credentials/src/actions/sign_out.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ impl Action<CredentialsProvider> for CredentialsSignOutAction {
SignOutAction::condition(provider, session)
}

fn forms(&self, provider: CredentialsProvider) -> Vec<Form> {
SignOutAction::forms(provider)
async fn forms(&self, provider: CredentialsProvider) -> Result<Vec<Form>, ShieldError> {
SignOutAction::forms(provider).await
}

async fn call(
Expand Down
15 changes: 11 additions & 4 deletions packages/methods/shield-oauth/src/actions/sign_in.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use async_trait::async_trait;
use oauth2::{CsrfToken, PkceCodeChallenge, Scope, url::form_urlencoded::parse};
use shield::{
Action, ConfigurationError, Form, Request, Response, Session, SessionError, ShieldError,
SignInAction, erased_action,
Action, ConfigurationError, Form, Input, InputType, InputTypeSubmit, Provider, Request,
Response, Session, SessionError, ShieldError, SignInAction, erased_action,
};

use crate::{
Expand All @@ -23,8 +23,15 @@ impl Action<OauthProvider> for OauthSignInAction {
SignInAction::name()
}

fn forms(&self, _provider: OauthProvider) -> Vec<Form> {
vec![Form { inputs: vec![] }]
async fn forms(&self, provider: OauthProvider) -> Result<Vec<Form>, ShieldError> {
Ok(vec![Form {
inputs: vec![Input {
name: "submit".to_owned(),
label: None,
r#type: InputType::Submit(InputTypeSubmit::default()),
value: Some(format!("Sign in with {}", provider.name())),
}],
}])
}

async fn call(
Expand Down
4 changes: 2 additions & 2 deletions packages/methods/shield-oauth/src/actions/sign_in_callback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,8 @@ impl<U: User + 'static> Action<OauthProvider> for OauthSignInCallbackAction<U> {
SignInCallbackAction::condition(provider, session)
}

fn forms(&self, _provider: OauthProvider) -> Vec<Form> {
vec![Form { inputs: vec![] }]
async fn forms(&self, _provider: OauthProvider) -> Result<Vec<Form>, ShieldError> {
Ok(vec![])
}

async fn call(
Expand Down
4 changes: 2 additions & 2 deletions packages/methods/shield-oauth/src/actions/sign_out.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ impl Action<OauthProvider> for OauthSignOutAction {
SignOutAction::condition(provider, session)
}

fn forms(&self, provider: OauthProvider) -> Vec<Form> {
SignOutAction::forms(provider)
async fn forms(&self, provider: OauthProvider) -> Result<Vec<Form>, ShieldError> {
SignOutAction::forms(provider).await
}

async fn call(
Expand Down
2 changes: 1 addition & 1 deletion packages/methods/shield-oauth/src/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use bon::Builder;
#[builder(on(String, into), state_mod(vis = "pub(crate)"))]
pub struct OauthOptions {
#[builder(default = "/")]
pub sign_in_redirect: String,
pub(crate) sign_in_redirect: String,
}

impl Default for OauthOptions {
Expand Down
6 changes: 3 additions & 3 deletions packages/methods/shield-oidc/src/actions/sign_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ impl Action<OidcProvider> for OidcSignInAction {
SignInAction::name()
}

fn forms(&self, provider: OidcProvider) -> Vec<Form> {
vec![Form {
async fn forms(&self, provider: OidcProvider) -> Result<Vec<Form>, ShieldError> {
Ok(vec![Form {
inputs: vec![Input {
name: "submit".to_owned(),
label: None,
r#type: InputType::Submit(InputTypeSubmit::default()),
value: Some(format!("Sign in with {}", provider.name())),
}],
}]
}])
}

async fn call(
Expand Down
4 changes: 2 additions & 2 deletions packages/methods/shield-oidc/src/actions/sign_in_callback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ impl<U: User + 'static> Action<OidcProvider> for OidcSignInCallbackAction<U> {
SignInCallbackAction::condition(provider, session)
}

fn forms(&self, _provider: OidcProvider) -> Vec<Form> {
vec![Form { inputs: vec![] }]
async fn forms(&self, _provider: OidcProvider) -> Result<Vec<Form>, ShieldError> {
Ok(vec![])
}

async fn call(
Expand Down
4 changes: 2 additions & 2 deletions packages/methods/shield-oidc/src/actions/sign_out.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ impl Action<OidcProvider> for OidcSignOutAction {
SignOutAction::condition(provider, session)
}

fn forms(&self, provider: OidcProvider) -> Vec<Form> {
SignOutAction::forms(provider)
async fn forms(&self, provider: OidcProvider) -> Result<Vec<Form>, ShieldError> {
SignOutAction::forms(provider).await
}

async fn call(
Expand Down
2 changes: 1 addition & 1 deletion packages/methods/shield-oidc/src/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use bon::Builder;
#[builder(on(String, into), state_mod(vis = "pub(crate)"))]
pub struct OidcOptions {
#[builder(default = "/")]
pub sign_in_redirect: String,
pub(crate) sign_in_redirect: String,
}

impl Default for OidcOptions {
Expand Down
Loading