diff --git a/src/vtab/arrow.rs b/src/vtab/arrow.rs index e6d7eb92..42f70769 100644 --- a/src/vtab/arrow.rs +++ b/src/vtab/arrow.rs @@ -1,13 +1,13 @@ use super::{ - vector::{FlatVector, ListVector, Vector}, - BindInfo, DataChunk, Free, FunctionInfo, InitInfo, LogicalType, LogicalTypeId, VTab, + vector::{FlatVector, ListVector, UnionVector, Vector}, + BindInfo, DataChunk, Free, FunctionInfo, InitInfo, LogicalType, LogicalTypeId, StructVector, VTab, }; use crate::vtab::vector::Inserter; use arrow::array::{ - as_boolean_array, as_large_list_array, as_list_array, as_primitive_array, as_string_array, Array, ArrayData, - BooleanArray, Decimal128Array, FixedSizeListArray, GenericListArray, OffsetSizeTrait, PrimitiveArray, StringArray, - StructArray, + as_boolean_array, as_large_list_array, as_list_array, as_primitive_array, as_string_array, as_struct_array, + as_union_array, Array, ArrayData, BooleanArray, Decimal128Array, FixedSizeListArray, GenericListArray, + OffsetSizeTrait, PrimitiveArray, StringArray, StructArray, UnionArray, }; use arrow::{ @@ -172,15 +172,18 @@ pub fn to_duckdb_logical_type(data_type: &DataType) -> Result Result { fixed_size_list_array_to_vector(as_fixed_size_list_array(col.as_ref()), &mut chunk.list_vector(i)); } - // DataType::Struct(_) => { - // let struct_array = as_struct_array(col.as_ref()); - // let mut struct_vector = chunk.struct_vector(i); - // struct_array_to_vector(struct_array, &mut struct_vector); - // } + DataType::Struct(_) => { + let struct_array = as_struct_array(col.as_ref()); + let mut struct_vector = chunk.struct_vector(i); + struct_array_to_vector(struct_array, &mut struct_vector); + } + DataType::Union(fields, mode) => { + assert_eq!( + mode, + &UnionMode::Sparse, + "duckdb only supports Sparse array for union types" + ); + let union_array = as_union_array(col.as_ref()); + let mut union_vector = chunk.union_vector(i); + union_array_to_vector(fields, union_array, &mut union_vector); + } _ => { - println!( + unimplemented!( "column {} is not supported yet, please file an issue https://github.com/wangfenjin/duckdb-rs", batch.schema().field(i) ); - todo!() } } } @@ -406,46 +419,66 @@ fn as_fixed_size_list_array(arr: &dyn Array) -> &FixedSizeListArray { arr.as_any().downcast_ref::().unwrap() } -// fn struct_array_to_vector(array: &StructArray, out: &mut StructVector) { -// for i in 0..array.num_columns() { -// let column = array.column(i); -// match column.data_type() { -// dt if dt.is_primitive() || matches!(dt, DataType::Boolean) => { -// primitive_array_to_vector(column, &mut out.child(i)); -// } -// DataType::Utf8 => { -// string_array_to_vector(as_string_array(column.as_ref()), &mut out.child(i)); -// } -// DataType::List(_) => { -// list_array_to_vector( -// as_list_array(column.as_ref()), -// &mut out.list_vector_child(i), -// ); -// } -// DataType::LargeList(_) => { -// list_array_to_vector( -// as_large_list_array(column.as_ref()), -// &mut out.list_vector_child(i), -// ); -// } -// DataType::FixedSizeList(_, _) => { -// fixed_size_list_array_to_vector( -// as_fixed_size_list_array(column.as_ref()), -// &mut out.list_vector_child(i), -// ); -// } -// DataType::Struct(_) => { -// let struct_array = as_struct_array(column.as_ref()); -// let mut struct_vector = out.struct_vector_child(i); -// struct_array_to_vector(struct_array, &mut struct_vector); -// } -// _ => { -// println!("Unsupported data type: {}, please file an issue https://github.com/wangfenjin/duckdb-rs", column.data_type()); -// todo!() -// } -// } -// } -// } +fn struct_array_to_vector(array: &StructArray, out: &mut StructVector) { + for i in 0..array.num_columns() { + let column = array.column(i); + match column.data_type() { + dt if dt.is_primitive() || matches!(dt, DataType::Boolean) => { + primitive_array_to_vector(column, &mut out.child(i)); + } + DataType::Utf8 => { + string_array_to_vector(as_string_array(column.as_ref()), &mut out.child(i)); + } + DataType::List(_) => { + list_array_to_vector(as_list_array(column.as_ref()), &mut out.list_vector_child(i)); + } + DataType::LargeList(_) => { + list_array_to_vector(as_large_list_array(column.as_ref()), &mut out.list_vector_child(i)); + } + DataType::FixedSizeList(_, _) => { + fixed_size_list_array_to_vector( + as_fixed_size_list_array(column.as_ref()), + &mut out.list_vector_child(i), + ); + } + DataType::Struct(_) => { + let struct_array = as_struct_array(column.as_ref()); + let mut struct_vector = out.struct_vector_child(i); + struct_array_to_vector(struct_array, &mut struct_vector); + } + _ => { + unimplemented!( + "Unsupported data type: {}, please file an issue https://github.com/wangfenjin/duckdb-rs", + column.data_type() + ); + } + } + } +} + +fn union_array_to_vector(union_fields: &UnionFields, array: &UnionArray, out: &mut UnionVector) { + // set the tag array + let out_tag_array = &mut out.tag_vector(); + out_tag_array.copy(array.type_ids()); + + for (i, (type_id, _)) in union_fields.iter().enumerate() { + let column = array.child(type_id); + match column.data_type() { + dt if dt.is_primitive() || matches!(dt, DataType::Boolean) => { + primitive_array_to_vector(column, &mut out.member_vector(i)); + } + DataType::Utf8 => { + string_array_to_vector(as_string_array(column.as_ref()), &mut out.member_vector(i)); + } + _ => { + unimplemented!( + "Unsupported data type: {}, please file an issue https://github.com/wangfenjin/duckdb-rs", + column.data_type() + ); + } + } + } +} /// Pass RecordBatch to duckdb. /// @@ -485,8 +518,9 @@ mod test { use super::{arrow_recordbatch_to_query_params, ArrowVTab}; use crate::{Connection, Result}; use arrow::{ - array::{Float64Array, Int32Array}, - datatypes::{DataType, Field, Schema}, + array::{Array, ArrayRef, Float64Array, Int32Array, StringArray, StructArray, UnionArray}, + buffer::Buffer, + datatypes::{DataType, Field, Fields, Schema, UnionFields}, record_batch::RecordBatch, }; use std::{error::Error, sync::Arc}; @@ -531,4 +565,80 @@ mod test { assert_eq!(column.value(0), 15); Ok(()) } + + #[test] + fn test_append_struct() -> Result<(), Box> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE t1 (s STRUCT(v VARCHAR, i INTEGER))")?; + { + let struct_array = StructArray::from(vec![ + ( + Arc::new(Field::new("v", DataType::Utf8, true)), + Arc::new(StringArray::from(vec![Some("foo"), Some("bar")])) as ArrayRef, + ), + ( + Arc::new(Field::new("i", DataType::Int32, true)), + Arc::new(Int32Array::from(vec![Some(1), Some(2)])) as ArrayRef, + ), + ]); + + let schema = Schema::new(vec![Field::new( + "s", + DataType::Struct(Fields::from(vec![ + Field::new("v", DataType::Utf8, true), + Field::new("i", DataType::Int32, true), + ])), + true, + )]); + + let record_batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(struct_array)])?; + let mut app = db.appender("t1")?; + app.append_record_batch(record_batch)?; + } + let mut stmt = db.prepare("SELECT s FROM t1")?; + let rbs: Vec = stmt.query_arrow([])?.collect(); + assert_eq!(rbs.iter().map(|op| op.num_rows()).sum::(), 2); + Ok(()) + } + + #[test] + fn test_append_union() -> Result<(), Box> { + let db = Connection::open_in_memory()?; + + db.execute_batch("CREATE TABLE tbl1 (u UNION(num INT, str VARCHAR));")?; + + { + let int_array = Int32Array::from(vec![Some(1), None, Some(34)]); + let string_array = StringArray::from(vec![None, Some("foo"), None]); + let type_id_buffer = Buffer::from_slice_ref(&[0_i8, 1, 0]); + + let children: Vec<(Field, Arc)> = vec![ + (Field::new("num", DataType::Int32, false), Arc::new(int_array)), + (Field::new("str", DataType::Utf8, true), Arc::new(string_array)), + ]; + + let union_array = UnionArray::try_new(&vec![0, 1], type_id_buffer, None, children).unwrap(); + + let union_fields = UnionFields::new( + vec![0, 1], + vec![ + Field::new("num", DataType::Int32, false), + Field::new("str", DataType::Utf8, true), + ], + ); + + let schema = Schema::new(vec![Field::new( + "u", + DataType::Union(union_fields, arrow::datatypes::UnionMode::Sparse), + true, + )]); + let record_batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(union_array)])?; + let mut app = db.appender("tbl1")?; + app.append_record_batch(record_batch)?; + } + let mut stmt = db.prepare("SELECT u FROM tbl1")?; + let rbs: Vec = stmt.query_arrow([])?.collect(); + assert_eq!(rbs.iter().map(|op| op.num_rows()).sum::(), 2); + Ok(()) + } } diff --git a/src/vtab/data_chunk.rs b/src/vtab/data_chunk.rs index 6e472773..092e7384 100644 --- a/src/vtab/data_chunk.rs +++ b/src/vtab/data_chunk.rs @@ -1,6 +1,6 @@ use super::{ logical_type::LogicalType, - vector::{FlatVector, ListVector, StructVector}, + vector::{FlatVector, ListVector, StructVector, UnionVector}, }; use crate::ffi::{ duckdb_create_data_chunk, duckdb_data_chunk, duckdb_data_chunk_get_column_count, duckdb_data_chunk_get_size, @@ -40,6 +40,11 @@ impl DataChunk { StructVector::from(unsafe { duckdb_data_chunk_get_vector(self.ptr, idx as u64) }) } + /// Get union vector at the column index: `idx`. + pub fn union_vector(&self, idx: usize) -> UnionVector { + UnionVector::from(unsafe { duckdb_data_chunk_get_vector(self.ptr, idx as u64) }) + } + /// Set the size of the data chunk pub fn set_len(&self, new_len: usize) { unsafe { duckdb_data_chunk_set_size(self.ptr, new_len as u64) }; diff --git a/src/vtab/vector.rs b/src/vtab/vector.rs index bf61cff4..cdb06273 100644 --- a/src/vtab/vector.rs +++ b/src/vtab/vector.rs @@ -1,5 +1,7 @@ use std::{any::Any, ffi::CString, slice}; +use libduckdb_sys::{duckdb_union_type_member_count, duckdb_union_type_member_name}; + use super::LogicalType; use crate::ffi::{ duckdb_list_entry, duckdb_list_vector_get_child, duckdb_list_vector_get_size, duckdb_list_vector_reserve, @@ -221,3 +223,55 @@ impl StructVector { unsafe { duckdb_struct_type_child_count(logical_type.ptr) as usize } } } + +impl From for UnionVector { + fn from(ptr: duckdb_vector) -> Self { + Self { ptr } + } +} + +/// A union vector +pub struct UnionVector { + /// UnionVector does not own the vector pointer + ptr: duckdb_vector, +} + +impl UnionVector { + /// Get the logical type of this struct vector. + pub fn logical_type(&self) -> LogicalType { + LogicalType::from(unsafe { duckdb_vector_get_column_type(self.ptr) }) + } + + /// Retrieves the member vector of a union vecotr + pub fn tag_vector(&self) -> FlatVector { + FlatVector::from(unsafe { duckdb_struct_vector_get_child(self.ptr, 0) }) + } + + /// Retrieves the member vector of a union vecotr + pub fn member_vector(&self, idx: usize) -> FlatVector { + FlatVector::from(unsafe { duckdb_struct_vector_get_child(self.ptr, (idx + 1) as u64) }) + } + + /// Retrieves the child type of the given union member at the specified index + pub fn member_child_type(&self, idx: usize) -> LogicalType { + self.logical_type().child(idx + 1) + } + + /// Get the name of the union member. + pub fn member_name(&self, idx: usize) -> String { + let logical_type = self.logical_type(); + unsafe { + let member_name_ptr = duckdb_union_type_member_name(logical_type.ptr, idx as u64); + let c_str = CString::from_raw(member_name_ptr); + let name = c_str.to_str().unwrap(); + // duckdb_free(child_name_ptr.cast()); + name.to_string() + } + } + + /// Returns the number of members that the union type ha + pub fn num_union_members(&self) -> usize { + let logical_type = self.logical_type(); + unsafe { duckdb_union_type_member_count(logical_type.ptr) as usize } + } +}