diff --git a/snowflake-api/src/connection.rs b/snowflake-api/src/connection.rs index e7087e1..90ec19a 100644 --- a/snowflake-api/src/connection.rs +++ b/snowflake-api/src/connection.rs @@ -82,6 +82,10 @@ impl Connection { Ok(Self::new_with_middware(client.build())) } + pub fn account_url_base(account_identifier: &str) -> String { + format!("https://{}.snowflakecomputing.com/", &account_identifier) + } + /// Allow a user to provide their own middleware /// /// Users can provide their own middleware to the connection like this: @@ -119,7 +123,7 @@ impl Connection { pub async fn request( &self, query_type: QueryType, - account_identifier: &str, + url_pattern: &str, extra_get_params: &[(&str, &str)], auth: Option<&str>, body: impl serde::Serialize, @@ -144,10 +148,7 @@ impl Connection { ]; get_params.extend_from_slice(extra_get_params); - let url = format!( - "https://{}.snowflakecomputing.com/{}", - &account_identifier, context.path - ); + let url = format!("{}{}", url_pattern, context.path); let url = Url::parse_with_params(&url, get_params)?; let mut headers = HeaderMap::new(); diff --git a/snowflake-api/src/lib.rs b/snowflake-api/src/lib.rs index 2f4789e..c9f6fc6 100644 --- a/snowflake-api/src/lib.rs +++ b/snowflake-api/src/lib.rs @@ -256,7 +256,7 @@ impl SnowflakeApiBuilder { let session = match self.auth.auth_type { AuthType::Password(args) => Session::password_auth( Arc::clone(&connection), - &self.auth.account_identifier, + self.auth.account_identifier.to_uppercase().as_str(), self.auth.warehouse.as_deref(), self.auth.database.as_deref(), self.auth.schema.as_deref(), @@ -266,7 +266,7 @@ impl SnowflakeApiBuilder { ), AuthType::Certificate(args) => Session::cert_auth( Arc::clone(&connection), - &self.auth.account_identifier, + self.auth.account_identifier.to_uppercase().as_str(), self.auth.warehouse.as_deref(), self.auth.database.as_deref(), self.auth.schema.as_deref(), @@ -276,13 +276,7 @@ impl SnowflakeApiBuilder { ), }; - let account_identifier = self.auth.account_identifier.to_uppercase(); - - Ok(SnowflakeApi::new( - Arc::clone(&connection), - session, - account_identifier, - )) + Ok(SnowflakeApi::new(Arc::clone(&connection), session)) } } @@ -290,16 +284,14 @@ impl SnowflakeApiBuilder { pub struct SnowflakeApi { connection: Arc, session: Session, - account_identifier: String, } impl SnowflakeApi { /// Create a new `SnowflakeApi` object with an existing connection and session. - pub fn new(connection: Arc, session: Session, account_identifier: String) -> Self { + pub fn new(connection: Arc, session: Session) -> Self { Self { connection, session, - account_identifier, } } /// Initialize object with password auth. Authentication happens on the first request. @@ -316,7 +308,7 @@ impl SnowflakeApi { let session = Session::password_auth( Arc::clone(&connection), - account_identifier, + account_identifier.to_uppercase().as_str(), warehouse, database, schema, @@ -325,12 +317,7 @@ impl SnowflakeApi { password, ); - let account_identifier = account_identifier.to_uppercase(); - Ok(Self::new( - Arc::clone(&connection), - session, - account_identifier, - )) + Ok(Self::new(Arc::clone(&connection), session)) } /// Initialize object with private certificate auth. Authentication happens on the first request. @@ -347,7 +334,7 @@ impl SnowflakeApi { let session = Session::cert_auth( Arc::clone(&connection), - account_identifier, + account_identifier.to_uppercase().as_str(), warehouse, database, schema, @@ -356,12 +343,48 @@ impl SnowflakeApi { private_key_pem, ); - let account_identifier = account_identifier.to_uppercase(); - Ok(Self::new( + Ok(Self::new(Arc::clone(&connection), session)) + } + + /// Initialize object with directly provided oauth token. Authentication happens on the first request. + pub fn with_oauth_auth( + account_identifier: &str, + warehouse: Option<&str>, + database: Option<&str>, + schema: Option<&str>, + username: &str, + role: Option<&str>, + token: &str, + ) -> Result { + let connection = Arc::new(Connection::new()?); + + let session = Session::oauth_auth( Arc::clone(&connection), - session, - account_identifier, - )) + account_identifier.to_uppercase().as_str(), + warehouse, + database, + schema, + username, + role, + token, + ); + + Ok(Self::new(Arc::clone(&connection), session)) + } + + /// Initialize object with spcs file provided oauth. Authentication happens on the first request. + pub fn with_spcs_oauth_auth( + warehouse: Option<&str>, + database: Option<&str>, + schema: Option<&str>, + role: Option<&str>, + ) -> Result { + let connection = Arc::new(Connection::new()?); + + let session = + Session::spcs_oauth_auth(Arc::clone(&connection), warehouse, database, schema, role)?; + + Ok(Self::new(Arc::clone(&connection), session)) } pub fn from_env() -> Result { @@ -501,7 +524,7 @@ impl SnowflakeApi { .connection .request::( query_type, - &self.account_identifier, + &self.session.base_url, &[], Some(&parts.session_token_auth_header), body, diff --git a/snowflake-api/src/requests.rs b/snowflake-api/src/requests.rs index 77b0434..612d9b9 100644 --- a/snowflake-api/src/requests.rs +++ b/snowflake-api/src/requests.rs @@ -15,6 +15,7 @@ pub struct LoginRequest { } pub type PasswordLoginRequest = LoginRequest; +pub type OAuthLoginRequest = LoginRequest; #[cfg(feature = "cert-auth")] pub type CertLoginRequest = LoginRequest; @@ -24,6 +25,7 @@ pub struct LoginRequestCommon { pub client_app_id: String, pub client_app_version: String, pub svn_revision: String, + pub base_url: String, pub account_name: String, pub login_name: String, pub session_parameters: SessionParameters, @@ -53,6 +55,15 @@ pub struct PasswordRequestData { pub password: String, } +#[derive(Serialize, Debug)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub struct OAuthRequestData { + #[serde(flatten)] + pub login_request_common: LoginRequestCommon, + pub token: String, + pub authenticator: String, +} + #[derive(Serialize, Debug)] #[serde(rename_all = "SCREAMING_SNAKE_CASE")] pub struct CertRequestData { diff --git a/snowflake-api/src/session.rs b/snowflake-api/src/session.rs index 90acaaf..ddc6d7e 100644 --- a/snowflake-api/src/session.rs +++ b/snowflake-api/src/session.rs @@ -1,3 +1,5 @@ +use std::env; +use std::fs; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -11,8 +13,8 @@ use crate::connection::{Connection, QueryType}; #[cfg(feature = "cert-auth")] use crate::requests::{CertLoginRequest, CertRequestData}; use crate::requests::{ - ClientEnvironment, LoginRequest, LoginRequestCommon, PasswordLoginRequest, PasswordRequestData, - RenewSessionRequest, SessionParameters, + ClientEnvironment, LoginRequest, LoginRequestCommon, OAuthLoginRequest, OAuthRequestData, + PasswordLoginRequest, PasswordRequestData, RenewSessionRequest, SessionParameters, }; use crate::responses::AuthResponse; @@ -31,6 +33,21 @@ pub enum AuthError { #[error("Password auth was requested, but password wasn't provided")] MissingPassword, + #[error("OAuth token auth was requested, but token wasn't provided")] + MissingOAuthToken, + + #[error("SPCS OAuth was requested, but token couldn't be found")] + MissingSPCSOAuthToken, + + #[error("SPCS OAuth was requested, but the account name couldn't be found")] + MissingSPCSAccountName, + + #[error("SPCS OAuth was requested, but the host name couldn't be found")] + MissingSPCSHost, + + #[error("Account identifier is missing")] + MissingAccountIdentifier, + #[error("Certificate auth was requested, but certificate wasn't provided")] MissingCertificate, @@ -105,6 +122,8 @@ impl AuthToken { enum AuthType { Certificate, Password, + OAuth, + SPCSOAuth, } /// Requests, caches, and renews authentication tokens. @@ -117,6 +136,7 @@ pub struct Session { auth_tokens: Mutex>, auth_type: AuthType, + pub base_url: String, account_identifier: String, warehouse: Option, @@ -129,6 +149,7 @@ pub struct Session { #[allow(dead_code)] private_key_pem: Option, password: Option, + oauth_token: Option, } // todo: make builder @@ -155,6 +176,7 @@ impl Session { let username = username.to_uppercase(); let role = role.map(str::to_uppercase); let private_key_pem = Some(private_key_pem.to_string()); + let base_url = Connection::account_url_base(&account_identifier); Self { connection, @@ -162,12 +184,14 @@ impl Session { auth_type: AuthType::Certificate, private_key_pem, account_identifier, + base_url, warehouse: warehouse.map(str::to_uppercase), database, username, role, schema, password: None, + oauth_token: None, } } @@ -192,22 +216,103 @@ impl Session { let username = username.to_uppercase(); let password = Some(password.to_string()); let role = role.map(str::to_uppercase); + let base_url = Connection::account_url_base(&account_identifier); Self { connection, auth_tokens: Mutex::new(None), auth_type: AuthType::Password, account_identifier, + base_url, warehouse: warehouse.map(str::to_uppercase), database, username, role, password, schema, + oauth_token: None, + private_key_pem: None, + } + } + + /// Authenticate using OAuth token and account identifier + // fixme: add builder or introduce structs + #[allow(clippy::too_many_arguments)] + pub fn oauth_auth( + connection: Arc, + account_identifier: &str, + warehouse: Option<&str>, + database: Option<&str>, + schema: Option<&str>, + username: &str, + role: Option<&str>, + oauth_token: &str, + ) -> Self { + let account_identifier = account_identifier.to_uppercase(); + + let database = database.map(str::to_uppercase); + let schema = schema.map(str::to_uppercase); + + let username = username.to_uppercase(); + let oauth_token = Some(oauth_token.to_string()); + let role = role.map(str::to_uppercase); + let base_url = Connection::account_url_base(&account_identifier); + + Self { + connection, + auth_tokens: Mutex::new(None), + auth_type: AuthType::OAuth, + account_identifier, + base_url, + warehouse: warehouse.map(str::to_uppercase), + database, + username, + role, + oauth_token, + schema, + password: None, private_key_pem: None, } } + /// Authenticate using OAuth token and spcs url + // fixme: Get the token + // fixme: + #[allow(clippy::too_many_arguments)] + pub fn spcs_oauth_auth( + connection: Arc, + warehouse: Option<&str>, + database: Option<&str>, + schema: Option<&str>, + role: Option<&str>, + ) -> Result { + let database = database.map(str::to_uppercase); + let schema = schema.map(str::to_uppercase); + + let role = role.map(str::to_uppercase); + + let account_identifier = + env::var("SNOWFLAKE_ACCOUNT").map_err(|_| AuthError::MissingSPCSAccountName)?; + let account_host = env::var("SNOWFLAKE_HOST").map_err(|_| AuthError::MissingSPCSHost)?; + let base_url = format!("https://{}/", account_host); + + Ok(Self { + connection, + auth_tokens: Mutex::new(None), + auth_type: AuthType::SPCSOAuth, + account_identifier, + base_url, + warehouse: warehouse.map(str::to_uppercase), + database, + username: "".to_string(), + role, + oauth_token: None, + schema, + password: None, + private_key_pem: None, + }) + } + /// Get cached token or request a new one if old one has expired. pub async fn get_token(&self) -> Result { let mut auth_tokens = self.auth_tokens.lock().await; @@ -221,14 +326,43 @@ impl Session { AuthType::Certificate => { log::info!("Starting session with certificate authentication"); if cfg!(feature = "cert-auth") { - self.create(self.cert_request_body()?).await + self.create( + &self.warehouse, + &self.database, + &self.schema, + &self.role, + self.cert_request_body()?, + ) + .await } else { Err(AuthError::MissingCertificate)? } } AuthType::Password => { log::info!("Starting session with password authentication"); - self.create(self.passwd_request_body()?).await + self.create( + &self.warehouse, + &self.database, + &self.schema, + &self.role, + self.passwd_request_body()?, + ) + .await + } + AuthType::OAuth => { + log::info!("Starting session with oauth authentication"); + self.create( + &self.warehouse, + &self.database, + &self.schema, + &self.role, + self.oauth_request_body()?, + ) + .await + } + AuthType::SPCSOAuth => { + log::info!("Starting session with spcs oauth authentication"); + self.spcs_create().await } }?; *auth_tokens = Some(tokens); @@ -256,7 +390,7 @@ impl Session { .connection .request::( QueryType::CloseSession, - &self.account_identifier, + &self.base_url, &[("delete", "true")], Some(&tokens.session_token.auth_header()), serde_json::Value::default(), @@ -278,7 +412,8 @@ impl Session { #[cfg(feature = "cert-auth")] fn cert_request_body(&self) -> Result { - let full_identifier = format!("{}.{}", &self.account_identifier, &self.username); + let username = self.username.clone(); + let full_identifier = format!("{}.{}", self.account_identifier, username); let private_key_pem = self .private_key_pem .as_ref() @@ -305,26 +440,77 @@ impl Session { }) } + fn oauth_request_body(&self) -> Result { + let oauth_token = self + .oauth_token + .as_ref() + .ok_or(AuthError::MissingOAuthToken)?; + + Ok(OAuthLoginRequest { + data: OAuthRequestData { + login_request_common: self.login_request_common(), + token: oauth_token.clone(), + authenticator: "OAUTH".to_string(), + }, + }) + } + + async fn spcs_create(&self) -> Result { + // We should wait until we are ready to send the request + // before we read the token, as it expires after 10 minutes + let token = fs::read_to_string("/snowflake/session/token") + .map_err(|_| AuthError::MissingSPCSOAuthToken)?; + + // fixme: I am not sure if waiting until now to read all of these + // is necessary or not, but I figured I would allow the user + // to override these on instantiation and only replace the + // user provided ones with the env provides ones if necessary + let warehouse = self + .warehouse + .clone() + .or(env::var("SNOWFLAKE_WAREHOUSE").ok()); + let database = self + .database + .clone() + .or(env::var("SNOWFLAKE_DATABASE").ok()); + let schema = self.schema.clone().or(env::var("SNOWFLAKE_SCHEMA").ok()); + let role = self.role.clone().or(env::var("SNOWFLAKE_ROLE").ok()); + + let body = OAuthLoginRequest { + data: OAuthRequestData { + login_request_common: self.login_request_common(), + token, + authenticator: "OAUTH".to_string(), + }, + }; + self.create(&warehouse, &database, &schema, &role, body) + .await + } + /// Start new session, all the Snowflake temporary objects will be scoped towards it, /// as well as temporary configuration parameters - async fn create( + async fn create( &self, + warehouse: &Option, + database: &Option, + schema: &Option, + role: &Option, body: LoginRequest, ) -> Result { let mut get_params = Vec::new(); - if let Some(warehouse) = &self.warehouse { + if let Some(warehouse) = &warehouse { get_params.push(("warehouse", warehouse.as_str())); } - if let Some(database) = &self.database { + if let Some(database) = &database { get_params.push(("databaseName", database.as_str())); } - if let Some(schema) = &self.schema { + if let Some(schema) = &schema { get_params.push(("schemaName", schema.as_str())); } - if let Some(role) = &self.role { + if let Some(role) = &role { get_params.push(("roleName", role.as_str())); } @@ -332,13 +518,12 @@ impl Session { .connection .request::( QueryType::LoginRequest, - &self.account_identifier, + &self.base_url, &get_params, None, body, ) .await?; - log::debug!("Auth response: {:?}", resp); match resp { AuthResponse::Login(lr) => { @@ -367,6 +552,7 @@ impl Session { svn_revision: String::new(), account_name: self.account_identifier.clone(), login_name: self.username.clone(), + base_url: self.base_url.clone(), session_parameters: SessionParameters { client_validate_default_parameters: true, }, @@ -392,7 +578,7 @@ impl Session { .connection .request( QueryType::TokenRequest, - &self.account_identifier, + &self.base_url, &[], Some(&auth), body,