Skip to content

Commit 8f54419

Browse files
committedSep 26, 2024·
feat(table-providers): make sqlserver pool compile
1 parent a7a7063 commit 8f54419

File tree

5 files changed

+827
-122
lines changed

5 files changed

+827
-122
lines changed
 

‎src-tauri/Cargo.lock

+677-75
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎src-tauri/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ rand = "0.8.5"
3838
bytes = "1.7.1"
3939
arrow-schema = "52.2.0"
4040
indexmap = "2.4.0"
41+
table-providers = { version = "0.1.0", path = "table-providers" }
4142

4243
[features]
4344
# this feature is used for production builds or when `devPath` points to the filesystem and the built-in dev server is disabled.

‎src-tauri/table-providers/Cargo.lock

+21
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎src-tauri/table-providers/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ chrono = "0.4.38"
1313
datafusion = "41.0.0"
1414
datafusion-table-providers = "0.1.0"
1515
futures = "0.3.30"
16+
pin-project = "1.1.5"
1617
secrecy = "0.10.1"
1718
tiberius = { version = "0.12.3", features = ["chrono", "time", "tokio"] }
1819
tokio = { version = "1.40.0", features = ["net"] }

‎src-tauri/table-providers/src/sqlserver/pool.rs

+127-47
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,27 @@
11
//! A connection pool for SQL Server
22
33
use core::fmt;
4+
use std::borrow::Cow;
5+
use std::task::{ready, Poll};
46
use std::{collections::HashMap, sync::Arc};
57

6-
use arrow::datatypes::{Field, Schema, SchemaRef};
8+
use arrow::array::RecordBatch;
9+
use arrow::datatypes::{Schema, SchemaRef};
710

811
use async_stream::stream;
912
use async_trait::async_trait;
13+
use datafusion::execution::RecordBatchStream;
1014
use datafusion::sql::TableReference;
11-
use datafusion::{
12-
error::DataFusionError, execution::SendableRecordBatchStream,
13-
physical_plan::stream::RecordBatchStreamAdapter,
14-
};
15+
use datafusion::{error::DataFusionError, execution::SendableRecordBatchStream};
1516
use datafusion_table_providers::sql::db_connection_pool;
1617
use datafusion_table_providers::sql::db_connection_pool::{
1718
dbconnection::{AsyncDbConnection, DbConnection},
1819
DbConnectionPool, JoinPushDown,
1920
};
20-
use futures::{stream, StreamExt};
21+
use futures::{Stream, StreamExt};
22+
use pin_project::pin_project;
2123
use secrecy::{ExposeSecret, SecretString};
22-
use tiberius::{AuthMethod, Config, EncryptionLevel, ToSql, ColumnData};
24+
use tiberius::{AuthMethod, ColumnData, Config, EncryptionLevel, IntoSql, Query, ToSql};
2325
use tokio::net::TcpStream;
2426
use tokio_util::compat::TokioAsyncWriteCompatExt;
2527

