Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PKCE support for SSO #5189

Merged
merged 3 commits into from
Dec 2, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
feat: add PKCE
avdb13 committed Nov 28, 2024
commit 52dd992d2408d3bf8578c6a60887e1e7b0748961
6 changes: 6 additions & 0 deletions crates/api_common/src/oauth_provider.rs
Original file line number Diff line number Diff line change
@@ -25,6 +25,8 @@ pub struct CreateOAuthProvider {
#[cfg_attr(feature = "full", ts(optional))]
pub account_linking_enabled: Option<bool>,
#[cfg_attr(feature = "full", ts(optional))]
pub use_pkce: Option<bool>,
#[cfg_attr(feature = "full", ts(optional))]
pub enabled: Option<bool>,
}

@@ -54,6 +56,8 @@ pub struct EditOAuthProvider {
#[cfg_attr(feature = "full", ts(optional))]
pub account_linking_enabled: Option<bool>,
#[cfg_attr(feature = "full", ts(optional))]
pub use_pkce: Option<bool>,
#[cfg_attr(feature = "full", ts(optional))]
pub enabled: Option<bool>,
}

@@ -82,4 +86,6 @@ pub struct AuthenticateWithOauth {
/// An answer is mandatory if require application is enabled on the server
#[cfg_attr(feature = "full", ts(optional))]
pub answer: Option<String>,
#[cfg_attr(feature = "full", ts(optional))]
pub pkce_code_verifier: Option<String>,
}
1 change: 1 addition & 0 deletions crates/api_crud/Cargo.toml
Original file line number Diff line number Diff line change
@@ -29,6 +29,7 @@ anyhow.workspace = true
chrono.workspace = true
webmention = "0.6.0"
accept-language = "3.1.0"
regex = { workspace = true }
serde_json = { workspace = true }
serde = { workspace = true }
serde_with = { workspace = true }
1 change: 1 addition & 0 deletions crates/api_crud/src/oauth_provider/create.rs
Original file line number Diff line number Diff line change
@@ -35,6 +35,7 @@ pub async fn create_oauth_provider(
scopes: data.scopes.to_string(),
auto_verify_email: data.auto_verify_email,
account_linking_enabled: data.account_linking_enabled,
use_pkce: data.use_pkce,
enabled: data.enabled,
};
let oauth_provider = OAuthProvider::create(&mut context.pool(), &oauth_provider_form).await?;
1 change: 1 addition & 0 deletions crates/api_crud/src/oauth_provider/update.rs
Original file line number Diff line number Diff line change
@@ -33,6 +33,7 @@ pub async fn update_oauth_provider(
auto_verify_email: data.auto_verify_email,
account_linking_enabled: data.account_linking_enabled,
enabled: data.enabled,
use_pkce: data.use_pkce,
updated: Some(Some(Utc::now())),
};

54 changes: 43 additions & 11 deletions crates/api_crud/src/user/create.rs
Original file line number Diff line number Diff line change
@@ -44,9 +44,10 @@ use lemmy_utils::{
validation::is_valid_actor_name,
},
};
use regex::Regex;
use serde::{Deserialize, Serialize};
use serde_with::skip_serializing_none;
use std::collections::HashSet;
use std::{collections::HashSet, sync::LazyLock};

#[skip_serializing_none]
#[derive(Debug, Serialize, Deserialize, Clone, Default)]
@@ -218,6 +219,11 @@ pub async fn authenticate_with_oauth(
Err(LemmyErrorType::OauthAuthorizationInvalid)?
}

// validate the PKCE challenge
if let Some(code_verifier) = &data.pkce_code_verifier {
check_code_verifier(code_verifier)?;
}

// Fetch the OAUTH provider and make sure it's enabled
let oauth_provider_id = data.oauth_provider_id;
let oauth_provider = OAuthProvider::read(&mut context.pool(), oauth_provider_id)
@@ -229,9 +235,14 @@ pub async fn authenticate_with_oauth(
return Err(LemmyErrorType::OauthAuthorizationInvalid)?;
}

let token_response =
oauth_request_access_token(&context, &oauth_provider, &data.code, redirect_uri.as_str())
.await?;
let token_response = oauth_request_access_token(
&context,
&oauth_provider,
&data.code,
data.pkce_code_verifier.as_deref(),
redirect_uri.as_str(),
)
.await?;

