Skip to content

Commit 324c573

Browse files
committed
refactor(table-providers): split pool into multiple files
1 parent 95d13a1 commit 324c573

File tree

6 files changed

+234
-207
lines changed

6 files changed

+234
-207
lines changed

src-tauri/table-providers/src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
mod sqlserver;
1+
pub mod sqlserver;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
//! Module that provides an [`AsyncDbConnection`] for SQL Server
2+
use std::{borrow::Cow, sync::Arc};
3+
4+
use arrow::datatypes::SchemaRef;
5+
use async_stream::stream;
6+
use async_trait::async_trait;
7+
use datafusion::{execution::SendableRecordBatchStream, sql::TableReference};
8+
use datafusion_table_providers::sql::db_connection_pool::dbconnection::{
9+
AsyncDbConnection, DbConnection,
10+
};
11+
use futures::StreamExt;
12+
use tiberius::{ColumnData, IntoSql, Query, ToSql};
13+
14+
use super::{pool::SqlServerPooledConnection, stream::SqlRecordBatchStream, Error};
15+
16+
pub struct SqlServerConnection {
17+
conn: Arc<tokio::sync::Mutex<SqlServerPooledConnection>>,
18+
}
19+
20+
type ConnectionError = datafusion_table_providers::sql::db_connection_pool::dbconnection::Error;
21+
type GenericError = datafusion_table_providers::sql::db_connection_pool::dbconnection::GenericError;
22+
23+
impl<'a> DbConnection<SqlServerPooledConnection, &'a dyn ToSql> for SqlServerConnection {
24+
fn as_any(&self) -> &dyn std::any::Any {
25+
self
26+
}
27+
28+
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
29+
self
30+
}
31+
32+
fn as_async(&self) -> Option<&dyn AsyncDbConnection<SqlServerPooledConnection, &'a dyn ToSql>> {
33+
Some(self)
34+
}
35+
}
36+
37+
fn to_owned<'a, T: ?Sized + ToOwned>(val: Cow<'a, T>) -> Cow<'static, T> {
38+
match val {
39+
Cow::Borrowed(val) => Cow::Owned(val.to_owned()),
40+
Cow::Owned(val) => Cow::Owned(val),
41+
}
42+
}
43+
44+
/// A [`ColumnData`] that owns the underlying data, meaning that it will
45+
/// transform Cow::Borrowed data to Cow::Owned when needed
46+
struct OwnedColumnData(ColumnData<'static>);
47+
impl<'a> From<ColumnData<'a>> for OwnedColumnData {
48+
fn from(value: ColumnData<'a>) -> Self {
49+
Self(match value {
50+
ColumnData::U8(val) => ColumnData::U8(val),
51+
ColumnData::I16(val) => ColumnData::I16(val),
52+
ColumnData::I32(val) => ColumnData::I32(val),
53+
ColumnData::I64(val) => ColumnData::I64(val),
54+
ColumnData::F32(val) => ColumnData::F32(val),
55+
ColumnData::F64(val) => ColumnData::F64(val),
56+
ColumnData::Bit(val) => ColumnData::Bit(val),
57+
ColumnData::Guid(val) => ColumnData::Guid(val),
58+
ColumnData::Numeric(val) => ColumnData::Numeric(val),
59+
ColumnData::DateTime(val) => ColumnData::DateTime(val),
60+
ColumnData::SmallDateTime(val) => ColumnData::SmallDateTime(val),
61+
ColumnData::Time(val) => ColumnData::Time(val),
62+
ColumnData::Date(val) => ColumnData::Date(val),
63+
ColumnData::DateTime2(val) => ColumnData::DateTime2(val),
64+
ColumnData::DateTimeOffset(val) => ColumnData::DateTimeOffset(val),
65+
ColumnData::String(val) => ColumnData::String(val.map(to_owned)),
66+
ColumnData::Binary(val) => ColumnData::Binary(val.map(to_owned)),
67+
ColumnData::Xml(val) => ColumnData::Xml(val.map(to_owned)),
68+
})
69+
}
70+
}
71+
72+
impl IntoSql<'static> for OwnedColumnData {
73+
fn into_sql(self) -> ColumnData<'static> {
74+
self.0
75+
}
76+
}
77+
78+
#[async_trait]
79+
impl<'a> AsyncDbConnection<SqlServerPooledConnection, &'a dyn ToSql> for SqlServerConnection {
80+
fn new(conn: SqlServerPooledConnection) -> Self
81+
where
82+
Self: Sized,
83+
{
84+
Self {
85+
conn: Arc::new(tokio::sync::Mutex::new(conn)),
86+
}
87+
}
88+
89+
async fn get_schema(
90+
&self,
91+
table_reference: &TableReference,
92+
) -> Result<SchemaRef, ConnectionError> {
93+
let table_ref = table_reference.to_quoted_string();
94+
let mut conn = self.conn.lock().await;
95+
96+
let stream = conn
97+
.query(format!("select * from {table_ref} limit 1"), &[])
98+
.await
99+
.map_err(|e| ConnectionError::UnableToGetSchema { source: e.into() })?;
100+
101+
let record = super::arrow::stream_to_arrow(stream)
102+
.await
103+
.map_err(|e| ConnectionError::UnableToGetSchema { source: e.into() })?;
104+
105+
Ok(record.schema())
106+
}
107+
108+
async fn query_arrow(
109+
&self,
110+
sql: &str,
111+
params: &[&'a dyn ToSql],
112+
_projected_schema: Option<SchemaRef>,
113+
) -> Result<SendableRecordBatchStream, GenericError> {
114+
let conn = Arc::clone(&self.conn);
115+
116+
let sql = sql.to_string();
117+
let params = params
118+
.iter()
119+
.map(|p| OwnedColumnData::from(p.to_sql()))
120+
.collect::<Vec<_>>();
121+
122+
let stream = stream! {
123+
let mut conn = conn.lock().await;
124+
let mut query = Query::new(sql.to_string());
125+
for param in params {
126+
query.bind(param);
127+
}
128+
let stream = query.query(&mut conn).await.map_err(Error::Sql)?;
129+
130+
let mut chunks = stream.chunks(8192).map(|rows| {
131+
let rows = rows.into_iter().collect::<Result<Vec<_>, _>>()?;
132+
let rec = super::arrow::rows_to_arrow(rows)?;
133+
134+
Ok::<_, Error>(rec)
135+
});
136+
137+
while let Some(chunk) = chunks.next().await {
138+
yield chunk
139+
}
140+
}
141+
.boxed();
142+
143+
Ok(Box::pin(SqlRecordBatchStream::new(stream)))
144+
}
145+
146+
async fn execute(&self, sql: &str, params: &[&'a dyn ToSql]) -> Result<u64, GenericError> {
147+
let mut conn = self.conn.lock().await;
148+
let result = conn.execute(sql, params).await?;
149+
Ok(result.into_iter().sum())
150+
}
151+
}

src-tauri/table-providers/src/sqlserver/connection_string.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@ use connection_string::AdoNetString;
77
use secrecy::SecretString;
88

99
#[derive(Debug, Clone, PartialEq, Eq)]
10-
pub(crate) enum SqlServerConnectionStringError {
10+
pub enum SqlServerConnectionStringError {
1111
Invalid,
1212
MissingHost,
1313
InvalidPort(ParseIntError),
1414
}
1515

1616
impl From<connection_string::Error> for SqlServerConnectionStringError {
17-
fn from(value: connection_string::Error) -> Self {
17+
fn from(_value: connection_string::Error) -> Self {
1818
Self::Invalid
1919
}
2020
}
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
pub mod arrow;
22
mod connection_string;
33

4+
pub mod conn;
5+
pub use conn::SqlServerConnection;
6+
47
pub mod pool;
58
pub use pool::{Error, Result, SqlServerConnectionPool};
9+
10+
mod stream;

0 commit comments

Comments
 (0)