@@ -84,6 +86,12 @@ impl fmt::Display for Error {
8486

8587
impl std::error::Error for Error {}
8688

89+
impl Into<DataFusionError> for Error {
90+
fn into(self) -> DataFusionError {
91+
DataFusionError::Execution(self.to_string())
92+
}
93+
}
94+
8795
pub struct SqlServerConnectionManager {
8896
config: tiberius::Config,
8997
}
@@ -210,11 +218,8 @@ impl DbConnectionPool<SqlServerPooledConnection, &'static dyn ToSql> for SqlServ
210218
Box<dyn DbConnection<SqlServerPooledConnection, &'static dyn ToSql>>,
211219
db_connection_pool::Error,
212220
> {
213-
let pool = Arc::clone(&self.pool);
214-
// let conn = pool.get().await?;
215-
// Ok(Box::new(SqlServerConnection::new(conn)))
216-
217-
todo!();
221+
let conn = self.pool.get_owned().await?;
222+
Ok(Box::new(SqlServerConnection::new(conn)))
218223
}
219224

220225
fn join_push_down(&self) -> JoinPushDown {
@@ -225,11 +230,97 @@ impl DbConnectionPool<SqlServerPooledConnection, &'static dyn ToSql> for SqlServ
225230
type ConnectionError = datafusion_table_providers::sql::db_connection_pool::dbconnection::Error;
226231
type GenericError = datafusion_table_providers::sql::db_connection_pool::dbconnection::GenericError;
227232

228-
struct OwnedColumnData<'a>(ColumnData<'a>);
233+
#[pin_project]
234+
struct SqlRecordBatchStream<S> {
235+
schema: Option<SchemaRef>,
236+
237+
#[pin]
238+
stream: S,
239+
}
240+
241+
impl<S> SqlRecordBatchStream<S> {
242+
fn new(stream: S) -> Self {
243+
Self {
244+
schema: None,
245+
stream,
246+
}
247+
}
248+
}
249+
250+
impl<S, E> Stream for SqlRecordBatchStream<S>
251+
where
252+
S: Stream<Item = Result<RecordBatch, E>>,
253+
E: Into<DataFusionError>,
254+
{
255+
type Item = Result<RecordBatch, DataFusionError>;
256+
257+
fn poll_next(
258+
self: std::pin::Pin<&mut Self>,
259+
cx: &mut std::task::Context<'_>,
260+
) -> Poll<Option<Self::Item>> {
261+
let this = self.project();
262+
263+
let batch = ready!(this.stream.poll_next(cx));
264+
let Some(batch) = batch else {
265+
return Poll::Ready(None);
266+
};
267+
268+
if let Ok(batch) = &batch {
269+
*this.schema = Some(batch.schema());
270+
}
271+
272+
Poll::Ready(Some(batch.map_err(Into::into)))
273+
}
274+
}
275+
276+
impl<S, E> RecordBatchStream for SqlRecordBatchStream<S>
277+
where
278+
S: Stream<Item = Result<RecordBatch, E>>,
279+
E: Into<DataFusionError>,
280+
{
281+
fn schema(&self) -> SchemaRef {
282+
self.schema.clone().unwrap_or(Arc::new(Schema::empty()))
283+
}
284+
}
285+
286+
fn to_owned<'a, T: ?Sized + ToOwned>(val: Cow<'a, T>) -> Cow<'static, T> {
287+
match val {
288+
Cow::Borrowed(val) => Cow::Owned(val.to_owned()),
289+
Cow::Owned(val) => Cow::Owned(val),
290+
}
291+
}
229292

230-
impl<'a> ToSql for OwnedColumnData<'a> {
231-
fn to_sql(&self) -> ColumnData<'_> {
232-
self.0.clone()
293+
/// A [`ColumnData`] that owns the underlying data, meaning that it will
294+
/// transform Cow::Borrowed data to Cow::Owned when needed
295+
struct OwnedColumnData(ColumnData<'static>);
296+
impl<'a> From<ColumnData<'a>> for OwnedColumnData {
297+
fn from(value: ColumnData<'a>) -> Self {
298+
Self(match value {
299+
ColumnData::U8(val) => ColumnData::U8(val),
300+
ColumnData::I16(val) => ColumnData::I16(val),
301+
ColumnData::I32(val) => ColumnData::I32(val),
302+
ColumnData::I64(val) => ColumnData::I64(val),
303+
ColumnData::F32(val) => ColumnData::F32(val),
304+
ColumnData::F64(val) => ColumnData::F64(val),
305+
ColumnData::Bit(val) => ColumnData::Bit(val),
306+
ColumnData::Guid(val) => ColumnData::Guid(val),
307+
ColumnData::Numeric(val) => ColumnData::Numeric(val),
308+
ColumnData::DateTime(val) => ColumnData::DateTime(val),
309+
ColumnData::SmallDateTime(val) => ColumnData::SmallDateTime(val),
310+
ColumnData::Time(val) => ColumnData::Time(val),
311+
ColumnData::Date(val) => ColumnData::Date(val),
312+
ColumnData::DateTime2(val) => ColumnData::DateTime2(val),
313+
ColumnData::DateTimeOffset(val) => ColumnData::DateTimeOffset(val),
314+
ColumnData::String(val) => ColumnData::String(val.map(to_owned)),
315+
ColumnData::Binary(val) => ColumnData::Binary(val.map(to_owned)),
316+
ColumnData::Xml(val) => ColumnData::Xml(val.map(to_owned)),
317+
})
318+
}
319+
}
320+
321+
impl IntoSql<'static> for OwnedColumnData {
322+
fn into_sql(self) -> ColumnData<'static> {
323+
self.0
233324
}
234325
}
235326

