1
1
//! A connection pool for SQL Server
2
2
3
3
use core:: fmt;
4
+ use std:: borrow:: Cow ;
5
+ use std:: task:: { ready, Poll } ;
4
6
use std:: { collections:: HashMap , sync:: Arc } ;
5
7
6
- use arrow:: datatypes:: { Field , Schema , SchemaRef } ;
8
+ use arrow:: array:: RecordBatch ;
9
+ use arrow:: datatypes:: { Schema , SchemaRef } ;
7
10
8
11
use async_stream:: stream;
9
12
use async_trait:: async_trait;
13
+ use datafusion:: execution:: RecordBatchStream ;
10
14
use datafusion:: sql:: TableReference ;
11
- use datafusion:: {
12
- error:: DataFusionError , execution:: SendableRecordBatchStream ,
13
- physical_plan:: stream:: RecordBatchStreamAdapter ,
14
- } ;
15
+ use datafusion:: { error:: DataFusionError , execution:: SendableRecordBatchStream } ;
15
16
use datafusion_table_providers:: sql:: db_connection_pool;
16
17
use datafusion_table_providers:: sql:: db_connection_pool:: {
17
18
dbconnection:: { AsyncDbConnection , DbConnection } ,
18
19
DbConnectionPool , JoinPushDown ,
19
20
} ;
20
- use futures:: { stream, StreamExt } ;
21
+ use futures:: { Stream , StreamExt } ;
22
+ use pin_project:: pin_project;
21
23
use secrecy:: { ExposeSecret , SecretString } ;
22
- use tiberius:: { AuthMethod , Config , EncryptionLevel , ToSql , ColumnData } ;
24
+ use tiberius:: { AuthMethod , ColumnData , Config , EncryptionLevel , IntoSql , Query , ToSql } ;
23
25
use tokio:: net:: TcpStream ;
24
26
use tokio_util:: compat:: TokioAsyncWriteCompatExt ;
25
27
@@ -84,6 +86,12 @@ impl fmt::Display for Error {
84
86
85
87
impl std:: error:: Error for Error { }
86
88
89
+ impl Into < DataFusionError > for Error {
90
+ fn into ( self ) -> DataFusionError {
91
+ DataFusionError :: Execution ( self . to_string ( ) )
92
+ }
93
+ }
94
+
87
95
pub struct SqlServerConnectionManager {
88
96
config : tiberius:: Config ,
89
97
}
@@ -210,11 +218,8 @@ impl DbConnectionPool<SqlServerPooledConnection, &'static dyn ToSql> for SqlServ
210
218
Box < dyn DbConnection < SqlServerPooledConnection , & ' static dyn ToSql > > ,
211
219
db_connection_pool:: Error ,
212
220
> {
213
- let pool = Arc :: clone ( & self . pool ) ;
214
- // let conn = pool.get().await?;
215
- // Ok(Box::new(SqlServerConnection::new(conn)))
216
-
217
- todo ! ( ) ;
221
+ let conn = self . pool . get_owned ( ) . await ?;
222
+ Ok ( Box :: new ( SqlServerConnection :: new ( conn) ) )
218
223
}
219
224
220
225
fn join_push_down ( & self ) -> JoinPushDown {
@@ -225,11 +230,97 @@ impl DbConnectionPool<SqlServerPooledConnection, &'static dyn ToSql> for SqlServ
225
230
type ConnectionError = datafusion_table_providers:: sql:: db_connection_pool:: dbconnection:: Error ;
226
231
type GenericError = datafusion_table_providers:: sql:: db_connection_pool:: dbconnection:: GenericError ;
227
232
228
- struct OwnedColumnData < ' a > ( ColumnData < ' a > ) ;
233
+ #[ pin_project]
234
+ struct SqlRecordBatchStream < S > {
235
+ schema : Option < SchemaRef > ,
236
+
237
+ #[ pin]
238
+ stream : S ,
239
+ }
240
+
241
+ impl < S > SqlRecordBatchStream < S > {
242
+ fn new ( stream : S ) -> Self {
243
+ Self {
244
+ schema : None ,
245
+ stream,
246
+ }
247
+ }
248
+ }
249
+
250
+ impl < S , E > Stream for SqlRecordBatchStream < S >
251
+ where
252
+ S : Stream < Item = Result < RecordBatch , E > > ,
253
+ E : Into < DataFusionError > ,
254
+ {
255
+ type Item = Result < RecordBatch , DataFusionError > ;
256
+
257
+ fn poll_next (
258
+ self : std:: pin:: Pin < & mut Self > ,
259
+ cx : & mut std:: task:: Context < ' _ > ,
260
+ ) -> Poll < Option < Self :: Item > > {
261
+ let this = self . project ( ) ;
262
+
263
+ let batch = ready ! ( this. stream. poll_next( cx) ) ;
264
+ let Some ( batch) = batch else {
265
+ return Poll :: Ready ( None ) ;
266
+ } ;
267
+
268
+ if let Ok ( batch) = & batch {
269
+ * this. schema = Some ( batch. schema ( ) ) ;
270
+ }
271
+
272
+ Poll :: Ready ( Some ( batch. map_err ( Into :: into) ) )
273
+ }
274
+ }
275
+
276
+ impl < S , E > RecordBatchStream for SqlRecordBatchStream < S >
277
+ where
278
+ S : Stream < Item = Result < RecordBatch , E > > ,
279
+ E : Into < DataFusionError > ,
280
+ {
281
+ fn schema ( & self ) -> SchemaRef {
282
+ self . schema . clone ( ) . unwrap_or ( Arc :: new ( Schema :: empty ( ) ) )
283
+ }
284
+ }
285
+
286
+ fn to_owned < ' a , T : ?Sized + ToOwned > ( val : Cow < ' a , T > ) -> Cow < ' static , T > {
287
+ match val {
288
+ Cow :: Borrowed ( val) => Cow :: Owned ( val. to_owned ( ) ) ,
289
+ Cow :: Owned ( val) => Cow :: Owned ( val) ,
290
+ }
291
+ }
229
292
230
- impl < ' a > ToSql for OwnedColumnData < ' a > {
231
- fn to_sql ( & self ) -> ColumnData < ' _ > {
232
- self . 0 . clone ( )
293
+ /// A [`ColumnData`] that owns the underlying data, meaning that it will
294
+ /// transform Cow::Borrowed data to Cow::Owned when needed
295
+ struct OwnedColumnData ( ColumnData < ' static > ) ;
296
+ impl < ' a > From < ColumnData < ' a > > for OwnedColumnData {
297
+ fn from ( value : ColumnData < ' a > ) -> Self {
298
+ Self ( match value {
299
+ ColumnData :: U8 ( val) => ColumnData :: U8 ( val) ,
300
+ ColumnData :: I16 ( val) => ColumnData :: I16 ( val) ,
301
+ ColumnData :: I32 ( val) => ColumnData :: I32 ( val) ,
302
+ ColumnData :: I64 ( val) => ColumnData :: I64 ( val) ,
303
+ ColumnData :: F32 ( val) => ColumnData :: F32 ( val) ,
304
+ ColumnData :: F64 ( val) => ColumnData :: F64 ( val) ,
305
+ ColumnData :: Bit ( val) => ColumnData :: Bit ( val) ,
306
+ ColumnData :: Guid ( val) => ColumnData :: Guid ( val) ,
307
+ ColumnData :: Numeric ( val) => ColumnData :: Numeric ( val) ,
308
+ ColumnData :: DateTime ( val) => ColumnData :: DateTime ( val) ,
309
+ ColumnData :: SmallDateTime ( val) => ColumnData :: SmallDateTime ( val) ,
310
+ ColumnData :: Time ( val) => ColumnData :: Time ( val) ,
311
+ ColumnData :: Date ( val) => ColumnData :: Date ( val) ,
312
+ ColumnData :: DateTime2 ( val) => ColumnData :: DateTime2 ( val) ,
313
+ ColumnData :: DateTimeOffset ( val) => ColumnData :: DateTimeOffset ( val) ,
314
+ ColumnData :: String ( val) => ColumnData :: String ( val. map ( to_owned) ) ,
315
+ ColumnData :: Binary ( val) => ColumnData :: Binary ( val. map ( to_owned) ) ,
316
+ ColumnData :: Xml ( val) => ColumnData :: Xml ( val. map ( to_owned) ) ,
317
+ } )
318
+ }
319
+ }
320
+
321
+ impl IntoSql < ' static > for OwnedColumnData {
322
+ fn into_sql ( self ) -> ColumnData < ' static > {
323
+ self . 0
233
324
}
234
325
}
235
326
@@ -267,19 +358,25 @@ impl<'a> AsyncDbConnection<SqlServerPooledConnection, &'a dyn ToSql> for SqlServ
267
358
& self ,
268
359
sql : & str ,
269
360
params : & [ & ' a dyn ToSql ] ,
270
- projected_schema : Option < SchemaRef > ,
361
+ _projected_schema : Option < SchemaRef > ,
271
362
) -> Result < SendableRecordBatchStream , GenericError > {
272
363
let conn = Arc :: clone ( & self . conn ) ;
273
- let params = params. iter ( ) . map ( |p| OwnedColumnData ( p. to_sql ( ) ) ) . collect :: < Vec < _ > > ( ) ;
274
364
275
- let mut stream = Box :: pin ( stream ! {
276
- let mut conn = conn. lock( ) . await ;
365
+ let sql = sql. to_string ( ) ;
366
+ let params = params
367
+ . iter ( )
368
+ . map ( |p| OwnedColumnData :: from ( p. to_sql ( ) ) )
369
+ . collect :: < Vec < _ > > ( ) ;
277
370
278
- let stream = conn
279
- . query( sql, & params)
280
- . await ?;
371
+ let stream = stream ! {
372
+ let mut conn = conn. lock( ) . await ;
373
+ let mut query = Query :: new( sql. to_string( ) ) ;
374
+ for param in params {
375
+ query. bind( param) ;
376
+ }
377
+ let stream = query. query( & mut conn) . await . map_err( Error :: Sql ) ?;
281
378
282
- let mut chunks = stream. chunks( 8192 ) . boxed ( ) . map( |rows| {
379
+ let mut chunks = stream. chunks( 8192 ) . map( |rows| {
283
380
let rows = rows. into_iter( ) . collect:: <Result <Vec <_>, _>>( ) ?;
284
381
let rec = super :: arrow:: rows_to_arrow( rows) ?;
285
382
@@ -289,32 +386,15 @@ impl<'a> AsyncDbConnection<SqlServerPooledConnection, &'a dyn ToSql> for SqlServ
289
386
while let Some ( chunk) = chunks. next( ) . await {
290
387
yield chunk
291
388
}
292
- } ) ;
293
-
294
- let Some ( first_chunk) = stream. next ( ) . await else {
295
- return Ok ( Box :: pin ( RecordBatchStreamAdapter :: new (
296
- Arc :: new ( Schema :: empty ( ) ) ,
297
- stream:: empty ( ) ,
298
- ) ) ) ;
299
- } ;
300
-
301
- let first_chunk =
302
- first_chunk. map_err ( |e| ConnectionError :: UnableToQueryArrow { source : e. into ( ) } ) ?;
303
- let schema = first_chunk. schema ( ) ;
304
-
305
- Ok ( Box :: pin ( RecordBatchStreamAdapter :: new ( schema, {
306
- stream ! {
307
- yield Ok ( first_chunk) ;
308
-
309
- while let Some ( chunk) = stream. next( ) . await {
310
- yield chunk. map_err( |e| DataFusionError :: Execution ( format!( "failed to fetch batch: {e}" ) ) )
311
- }
389
+ }
390
+ . boxed ( ) ;
312
391
313
- }
314
- } ) ) )
392
+ Ok ( Box :: pin ( SqlRecordBatchStream :: new ( stream) ) )
315
393
}
316
394
317
395
async fn execute ( & self , sql : & str , params : & [ & ' a dyn ToSql ] ) -> Result < u64 , GenericError > {
318
- todo ! ( )
396
+ let mut conn = self . conn . lock ( ) . await ;
397
+ let result = conn. execute ( sql, params) . await ?;
398
+ Ok ( result. into_iter ( ) . sum ( ) )
319
399
}
320
400
}
0 commit comments