Skip to content

Commit

Permalink
Stream Arrow RecordBatches using mpsc channel and tokio_streams (#481)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonmmease authored Aug 14, 2024
1 parent 3583ec5 commit ffc034a
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 70 deletions.
1 change: 1 addition & 0 deletions packages/duckdb-server-rust/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions packages/duckdb-server-rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ axum-streams = { version = "0.18", features=["arrow"] }
async-stream = "0.3"
deadpool = { version = "0.12", features = ["managed"] }
deadpool-r2d2 = {version = "0.4.1"}
tokio-stream = "0.1"

[dev-dependencies]
http-body-util = "0.1.0"
Expand Down
95 changes: 25 additions & 70 deletions packages/duckdb-server-rust/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,20 @@ use tracing::span::Record;
use futures::future::FutureExt;
use futures::stream::StreamExt;

use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;

use crate::interfaces::adapt_anyhow_error;

#[async_trait]
pub trait Database: Send + Sync {
async fn execute(&self, sql: String) -> Result<()>;
async fn get_json(&self, sql: String) -> Result<Vec<u8>>;
async fn get_arrow(&self, sql: String) -> Result<Vec<u8>>;
// async fn stream_arrow(&self, sql: String) -> Result<impl IntoResponse>;
async fn stream_record_batch(
&self,
sql: String,
) -> impl Stream<Item = arrow::array::RecordBatch>;// Result<BoxStream<'static, Result<RecordBatch, Error>>>;
) -> Result<Box<dyn Stream<Item = RecordBatch>>, Error>;
}

type DuckDBManager = deadpool_r2d2::Manager<DuckdbConnectionManager>;
Expand Down Expand Up @@ -108,79 +110,32 @@ impl Database for ConnectionPool {
Ok(buffer)
}

// async fn stream_arrow(&self, sql: String) -> Result<impl IntoResponse> {
// let conn = self.pool.get().await?;
// let stream = conn.interact(move |client| {
// let mut stmt = client.prepare(&sql)?;
// let arrow = stmt.query_arrow([])?;
// let schema = arrow.get_schema();

// let stream = async_stream::try_stream! {
// for batch in arrow {
// yield batch;
// }
// };
// Ok(StreamBodyAs::arrow_ipc_with_errors(schema, stream))
// }).await.map_err(adapt_anyhow_error)??;

// let response = stream.into_response();
// Ok(response)
// }

async fn stream_record_batch(
&self,
sql: String,
) -> Result<BoxStream<'static, Result<RecordBatch, Error>>> {
) -> Result<Box<dyn Stream<Item = RecordBatch>>, Error> {
let conn = self.pool.get().await?;

let stream = std::pin::pin! { conn
.interact(move |conn| {
let (tx, rx) = mpsc::channel(100);

tokio::spawn(async move {
let result = conn.interact(move |conn| {
let mut stmt = conn.prepare(&sql)?;
let arrow = stmt.query_arrow([])?;
let schema = arrow.get_schema();

let stream = stream! {
for batch in arrow {
yield batch;

for batch in arrow {
if let Err(e) = tx.blocking_send(batch) {
tracing::error!("Error processing batch: {:?}", e);
break;
}
};

Ok(stream)
})
.await
.map_err(adapt_anyhow_error)?? };

Ok(stream)
}

// fn stream_record_batch(&self, sql: &str) -> impl Stream<Item = Result<RecordBatch>> {
// let pool = self.pool.clone();
// async_stream::try_stream! {
// let conn = pool.get()?;
// let mut stmt = conn.prepare(sql)?;
// let arrow = stmt.query_arrow([])?;

// for batch in arrow {
// yield batch;
// }
// }
// }

// fn stream_record_batch(
// &self,
// sql: &str,
// ) -> Pin<Box<dyn Stream<Item = Result<RecordBatch, anyhow::Error>> + Send>> {
// let pool = self.pool.clone();
// Box::pin(futures::stream::unfold(
// (pool, sql.to_string()),
// |(pool, sql)| async move {
// let conn = pool.get().ok()?;
// let mut stmt = conn.prepare(&sql).ok()?;
// let mut arrow = stmt.query_arrow([]).ok()?;

// let batch = arrow.next()?;
// Some((Ok(batch), (pool, sql)))
// },
// ))
// }
}
Ok::<_, Error>(())
}).await;

if let Err(e) = result {
tracing::error!("Error in database interaction: {:?}", e);
}
});

Ok(Box::new(ReceiverStream::new(rx)))
}
}

0 comments on commit ffc034a

Please sign in to comment.