Skip to content

Commit 0018cd8

Browse files
authored
feat: enum read support (#297)
* add basic enum support * add enum to test_all_types
1 parent d613139 commit 0018cd8

File tree

5 files changed

+77
-6
lines changed

5 files changed

+77
-6
lines changed

src/row.rs

+20-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
use std::{convert, sync::Arc};
22

33
use super::{Error, Result, Statement};
4-
use crate::types::{self, FromSql, FromSqlError, ValueRef};
4+
use crate::types::{self, EnumType, FromSql, FromSqlError, ValueRef};
55

6+
use arrow::array::DictionaryArray;
67
use arrow::{
78
array::{self, Array, ArrayRef, ListArray, StructArray},
89
datatypes::*,
@@ -601,6 +602,24 @@ impl<'stmt> Row<'stmt> {
601602

602603
ValueRef::List(arr, row)
603604
}
605+
DataType::Dictionary(key_type, ..) => {
606+
let column = column.as_any();
607+
ValueRef::Enum(
608+
match key_type.as_ref() {
609+
DataType::UInt8 => {
610+
EnumType::UInt8(column.downcast_ref::<DictionaryArray<UInt8Type>>().unwrap())
611+
}
612+
DataType::UInt16 => {
613+
EnumType::UInt16(column.downcast_ref::<DictionaryArray<UInt16Type>>().unwrap())
614+
}
615+
DataType::UInt32 => {
616+
EnumType::UInt32(column.downcast_ref::<DictionaryArray<UInt32Type>>().unwrap())
617+
}
618+
typ => panic!("Unsupported key type: {typ:?}"),
619+
},
620+
row,
621+
)
622+
}
604623
_ => unreachable!("invalid value: {} {}", col, column.data_type()),
605624
}
606625
}

src/test_all_types.rs

+15-3
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,6 @@ fn test_all_types() -> crate::Result<()> {
1818
// union is currently blocked by https://github.com/duckdb/duckdb/pull/11326
1919
"union",
2020
// these remaining types are not yet supported by duckdb-rs
21-
"small_enum",
22-
"medium_enum",
23-
"large_enum",
2421
"struct",
2522
"struct_of_arrays",
2623
"array_of_structs",
@@ -349,6 +346,21 @@ fn test_single(idx: &mut i32, column: String, value: ValueRef) {
349346
),
350347
_ => assert_eq!(value, ValueRef::Null),
351348
},
349+
"small_enum" => match idx {
350+
0 => assert_eq!(value.to_owned(), Value::Enum("DUCK_DUCK_ENUM".to_string())),
351+
1 => assert_eq!(value.to_owned(), Value::Enum("GOOSE".to_string())),
352+
_ => assert_eq!(value, ValueRef::Null),
353+
},
354+
"medium_enum" => match idx {
355+
0 => assert_eq!(value.to_owned(), Value::Enum("enum_0".to_string())),
356+
1 => assert_eq!(value.to_owned(), Value::Enum("enum_1".to_string())),
357+
_ => assert_eq!(value, ValueRef::Null),
358+
},
359+
"large_enum" => match idx {
360+
0 => assert_eq!(value.to_owned(), Value::Enum("enum_0".to_string())),
361+
1 => assert_eq!(value.to_owned(), Value::Enum("enum_69999".to_string())),
362+
_ => assert_eq!(value, ValueRef::Null),
363+
},
352364
_ => todo!("{column:?}"),
353365
}
354366
}

src/types/mod.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ pub use self::{
7171
from_sql::{FromSql, FromSqlError, FromSqlResult},
7272
to_sql::{ToSql, ToSqlOutput},
7373
value::Value,
74-
value_ref::{TimeUnit, ValueRef},
74+
value_ref::{EnumType, TimeUnit, ValueRef},
7575
};
7676