let user_info = oidc_get_user_info(
&context,
@@ -512,20 +523,27 @@ async fn oauth_request_access_token(
context: &Data<LemmyContext>,
oauth_provider: &OAuthProvider,
code: &str,
pkce_code_verifier: Option<&str>,
redirect_uri: &str,
) -> LemmyResult<TokenResponse> {
let mut form = vec![
("client_id", &*oauth_provider.client_id),
("client_secret", &*oauth_provider.client_secret),
("code", code),
("grant_type", "authorization_code"),
("redirect_uri", redirect_uri),
];

if let Some(code_verifier) = pkce_code_verifier {
form.push(("code_verifier", code_verifier));
}

// Request an Access Token from the OAUTH provider
let response = context
.client()
.post(oauth_provider.token_endpoint.as_str())
.header("Accept", "application/json")
.form(&[
("grant_type", "authorization_code"),
("code", code),
("redirect_uri", redirect_uri),
("client_id", &oauth_provider.client_id),
("client_secret", &oauth_provider.client_secret),
])
.form(&form[..])
.send()
.await
.with_lemmy_type(LemmyErrorType::OauthLoginFailed)?
@@ -575,3 +593,17 @@ fn read_user_info(user_info: &serde_json::Value, key: &str) -> LemmyResult<Strin
}
Err(LemmyErrorType::OauthLoginFailed)?
}

#[allow(clippy::expect_used)]
fn check_code_verifier(code_verifier: &str) -> LemmyResult<()> {
static VALID_CODE_VERIFIER_REGEX: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"^[a-zA-Z0-9\-._~]{43,128}$").expect("compile regex"));

let check = VALID_CODE_VERIFIER_REGEX.is_match(code_verifier);

if check {
Ok(())
} else {
Err(LemmyErrorType::InvalidCodeVerifier.into())
}
}
1 change: 1 addition & 0 deletions crates/db_schema/src/schema.rs
Original file line number Diff line number Diff line change
@@ -636,6 +636,7 @@ diesel::table! {
enabled -> Bool,
published -> Timestamptz,
updated -> Nullable<Timestamptz>,
use_pkce -> Bool,
}
}

5 changes: 5 additions & 0 deletions crates/db_schema/src/source/oauth_provider.rs
Original file line number Diff line number Diff line change
@@ -62,6 +62,8 @@ pub struct OAuthProvider {
pub published: DateTime<Utc>,
#[cfg_attr(feature = "full", ts(optional))]
pub updated: Option<DateTime<Utc>>,
/// switch to enable or disable PKCE
pub use_pkce: bool,
}

#[derive(Clone, PartialEq, Eq, Debug, Deserialize)]
@@ -83,6 +85,7 @@ impl Serialize for PublicOAuthProvider {
state.serialize_field("authorization_endpoint", &self.0.authorization_endpoint)?;
state.serialize_field("client_id", &self.0.client_id)?;
state.serialize_field("scopes", &self.0.scopes)?;
state.serialize_field("use_pkce", &self.0.use_pkce)?;
state.end()
}
}
@@ -102,6 +105,7 @@ pub struct OAuthProviderInsertForm {
pub scopes: String,
pub auto_verify_email: Option<bool>,
pub account_linking_enabled: Option<bool>,
pub use_pkce: Option<bool>,
pub enabled: Option<bool>,
}

@@ -118,6 +122,7 @@ pub struct OAuthProviderUpdateForm {
pub scopes: Option<String>,
pub auto_verify_email: Option<bool>,
pub account_linking_enabled: Option<bool>,
pub use_pkce: Option<bool>,
pub enabled: Option<bool>,
pub updated: Option<Option<DateTime<Utc>>>,
}
1 change: 1 addition & 0 deletions crates/utils/src/error.rs
Original file line number Diff line number Diff line change
@@ -76,6 +76,7 @@ pub enum LemmyErrorType {
InvalidEmailAddress(String),
RateLimitError,
InvalidName,
InvalidCodeVerifier,
InvalidDisplayName,
InvalidMatrixId,
InvalidPostTitle,
3 changes: 3 additions & 0 deletions migrations/2024-11-23-234637_oauth_pkce/down.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
ALTER TABLE oauth_provider
DROP COLUMN use_pkce;

3 changes: 3 additions & 0 deletions migrations/2024-11-23-234637_oauth_pkce/up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
ALTER TABLE oauth_provider
ADD COLUMN use_pkce boolean DEFAULT FALSE NOT NULL;