Skip to content

Commit a198309

Browse files
committed
Support try_from_array and eq_array for ScalarValue::Union
1 parent bd50698 commit a198309

File tree

1 file changed

+122
-2
lines changed
  • datafusion/common/src/scalar

1 file changed

+122
-2
lines changed

datafusion/common/src/scalar/mod.rs

+122-2
Original file line numberDiff line numberDiff line change
@@ -2800,6 +2800,13 @@ impl ScalarValue {
28002800
let a = array.slice(index, 1);
28012801
Self::Map(Arc::new(a.as_map().to_owned()))
28022802
}
2803+
DataType::Union(fields, mode) => {
2804+
let array = as_union_array(array);
2805+
let ti = array.type_id(index);
2806+
let index = array.value_offset(index);
2807+
let value = ScalarValue::try_from_array(array.child(ti), index)?;
2808+
ScalarValue::Union(Some((ti, Box::new(value))), fields.clone(), *mode)
2809+
}
28032810
other => {
28042811
return _not_impl_err!(
28052812
"Can't create a scalar from array of type \"{other:?}\""
@@ -3035,8 +3042,15 @@ impl ScalarValue {
30353042
ScalarValue::DurationNanosecond(val) => {
30363043
eq_array_primitive!(array, index, DurationNanosecondArray, val)?
30373044
}
3038-
ScalarValue::Union(_, _, _) => {
3039-
return _not_impl_err!("Union is not supported yet")
3045+
ScalarValue::Union(value, _, _) => {
3046+
let array = as_union_array(array);
3047+
let ti = array.type_id(index);
3048+
let index = array.value_offset(index);
3049+
if let Some((ti_v, value)) = value {
3050+
ti_v == &ti && value.eq_array(array.child(ti), index)?
3051+
} else {
3052+
array.child(ti).is_null(index)
3053+
}
30403054
}
30413055
ScalarValue::Dictionary(key_type, v) => {
30423056
let (values_array, values_index) = match key_type.as_ref() {
@@ -5511,6 +5525,112 @@ mod tests {
55115525
assert_eq!(&array, &expected);
55125526
}
55135527

5528+
#[test]
5529+
fn test_scalar_union_sparse() {
5530+
let field_a = Arc::new(Field::new("A", DataType::Int32, true));
5531+
let field_b = Arc::new(Field::new("B", DataType::Boolean, true));
5532+
let field_c = Arc::new(Field::new("C", DataType::Utf8, true));
5533+
let fields = UnionFields::from_iter([(0, field_a), (1, field_b), (2, field_c)]);
5534+
5535+
let mut values_a = vec![None; 6];
5536+
values_a[0] = Some(42);
5537+
let mut values_b = vec![None; 6];
5538+
values_b[1] = Some(true);
5539+
let mut values_c = vec![None; 6];
5540+
values_c[2] = Some("foo");
5541+
let children: Vec<ArrayRef> = vec![
5542+
Arc::new(Int32Array::from(values_a)),
5543+
Arc::new(BooleanArray::from(values_b)),
5544+
Arc::new(StringArray::from(values_c)),
5545+
];
5546+
5547+
let type_ids = ScalarBuffer::from(vec![0, 1, 2, 0, 1, 2]);
5548+
let array: ArrayRef = Arc::new(
5549+
UnionArray::try_new(fields.clone(), type_ids, None, children)
5550+
.expect("UnionArray"),
5551+
);
5552+
5553+
let expected = [
5554+
(0, ScalarValue::from(42)),
5555+
(1, ScalarValue::from(true)),
5556+
(2, ScalarValue::from("foo")),
5557+
(0, ScalarValue::Int32(None)),
5558+
(1, ScalarValue::Boolean(None)),
5559+
(2, ScalarValue::Utf8(None)),
5560+
];
5561+
5562+
for (i, (ti, value)) in expected.into_iter().enumerate() {
5563+
let is_null = value.is_null();
5564+
let value = Some((ti, Box::new(value)));
5565+
let expected = ScalarValue::Union(value, fields.clone(), UnionMode::Sparse);
5566+
let actual = ScalarValue::try_from_array(&array, i).expect("try_from_array");
5567+
5568+
assert_eq!(
5569+
actual, expected,
5570+
"[{i}] {actual} was not equal to {expected}"
5571+
);
5572+
5573+
assert!(
5574+
expected.eq_array(&array, i).expect("eq_array"),
5575+
"[{i}] {expected}.eq_array was false"
5576+
);
5577+
5578+
if is_null {
5579+
assert!(actual.is_null(), "[{i}] {actual} was not null")
5580+
}
5581+
}
5582+
}
5583+
5584+
#[test]
5585+
fn test_scalar_union_dense() {
5586+
let field_a = Arc::new(Field::new("A", DataType::Int32, true));
5587+
let field_b = Arc::new(Field::new("B", DataType::Boolean, true));
5588+
let field_c = Arc::new(Field::new("C", DataType::Utf8, true));
5589+
let fields = UnionFields::from_iter([(0, field_a), (1, field_b), (2, field_c)]);
5590+
let children: Vec<ArrayRef> = vec![
5591+
Arc::new(Int32Array::from(vec![Some(42), None])),
5592+
Arc::new(BooleanArray::from(vec![Some(true), None])),
5593+
Arc::new(StringArray::from(vec![Some("foo"), None])),
5594+
];
5595+
5596+
let type_ids = ScalarBuffer::from(vec![0, 1, 2, 0, 1, 2]);
5597+
let offsets = ScalarBuffer::from(vec![0, 0, 0, 1, 1, 1]);
5598+
let array: ArrayRef = Arc::new(
5599+
UnionArray::try_new(fields.clone(), type_ids, Some(offsets), children)
5600+
.expect("UnionArray"),
5601+
);
5602+
5603+
let expected = [
5604+
(0, ScalarValue::from(42)),
5605+
(1, ScalarValue::from(true)),
5606+
(2, ScalarValue::from("foo")),
5607+
(0, ScalarValue::Int32(None)),
5608+
(1, ScalarValue::Boolean(None)),
5609+
(2, ScalarValue::Utf8(None)),
5610+
];
5611+
5612+
for (i, (ti, value)) in expected.into_iter().enumerate() {
5613+
let is_null = value.is_null();
5614+
let value = Some((ti, Box::new(value)));
5615+
let expected = ScalarValue::Union(value, fields.clone(), UnionMode::Dense);
5616+
let actual = ScalarValue::try_from_array(&array, i).expect("try_from_array");
5617+
5618+
assert_eq!(
5619+
actual, expected,
5620+
"[{i}] {actual} was not equal to {expected}"
5621+
);
5622+
5623+
assert!(
5624+
expected.eq_array(&array, i).expect("eq_array"),
5625+
"[{i}] {expected}.eq_array was false"
5626+
);
5627+
5628+
if is_null {
5629+
assert!(actual.is_null(), "[{i}] {actual} was not null")
5630+
}
5631+
}
5632+
}
5633+
55145634
#[test]
55155635
fn test_lists_in_struct() {
55165636
let field_a = Arc::new(Field::new("A", DataType::Utf8, false));

0 commit comments

Comments
 (0)