Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 35 additions & 16 deletions snowflake-api/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ pub enum ConnectionError {
/// Container for query parameters
/// This API has different endpoints and MIME types for different requests
struct QueryContext {
path: &'static str,
path: String,
accept_mime: &'static str,
method: reqwest::Method
}

pub enum QueryType {
Expand All @@ -39,30 +40,40 @@ pub enum QueryType {
CloseSession,
JsonQuery,
ArrowQuery,
ArrowQueryResult(String),
}

impl QueryType {
const fn query_context(&self) -> QueryContext {
fn query_context(&self) -> QueryContext {
match self {
Self::LoginRequest => QueryContext {
path: "session/v1/login-request",
path: "session/v1/login-request".to_string(),
accept_mime: "application/json",
method: reqwest::Method::POST,
},
Self::TokenRequest => QueryContext {
path: "/session/token-request",
path: "/session/token-request".to_string(),
accept_mime: "application/snowflake",
method: reqwest::Method::POST,
},
Self::CloseSession => QueryContext {
path: "session",
path: "session".to_string(),
accept_mime: "application/snowflake",
method: reqwest::Method::POST,
},
Self::JsonQuery => QueryContext {
path: "queries/v1/query-request",
path: "queries/v1/query-request".to_string(),
accept_mime: "application/json",
method: reqwest::Method::POST,
},
Self::ArrowQuery => QueryContext {
path: "queries/v1/query-request",
path: "queries/v1/query-request".to_string(),
accept_mime: "application/snowflake",
method: reqwest::Method::POST,
},
Self::ArrowQueryResult(query_result_url) => QueryContext {
path: query_result_url.to_string(),
accept_mime: "application/snowflake",
method: reqwest::Method::GET,
},
}
}
Expand Down Expand Up @@ -163,14 +174,22 @@ impl Connection {
}

// todo: persist client to use connection polling
let resp = self
.client
.post(url)
.headers(headers)
.json(&body)
.send()
.await?;

let resp = match context.method {
reqwest::Method::POST => self
.client
.post(url)
.headers(headers)
.json(&body)
.send()
.await?,
reqwest::Method::GET => self
.client
.get(url)
.headers(headers)
.send()
.await?,
_ => panic!("Unsupported method"),
};
Ok(resp.json::<R>().await?)
}

Expand Down
40 changes: 38 additions & 2 deletions snowflake-api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ impl SnowflakeApi {

match resp {
ExecResponse::Query(_) => Err(SnowflakeApiError::UnexpectedResponse),
ExecResponse::QueryAsync(_) => Err(SnowflakeApiError::UnexpectedResponse),
ExecResponse::PutGet(pg) => put::put(pg).await,
ExecResponse::Error(e) => Err(SnowflakeApiError::ApiError(
e.data.error_code,
Expand All @@ -430,14 +431,21 @@ impl SnowflakeApi {
}

async fn exec_arrow_raw(&self, sql: &str) -> Result<RawQueryResult, SnowflakeApiError> {
let resp = self
let mut resp = self
.run_sql::<ExecResponse>(sql, QueryType::ArrowQuery)
.await?;
log::debug!("Got query response: {:?}", resp);

if let ExecResponse::QueryAsync(data) = &resp {
log::debug!("Got async exec response");
resp = self.get_async_exec_result(&data.data.get_result_url).await?;
log::debug!("Got result for async exec: {:?}", resp);
}

let resp = match resp {
// processable response
ExecResponse::Query(qr) => Ok(qr),
ExecResponse::QueryAsync(_) => Err(SnowflakeApiError::UnexpectedResponse),
ExecResponse::PutGet(_) => Err(SnowflakeApiError::UnexpectedResponse),
ExecResponse::Error(e) => Err(SnowflakeApiError::ApiError(
e.data.error_code,
Expand Down Expand Up @@ -504,10 +512,38 @@ impl SnowflakeApi {
&self.account_identifier,
&[],
Some(&parts.session_token_auth_header),
body,
Some(body),
)
.await?;

Ok(resp)
}

pub async fn get_async_exec_result(&self, query_result_url: &String) -> Result<ExecResponse, SnowflakeApiError>{
log::debug!("Getting async exec result: {}", query_result_url);

let mut delay = 1; // Initial delay of 1 second

loop {
let parts = self.session.get_token().await?;
let resp = self
.connection
.request::<ExecResponse>(
QueryType::ArrowQueryResult(query_result_url.to_string()),
&self.account_identifier,
&[],
Some(&parts.session_token_auth_header),
serde_json::Value::default()
)
.await?;

if let ExecResponse::QueryAsync(_) = &resp {
// simple exponential retry with a maximum wait time of 5 seconds
tokio::time::sleep(tokio::time::Duration::from_secs(delay)).await;
delay = (delay * 2).min(5); // cap delay to 5 seconds
} else {
return Ok(resp);
}
};
}
}
11 changes: 10 additions & 1 deletion snowflake-api/src/responses.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use serde::Deserialize;
#[serde(untagged)]
pub enum ExecResponse {
Query(QueryExecResponse),
QueryAsync(QueryAsyncExecResponse),
PutGet(PutGetExecResponse),
Error(ExecErrorResponse),
}
Expand Down Expand Up @@ -34,6 +35,7 @@ pub struct BaseRestResponse<D> {

pub type PutGetExecResponse = BaseRestResponse<PutGetResponseData>;
pub type QueryExecResponse = BaseRestResponse<QueryExecResponseData>;
pub type QueryAsyncExecResponse = BaseRestResponse<QueryAsyncExecResponseData>;
pub type ExecErrorResponse = BaseRestResponse<ExecErrorResponseData>;
pub type AuthErrorResponse = BaseRestResponse<AuthErrorResponseData>;
pub type AuthenticatorResponse = BaseRestResponse<AuthenticatorResponseData>;
Expand All @@ -54,7 +56,7 @@ pub struct ExecErrorResponseData {
pub pos: Option<i64>,

// fixme: only valid for exec query response error? present in any exec query response?
pub query_id: String,
pub query_id: Option<String>,
pub sql_state: String,
}

Expand Down Expand Up @@ -151,6 +153,13 @@ pub struct QueryExecResponseData {
// `sendResultTime`, `queryResultFormat`, `queryContext` also exist
}

#[derive(Deserialize, Debug)]
#[serde(rename_all = "camelCase")]
pub struct QueryAsyncExecResponseData {
pub query_id: String,
pub get_result_url: String,
}

#[derive(Deserialize, Debug)]
pub struct ExecResponseRowType {
pub name: String,
Expand Down