7777
use arrow::datatypes::DataType;
@@ -149,6 +149,8 @@ pub enum Type {
149149
Interval,
150150
/// LIST
151151
List(Box<Type>),
152+
/// ENUM
153+
Enum,
152154
/// Any
153155
Any,
154156
}
@@ -219,6 +221,7 @@ impl fmt::Display for Type {
219221
Type::Time64 => f.pad("Time64"),
220222
Type::Interval => f.pad("Interval"),
221223
Type::List(..) => f.pad("List"),
224+
Type::Enum => f.pad("Enum"),
222225
Type::Any => f.pad("Any"),
223226
}
224227
}

src/types/value.rs

+3
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ pub enum Value {
5757
},
5858
/// The value is a list
5959
List(Vec<Value>),
60+
/// The value is an enum
61+
Enum(String),
6062
}
6163

6264
impl From<Null> for Value {
@@ -225,6 +227,7 @@ impl Value {
225227
Value::Time64(..) => Type::Time64,
226228
Value::Interval { .. } => Type::Interval,
227229
Value::List(_) => todo!(),
230+
Value::Enum(..) => Type::Enum,
228231
}
229232
}
230233
}

src/types/value_ref.rs

+35-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ use crate::types::{FromSqlError, FromSqlResult};
44
use crate::Row;
55
use rust_decimal::prelude::*;
66

7-
use arrow::array::{Array, ListArray};
7+
use arrow::array::{Array, DictionaryArray, ListArray};
8+
use arrow::datatypes::{UInt16Type, UInt32Type, UInt8Type};
89

910
/// An absolute length of time in seconds, milliseconds, microseconds or nanoseconds.
1011
/// Copy from arrow::datatypes::TimeUnit
@@ -75,6 +76,19 @@ pub enum ValueRef<'a> {
7576
},
7677
/// The value is a list
7778
List(&'a ListArray, usize),
79+
/// The value is an enum
80+
Enum(EnumType<'a>, usize),
81+
}
82+
83+
/// Wrapper type for different enum sizes
84+
#[derive(Debug, Copy, Clone, PartialEq)]
85+
pub enum EnumType<'a> {
86+
/// The underlying enum type is u8
87+
UInt8(&'a DictionaryArray<UInt8Type>),
88+
/// The underlying enum type is u16
89+
UInt16(&'a DictionaryArray<UInt16Type>),
90+
/// The underlying enum type is u32
91+
UInt32(&'a DictionaryArray<UInt32Type>),
7892
}
7993

8094
impl ValueRef<'_> {
@@ -103,6 +117,7 @@ impl ValueRef<'_> {
103117
ValueRef::Time64(..) => Type::Time64,
104118
ValueRef::Interval { .. } => Type::Interval,
105119
ValueRef::List(arr, _) => arr.data_type().into(),
120+
ValueRef::Enum(..) => Type::Enum,
106121
}
107122
}
108123

@@ -170,6 +185,24 @@ impl From<ValueRef<'_>> for Value {
170185
.collect();
171186
Value::List(map)
172187
}
188+
ValueRef::Enum(items, idx) => {
189+
let value = Row::value_ref_internal(
190+
idx,
191+
0,
192+
match items {
193+
EnumType::UInt8(res) => res.values(),
194+
EnumType::UInt16(res) => res.values(),
195+
EnumType::UInt32(res) => res.values(),
196+
},
197+
)
198+
.to_owned();
199+
200+
if let Value::Text(s) = value {
201+
Value::Enum(s)
202+
} else {
203+
panic!("Enum value is not a string")
204+
}
205+
}
173206
}
174207
}
175208
}
@@ -213,6 +246,7 @@ impl<'a> From<&'a Value> for ValueRef<'a> {
213246
Value::Time64(t, d) => ValueRef::Time64(t, d),
214247
Value::Interval { months, days, nanos } => ValueRef::Interval { months, days, nanos },
215248
Value::List(..) => unimplemented!(),
249+
Value::Enum(..) => todo!(),
216250
}
217251
}
218252
}

0 commit comments

Comments
 (0)