Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 194 additions & 47 deletions snowflake-api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use arrow::record_batch::RecordBatch;
use base64::Engine;
use bytes::{Buf, Bytes};
use futures::future::try_join_all;
use futures::stream::{self, Stream, StreamExt, TryStreamExt};
use regex::Regex;
use reqwest_middleware::ClientWithMiddleware;
use thiserror::Error;
Expand Down Expand Up @@ -182,6 +183,21 @@ impl RawQueryResult {
}
}

/// Internal representation of a parsed query response
/// This unifies the logic for both streaming and non-streaming execution
enum ParsedQueryResponse {
Empty,
Json {
value: serde_json::Value,
schema: Vec<FieldSchema>,
},
Arrow {
base64: String,
chunks: Vec<responses::ExecResponseChunk>,
chunk_headers: std::collections::HashMap<String, String>,
},
}

pub struct AuthArgs {
pub account_identifier: String,
pub warehouse: Option<String>,
Expand Down Expand Up @@ -376,6 +392,50 @@ impl SnowflakeApi {
Ok(())
}

/// Common method to execute a query and parse the response
/// This unifies the logic used by both exec_raw and exec_stream
async fn execute_query(&self, sql: &str) -> Result<ParsedQueryResponse, SnowflakeApiError> {
let resp = self
.run_sql::<ExecResponse>(sql, QueryType::ArrowQuery)
.await?;
log::debug!("Got query response: {resp:?}");

let resp = match resp {
ExecResponse::Query(qr) => Ok(qr),
ExecResponse::PutGet(_) => Err(SnowflakeApiError::UnexpectedResponse),
ExecResponse::Error(e) => Err(SnowflakeApiError::ApiError(
e.data.error_code,
e.message.unwrap_or_default(),
)),
}?;

// Handle empty response
if resp.data.returned == 0 {
log::debug!("Got response with 0 rows");
return Ok(ParsedQueryResponse::Empty);
}

// Handle JSON response
if let Some(value) = resp.data.rowset {
log::debug!("Got JSON response");
return Ok(ParsedQueryResponse::Json {
value,
schema: resp.data.rowtype.into_iter().map(Into::into).collect(),
});
}

// Handle Arrow response
if let Some(base64) = resp.data.rowset_base64 {
Ok(ParsedQueryResponse::Arrow {
base64,
chunks: resp.data.chunks,
chunk_headers: resp.data.chunk_headers,
})
} else {
Err(SnowflakeApiError::BrokenResponse)
}
}

/// Execute a single query against API.
/// If statement is PUT, then file will be uploaded to the Snowflake-managed storage
pub async fn exec(&self, sql: &str) -> Result<QueryResult, SnowflakeApiError> {
Expand Down Expand Up @@ -415,6 +475,114 @@ impl SnowflakeApi {
}
}

/// Execute a query and return a stream of Arrow RecordBatches.
/// This method provides streaming access to query results, fetching chunks on-demand
/// rather than loading all data into memory at once.
///
/// # Returns
///
/// - For SELECT queries with Arrow results: A stream of `RecordBatch` items
/// - For non-SELECT queries: Returns an error (use `exec()` or `exec_raw()` instead)
/// - For PUT queries: Returns an error (use `exec()` or `exec_raw()` instead)
///
/// # Example
///
/// ```rust,no_run
/// use futures::stream::StreamExt;
///
/// # async fn example(api: &snowflake_api::SnowflakeApi) -> Result<(), Box<dyn std::error::Error>> {
/// let mut stream = api.exec_stream("SELECT * FROM large_table").await?;
///
/// while let Some(batch_result) = stream.next().await {
/// let batch = batch_result?;
/// // Process each batch as it arrives
/// println!("Received batch with {} rows", batch.num_rows());
/// }
/// # Ok(())
/// # }
/// ```
pub async fn exec_stream(
&self,
sql: &str,
) -> Result<impl Stream<Item = Result<RecordBatch, SnowflakeApiError>>, SnowflakeApiError> {
let put_re = Regex::new(r"(?i)^(?:/\*.*\*/\s*)*put\s+").unwrap();

// PUT commands are not supported for streaming
if put_re.is_match(sql) {
return Err(SnowflakeApiError::Unimplemented(
"Streaming is not supported for PUT queries".to_string(),
));
}

match self.execute_query(sql).await? {
ParsedQueryResponse::Empty => Ok(stream::empty().boxed()),
ParsedQueryResponse::Json { .. } => Err(SnowflakeApiError::Unimplemented(
"Streaming is not supported for JSON responses. Use exec() or exec_raw() instead."
.to_string(),
)),
ParsedQueryResponse::Arrow {
base64,
chunks,
chunk_headers,
} => {
let connection = Arc::clone(&self.connection);
let chunk_headers_clone = chunk_headers.clone();

// Create a stream that first yields batches from base64 chunk (if present),
// then streams chunks one by one
let base64_stream: Box<
dyn Stream<Item = Result<RecordBatch, SnowflakeApiError>> + Send + Unpin,
> = if !base64.is_empty() {
log::debug!("Streaming base64 encoded response");
let bytes = Bytes::from(
base64::engine::general_purpose::STANDARD.decode(base64)
.map_err(SnowflakeApiError::from)?,
);
match Self::bytes_to_batches_stream(bytes) {
Ok(s) => Box::new(s.map_err(SnowflakeApiError::from).boxed()),
Err(e) => Box::new(
stream::once(async move { Err(SnowflakeApiError::from(e)) }).boxed(),
),
}
} else {
Box::new(stream::empty().boxed())
};

// Create a stream for remote chunks
let chunks_stream = stream::iter(chunks.into_iter().enumerate())
.then(move |(idx, chunk)| {
let connection = Arc::clone(&connection);
let chunk_headers = chunk_headers_clone.clone();
async move {
log::debug!("Fetching chunk {} of {}", idx + 1, chunk.url);
connection
.get_chunk(&chunk.url, &chunk_headers)
.await
.map_err(SnowflakeApiError::from)
}
})
.and_then(move |bytes| async move {
match Self::bytes_to_batches_stream(bytes) {
Ok(s) => Ok(s.map_err(SnowflakeApiError::from)),
Err(e) => Err(SnowflakeApiError::from(e)),
}
})
.try_flatten();

// Combine both streams
Ok(stream::select(base64_stream, chunks_stream.boxed()).boxed())
}
}
}

/// Convert bytes to a stream of RecordBatches
fn bytes_to_batches_stream(
bytes: Bytes,
) -> Result<impl Stream<Item = Result<RecordBatch, ArrowError>>, ArrowError> {
let reader = StreamReader::try_new(bytes.reader(), None)?;
Ok(stream::iter(reader.into_iter()))
}

/// Useful for debugging to get the straight query response
#[cfg(debug_assertions)]
pub async fn exec_response(&mut self, sql: &str) -> Result<ExecResponse, SnowflakeApiError> {
Expand All @@ -430,54 +598,33 @@ impl SnowflakeApi {
}

async fn exec_arrow_raw(&self, sql: &str) -> Result<RawQueryResult, SnowflakeApiError> {
let resp = self
.run_sql::<ExecResponse>(sql, QueryType::ArrowQuery)
.await?;
log::debug!("Got query response: {resp:?}");

let resp = match resp {
// processable response
ExecResponse::Query(qr) => Ok(qr),
ExecResponse::PutGet(_) => Err(SnowflakeApiError::UnexpectedResponse),
ExecResponse::Error(e) => Err(SnowflakeApiError::ApiError(
e.data.error_code,
e.message.unwrap_or_default(),
)),
}?;

// if response was empty, base64 data is empty string
// todo: still return empty arrow batch with proper schema? (schema always included)
if resp.data.returned == 0 {
log::debug!("Got response with 0 rows");
Ok(RawQueryResult::Empty)
} else if let Some(value) = resp.data.rowset {
log::debug!("Got JSON response");
// NOTE: json response could be chunked too. however, go clients should receive arrow by-default,
// unless user sets session variable to return json. This case was added for debugging and status
// information being passed through that fields.
Ok(RawQueryResult::Json(JsonResult {
value,
schema: resp.data.rowtype.into_iter().map(Into::into).collect(),
}))
} else if let Some(base64) = resp.data.rowset_base64 {
// fixme: is it possible to give streaming interface?
let mut chunks = try_join_all(resp.data.chunks.iter().map(|chunk| {
self.connection
.get_chunk(&chunk.url, &resp.data.chunk_headers)
}))
.await?;

// fixme: should base64 chunk go first?
// fixme: if response is chunked is it both base64 + chunks or just chunks?
if !base64.is_empty() {
log::debug!("Got base64 encoded response");
let bytes = Bytes::from(base64::engine::general_purpose::STANDARD.decode(base64)?);
chunks.push(bytes);
match self.execute_query(sql).await? {
ParsedQueryResponse::Empty => Ok(RawQueryResult::Empty),
ParsedQueryResponse::Json { value, schema } => {
Ok(RawQueryResult::Json(JsonResult { value, schema }))
}
ParsedQueryResponse::Arrow {
base64,
chunks,
chunk_headers,
} => {
// Fetch all chunks in parallel
let mut chunk_bytes = try_join_all(chunks.iter().map(|chunk| {
self.connection.get_chunk(&chunk.url, &chunk_headers)
}))
.await?;

// Add base64 chunk if present (note: base64 chunk typically comes first)
if !base64.is_empty() {
log::debug!("Got base64 encoded response");
let bytes = Bytes::from(
base64::engine::general_purpose::STANDARD.decode(base64)?,
);
chunk_bytes.push(bytes);
}

Ok(RawQueryResult::Bytes(chunk_bytes))
}

Ok(RawQueryResult::Bytes(chunks))
} else {
Err(SnowflakeApiError::BrokenResponse)
}
}

Expand Down
Loading