@@ -267,19 +358,25 @@ impl<'a> AsyncDbConnection<SqlServerPooledConnection, &'a dyn ToSql> for SqlServ
267358
&self,
268359
sql: &str,
269360
params: &[&'a dyn ToSql],
270-
projected_schema: Option<SchemaRef>,
361+
_projected_schema: Option<SchemaRef>,
271362
) -> Result<SendableRecordBatchStream, GenericError> {
272363
let conn = Arc::clone(&self.conn);
273-
let params = params.iter().map(|p| OwnedColumnData(p.to_sql())).collect::<Vec<_>>();
274364

275-
let mut stream = Box::pin(stream! {
276-
let mut conn = conn.lock().await;
365+
let sql = sql.to_string();
366+
let params = params
367+
.iter()
368+
.map(|p| OwnedColumnData::from(p.to_sql()))
369+
.collect::<Vec<_>>();
277370

278-
let stream = conn
279-
.query(sql, &params)
280-
.await?;
371+
let stream = stream! {
372+
let mut conn = conn.lock().await;
373+
let mut query = Query::new(sql.to_string());
374+
for param in params {
375+
query.bind(param);
376+
}
377+
let stream = query.query(&mut conn).await.map_err(Error::Sql)?;
281378

282-
let mut chunks = stream.chunks(8192).boxed().map(|rows| {
379+
let mut chunks = stream.chunks(8192).map(|rows| {
283380
let rows = rows.into_iter().collect::<Result<Vec<_>, _>>()?;
284381
let rec = super::arrow::rows_to_arrow(rows)?;
285382

@@ -289,32 +386,15 @@ impl<'a> AsyncDbConnection<SqlServerPooledConnection, &'a dyn ToSql> for SqlServ
289386
while let Some(chunk) = chunks.next().await {
290387
yield chunk
291388
}
292-
});
293-
294-
let Some(first_chunk) = stream.next().await else {
295-
return Ok(Box::pin(RecordBatchStreamAdapter::new(
296-
Arc::new(Schema::empty()),
297-
stream::empty(),
298-
)));
299-
};
300-
301-
let first_chunk =
302-
first_chunk.map_err(|e| ConnectionError::UnableToQueryArrow { source: e.into() })?;
303-
let schema = first_chunk.schema();
304-
305-
Ok(Box::pin(RecordBatchStreamAdapter::new(schema, {
306-
stream! {
307-
yield Ok(first_chunk);
308-
309-
while let Some(chunk) = stream.next().await {
310-
yield chunk.map_err(|e| DataFusionError::Execution(format!("failed to fetch batch: {e}")))
311-
}
389+
}
390+
.boxed();
312391

313-
}
314-
})))
392+
Ok(Box::pin(SqlRecordBatchStream::new(stream)))
315393
}
316394

317395
async fn execute(&self, sql: &str, params: &[&'a dyn ToSql]) -> Result<u64, GenericError> {
318-
todo!()
396+
let mut conn = self.conn.lock().await;
397+
let result = conn.execute(sql, params).await?;
398+
Ok(result.into_iter().sum())
319399
}
320400
}

0 commit comments

Comments
 (0)
Please sign in to comment.