Skip to content

Commit 5a1729e

Browse files
JeadieMaxxen
andauthored
Support UTF8 in nested Apache Arrow data types (e.g. List) (#300)
* support UTF8[] * add tests * fix test * format * clippy * bump cause github is broken --------- Co-authored-by: Max Gabrielsson <[email protected]>
1 parent 0018cd8 commit 5a1729e

File tree

2 files changed

+79
-8
lines changed

2 files changed

+79
-8
lines changed

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ pretty_assertions = "1.4.0"
8484
path = "libduckdb-sys"
8585
version = "0.10.1"
8686

87+
8788
[package.metadata.docs.rs]
8889
features = ['vtab', 'chrono']
8990
all-features = false

src/vtab/arrow.rs

+78-8
Original file line numberDiff line numberDiff line change
@@ -422,17 +422,20 @@ fn list_array_to_vector<O: OffsetSizeTrait + AsPrimitive<usize>>(
422422
match value_array.data_type() {
423423
dt if dt.is_primitive() => {
424424
primitive_array_to_vector(value_array.as_ref(), &mut child)?;
425-
for i in 0..array.len() {
426-
let offset = array.value_offsets()[i];
427-
let length = array.value_length(i);
428-
out.set_entry(i, offset.as_(), length.as_());
429-
}
425+
}
426+
DataType::Utf8 => {
427+
string_array_to_vector(as_string_array(value_array.as_ref()), &mut child);
430428
}
431429
_ => {
432430
return Err("Nested list is not supported yet.".into());
433431
}
434432
}
435433

434+
for i in 0..array.len() {
435+
let offset = array.value_offsets()[i];
436+
let length = array.value_length(i);
437+
out.set_entry(i, offset.as_(), length.as_());
438+
}
436439
Ok(())
437440
}
438441

@@ -452,10 +455,19 @@ fn fixed_size_list_array_to_vector(
452455
}
453456
out.set_len(value_array.len());
454457
}
458+
DataType::Utf8 => {
459+
string_array_to_vector(as_string_array(value_array.as_ref()), &mut child);
460+
}
455461
_ => {
456462
return Err("Nested list is not supported yet.".into());
457463
}
458464
}
465+
for i in 0..array.len() {
466+
let offset = array.value_offset(i);
467+
let length = array.value_length();
468+
out.set_entry(i, offset as usize, length as usize);
469+
}
470+
out.set_len(value_array.len());
459471

460472
Ok(())
461473
}
@@ -543,10 +555,12 @@ mod test {
543555
use crate::{Connection, Result};
544556
use arrow::{
545557
array::{
546-
Array, ArrayRef, AsArray, Date32Array, Date64Array, Decimal256Array, Float64Array, Int32Array,
547-
PrimitiveArray, StringArray, StructArray, Time32SecondArray, Time64MicrosecondArray,
548-
TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray,
558+
Array, ArrayRef, AsArray, Date32Array, Date64Array, Decimal256Array, Float64Array, GenericListArray,
559+
Int32Array, ListArray, OffsetSizeTrait, PrimitiveArray, StringArray, StructArray, Time32SecondArray,
560+
Time64MicrosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
561+
TimestampSecondArray,
549562
},
563+
buffer::{OffsetBuffer, ScalarBuffer},
550564
datatypes::{i256, ArrowPrimitiveType, DataType, Field, Fields, Schema},
551565
record_batch::RecordBatch,
552566
};
@@ -676,6 +690,62 @@ mod test {
676690
Ok(())
677691
}
678692

693+
fn check_generic_array_roundtrip<T>(arry: GenericListArray<T>) -> Result<(), Box<dyn Error>>
694+
where
695+
T: OffsetSizeTrait,
696+
{
697+
let expected_output_array = arry.clone();
698+
699+
let db = Connection::open_in_memory()?;
700+
db.register_table_function::<ArrowVTab>("arrow")?;
701+
702+
// Roundtrip a record batch from Rust to DuckDB and back to Rust
703+
let schema = Schema::new(vec![Field::new("a", arry.data_type().clone(), false)]);
704+
705+
let rb = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arry.clone())])?;
706+
let param = arrow_recordbatch_to_query_params(rb);
707+
let mut stmt = db.prepare("select a from arrow(?, ?)")?;
708+
let rb = stmt.query_arrow(param)?.next().expect("no record batch");
709+
710+
let output_any_array = rb.column(0);
711+
assert!(output_any_array
712+
.data_type()
713+
.equals_datatype(expected_output_array.data_type()));
714+
715+
match output_any_array.as_list_opt::<T>() {
716+
Some(output_array) => {
717+
assert_eq!(output_array.len(), expected_output_array.len());
718+
for i in 0..output_array.len() {
719+
assert_eq!(output_array.is_valid(i), expected_output_array.is_valid(i));
720+
if output_array.is_valid(i) {
721+
assert!(expected_output_array.value(i).eq(&output_array.value(i)));
722+
}
723+
}
724+
}
725+
None => panic!("Expected GenericListArray"),
726+
}
727+
728+
Ok(())
729+
}
730+
731+
#[test]
732+
fn test_array_roundtrip() -> Result<(), Box<dyn Error>> {
733+
check_generic_array_roundtrip(ListArray::new(
734+
Arc::new(Field::new("item", DataType::Utf8, true)),
735+
OffsetBuffer::new(ScalarBuffer::from(vec![0, 2, 4, 5])),
736+
Arc::new(StringArray::from(vec![
737+
Some("foo"),
738+
Some("baz"),
739+
Some("bar"),
740+
Some("foo"),
741+
Some("baz"),
742+
])),
743+
None,
744+
))?;
745+
746+
Ok(())
747+
}
748+
679749
#[test]
680750
fn test_timestamp_roundtrip() -> Result<(), Box<dyn Error>> {
681751
check_rust_primitive_array_roundtrip(Int32Array::from(vec![1, 2, 3]), Int32Array::from(vec![1, 2, 3]))?;

0 commit comments

Comments
 (0)