Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions snowflake-api/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ impl Connection {
Ok(Self::new_with_middware(client.build()))
}

pub fn account_url_base(account_identifier: &str) -> String {
format!("https://{}.snowflakecomputing.com/", &account_identifier)
}

/// Allow a user to provide their own middleware
///
/// Users can provide their own middleware to the connection like this:
Expand Down Expand Up @@ -119,7 +123,7 @@ impl Connection {
pub async fn request<R: serde::de::DeserializeOwned>(
&self,
query_type: QueryType,
account_identifier: &str,
url_pattern: &str,
extra_get_params: &[(&str, &str)],
auth: Option<&str>,
body: impl serde::Serialize,
Expand All @@ -144,10 +148,7 @@ impl Connection {
];
get_params.extend_from_slice(extra_get_params);

let url = format!(
"https://{}.snowflakecomputing.com/{}",
&account_identifier, context.path
);
let url = format!("{}{}", url_pattern, context.path);
let url = Url::parse_with_params(&url, get_params)?;

let mut headers = HeaderMap::new();
Expand Down
75 changes: 49 additions & 26 deletions snowflake-api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ impl SnowflakeApiBuilder {
let session = match self.auth.auth_type {
AuthType::Password(args) => Session::password_auth(
Arc::clone(&connection),
&self.auth.account_identifier,
self.auth.account_identifier.to_uppercase().as_str(),
self.auth.warehouse.as_deref(),
self.auth.database.as_deref(),
self.auth.schema.as_deref(),
Expand All @@ -266,7 +266,7 @@ impl SnowflakeApiBuilder {
),
AuthType::Certificate(args) => Session::cert_auth(
Arc::clone(&connection),
&self.auth.account_identifier,
self.auth.account_identifier.to_uppercase().as_str(),
self.auth.warehouse.as_deref(),
self.auth.database.as_deref(),
self.auth.schema.as_deref(),
Expand All @@ -276,30 +276,22 @@ impl SnowflakeApiBuilder {
),
};

let account_identifier = self.auth.account_identifier.to_uppercase();

Ok(SnowflakeApi::new(
Arc::clone(&connection),
session,
account_identifier,
))
Ok(SnowflakeApi::new(Arc::clone(&connection), session))
}
}

/// Snowflake API, keeps connection pool and manages session for you
pub struct SnowflakeApi {
connection: Arc<Connection>,
session: Session,
account_identifier: String,
}

impl SnowflakeApi {
/// Create a new `SnowflakeApi` object with an existing connection and session.
pub fn new(connection: Arc<Connection>, session: Session, account_identifier: String) -> Self {
pub fn new(connection: Arc<Connection>, session: Session) -> Self {
Self {
connection,
session,
account_identifier,
}
}
/// Initialize object with password auth. Authentication happens on the first request.
Expand All @@ -316,7 +308,7 @@ impl SnowflakeApi {

let session = Session::password_auth(
Arc::clone(&connection),
account_identifier,
account_identifier.to_uppercase().as_str(),
warehouse,
database,
schema,
Expand All @@ -325,12 +317,7 @@ impl SnowflakeApi {
password,
);

let account_identifier = account_identifier.to_uppercase();
Ok(Self::new(
Arc::clone(&connection),
session,
account_identifier,
))
Ok(Self::new(Arc::clone(&connection), session))
}

/// Initialize object with private certificate auth. Authentication happens on the first request.
Expand All @@ -347,7 +334,7 @@ impl SnowflakeApi {

let session = Session::cert_auth(
Arc::clone(&connection),
account_identifier,
account_identifier.to_uppercase().as_str(),
warehouse,
database,
schema,
Expand All @@ -356,12 +343,48 @@ impl SnowflakeApi {
private_key_pem,
);

let account_identifier = account_identifier.to_uppercase();
Ok(Self::new(
Ok(Self::new(Arc::clone(&connection), session))
}

/// Initialize object with directly provided oauth token. Authentication happens on the first request.
pub fn with_oauth_auth(
account_identifier: &str,
warehouse: Option<&str>,
database: Option<&str>,
schema: Option<&str>,
username: &str,
role: Option<&str>,
token: &str,
) -> Result<Self, SnowflakeApiError> {
let connection = Arc::new(Connection::new()?);

let session = Session::oauth_auth(
Arc::clone(&connection),
session,
account_identifier,
))
account_identifier.to_uppercase().as_str(),
warehouse,
database,
schema,
username,
role,
token,
);

Ok(Self::new(Arc::clone(&connection), session))
}

/// Initialize object with spcs file provided oauth. Authentication happens on the first request.
pub fn with_spcs_oauth_auth(
warehouse: Option<&str>,
database: Option<&str>,
schema: Option<&str>,
role: Option<&str>,
) -> Result<Self, SnowflakeApiError> {
let connection = Arc::new(Connection::new()?);

let session =
Session::spcs_oauth_auth(Arc::clone(&connection), warehouse, database, schema, role)?;

Ok(Self::new(Arc::clone(&connection), session))
}

pub fn from_env() -> Result<Self, SnowflakeApiError> {
Expand Down Expand Up @@ -501,7 +524,7 @@ impl SnowflakeApi {
.connection
.request::<R>(
query_type,
&self.account_identifier,
&self.session.base_url,
&[],
Some(&parts.session_token_auth_header),
body,
Expand Down
11 changes: 11 additions & 0 deletions snowflake-api/src/requests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pub struct LoginRequest<T> {
}

pub type PasswordLoginRequest = LoginRequest<PasswordRequestData>;
pub type OAuthLoginRequest = LoginRequest<OAuthRequestData>;
#[cfg(feature = "cert-auth")]
pub type CertLoginRequest = LoginRequest<CertRequestData>;

Expand All @@ -24,6 +25,7 @@ pub struct LoginRequestCommon {
pub client_app_id: String,
pub client_app_version: String,
pub svn_revision: String,
pub base_url: String,
pub account_name: String,
pub login_name: String,
pub session_parameters: SessionParameters,
Expand Down Expand Up @@ -53,6 +55,15 @@ pub struct PasswordRequestData {
pub password: String,
}

#[derive(Serialize, Debug)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub struct OAuthRequestData {
#[serde(flatten)]
pub login_request_common: LoginRequestCommon,
pub token: String,
pub authenticator: String,
}

#[derive(Serialize, Debug)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub struct CertRequestData {
Expand Down
Loading