Skip to content

Commit f8d8603

Browse files
Randall SpearsSpears Randallalamb
authored
Change ScalarValue::{List, LargeList, FixedSizedList} to take specific types rather than ArrayRef (#8562)
* Change ScalarValue::List type signature Also ScalarValue::LargeList and ScalarValue::FixedSizeList * Formatting/cleanup * Remove duplicate match statements * Add back scalar eq_array test for List * Formatting * Reduce code duplication * Fix merge conflict * Fix post-merge compile errors * Remove redundant partial_cmp implementation * improve * Cargo fmt fix * Reduce duplication in formatter * Reduce more duplication * Fix test error --------- Co-authored-by: Spears Randall <[email protected]> Co-authored-by: Andrew Lamb <[email protected]>
1 parent be8a953 commit f8d8603

File tree

7 files changed

+243
-195
lines changed

7 files changed

+243
-195
lines changed

datafusion/common/src/scalar.rs

+173-132
Large diffs are not rendered by default.

datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ use crate::simplify_expressions::regex::simplify_regex_expr;
2727
use crate::simplify_expressions::SimplifyInfo;
2828

2929
use arrow::{
30-
array::new_null_array,
30+
array::{new_null_array, AsArray},
3131
datatypes::{DataType, Field, Schema},
3232
record_batch::RecordBatch,
3333
};
@@ -396,7 +396,7 @@ impl<'a> ConstEvaluator<'a> {
396396
a.len()
397397
)
398398
} else if as_list_array(&a).is_ok() || as_large_list_array(&a).is_ok() {
399-
Ok(ScalarValue::List(a))
399+
Ok(ScalarValue::List(a.as_list().to_owned().into()))
400400
} else {
401401
// Non-ListArray
402402
ScalarValue::try_from_array(&a, 0)

datafusion/physical-expr/src/aggregate/array_agg_distinct.rs

+1-5
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,6 @@ mod tests {
186186
use arrow::array::{ArrayRef, Int32Array};
187187
use arrow::datatypes::{DataType, Schema};
188188
use arrow::record_batch::RecordBatch;
189-
use arrow_array::cast::as_list_array;
190189
use arrow_array::types::Int32Type;
191190
use arrow_array::{Array, ListArray};
192191
use arrow_buffer::OffsetBuffer;
@@ -196,10 +195,7 @@ mod tests {
196195
// arrow::compute::sort cann't sort ListArray directly, so we need to sort the inner primitive array and wrap it back into ListArray.
197196
fn sort_list_inner(arr: ScalarValue) -> ScalarValue {
198197
let arr = match arr {
199-
ScalarValue::List(arr) => {
200-
let list_arr = as_list_array(&arr);
201-
list_arr.value(0)
202-
}
198+
ScalarValue::List(arr) => arr.value(0),
203199
_ => {
204200
panic!("Expected ScalarValue::List, got {:?}", arr)
205201
}

datafusion/physical-expr/src/aggregate/count_distinct.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ where
292292
let arr = Arc::new(PrimitiveArray::<T>::from_iter_values(
293293
self.values.iter().cloned(),
294294
)) as ArrayRef;
295-
let list = Arc::new(array_into_list_array(arr)) as ArrayRef;
295+
let list = Arc::new(array_into_list_array(arr));
296296
Ok(vec![ScalarValue::List(list)])
297297
}
298298

@@ -378,7 +378,7 @@ where
378378
let arr = Arc::new(PrimitiveArray::<T>::from_iter_values(
379379
self.values.iter().map(|v| v.0),
380380
)) as ArrayRef;
381-
let list = Arc::new(array_into_list_array(arr)) as ArrayRef;
381+
let list = Arc::new(array_into_list_array(arr));
382382
Ok(vec![ScalarValue::List(list)])
383383
}
384384

datafusion/physical-expr/src/aggregate/tdigest.rs

+2-4
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
//! [Facebook's Folly TDigest]: https://github.com/facebook/folly/blob/main/folly/stats/TDigest.h
2929
3030
use arrow::datatypes::DataType;
31-
use arrow_array::cast::as_list_array;
3231
use arrow_array::types::Float64Type;
3332
use datafusion_common::cast::as_primitive_array;
3433
use datafusion_common::Result;
@@ -606,11 +605,10 @@ impl TDigest {
606605

607606
let centroids: Vec<_> = match &state[5] {
608607
ScalarValue::List(arr) => {
609-
let list_array = as_list_array(arr);
610-
let arr = list_array.values();
608+
let array = arr.values();
611609

612610
let f64arr =
613-
as_primitive_array::<Float64Type>(arr).expect("expected f64 array");
611+
as_primitive_array::<Float64Type>(array).expect("expected f64 array");
614612
f64arr
615613
.values()
616614
.chunks(2)

datafusion/proto/src/logical_plan/from_proto.rs

+10-3
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ use crate::protobuf::{
2727
OptimizedPhysicalPlanType, PlaceholderNode, RollupNode,
2828
};
2929
use arrow::{
30+
array::AsArray,
3031
buffer::Buffer,
3132
datatypes::{
3233
i256, DataType, Field, IntervalMonthDayNanoType, IntervalUnit, Schema, TimeUnit,
@@ -722,9 +723,15 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
722723
.map_err(|e| e.context("Decoding ScalarValue::List Value"))?;
723724
let arr = record_batch.column(0);
724725
match value {
725-
Value::ListValue(_) => Self::List(arr.to_owned()),
726-
Value::LargeListValue(_) => Self::LargeList(arr.to_owned()),
727-
Value::FixedSizeListValue(_) => Self::FixedSizeList(arr.to_owned()),
726+
Value::ListValue(_) => {
727+
Self::List(arr.as_list::<i32>().to_owned().into())
728+
}
729+
Value::LargeListValue(_) => {
730+
Self::LargeList(arr.as_list::<i64>().to_owned().into())
731+
}
732+
Value::FixedSizeListValue(_) => {
733+
Self::FixedSizeList(arr.as_fixed_size_list().to_owned().into())
734+
}
728735
_ => unreachable!(),
729736
}
730737
}

datafusion/proto/src/logical_plan/to_proto.rs

+53-47
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ use crate::protobuf::{
3232
OptimizedLogicalPlanType, OptimizedPhysicalPlanType, PlaceholderNode, RollupNode,
3333
};
3434
use arrow::{
35+
array::ArrayRef,
3536
datatypes::{
3637
DataType, Field, IntervalMonthDayNanoType, IntervalUnit, Schema, SchemaRef,
3738
TimeUnit, UnionMode,
@@ -1159,54 +1160,15 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue {
11591160
}
11601161
// ScalarValue::List and ScalarValue::FixedSizeList are serialized using
11611162
// Arrow IPC messages as a single column RecordBatch
1162-
ScalarValue::List(arr)
1163-
| ScalarValue::LargeList(arr)
1164-
| ScalarValue::FixedSizeList(arr) => {
1163+
ScalarValue::List(arr) => {
1164+
encode_scalar_list_value(arr.to_owned() as ArrayRef, val)
1165+
}
1166+
ScalarValue::LargeList(arr) => {
11651167
// Wrap in a "field_name" column
1166-
let batch = RecordBatch::try_from_iter(vec![(
1167-
"field_name",
1168-
arr.to_owned(),
1169-
)])
1170-
.map_err(|e| {
1171-
Error::General( format!("Error creating temporary batch while encoding ScalarValue::List: {e}"))
1172-
})?;
1173-
1174-
let gen = IpcDataGenerator {};
1175-
let mut dict_tracker = DictionaryTracker::new(false);
1176-
let (_, encoded_message) = gen
1177-
.encoded_batch(&batch, &mut dict_tracker, &Default::default())
1178-
.map_err(|e| {
1179-
Error::General(format!(
1180-
"Error encoding ScalarValue::List as IPC: {e}"
1181-
))
1182-
})?;
1183-
1184-
let schema: protobuf::Schema = batch.schema().try_into()?;
1185-
1186-
let scalar_list_value = protobuf::ScalarListValue {
1187-
ipc_message: encoded_message.ipc_message,
1188-
arrow_data: encoded_message.arrow_data,
1189-
schema: Some(schema),
1190-
};
1191-
1192-
match val {
1193-
ScalarValue::List(_) => Ok(protobuf::ScalarValue {
1194-
value: Some(protobuf::scalar_value::Value::ListValue(
1195-
scalar_list_value,
1196-
)),
1197-
}),
1198-
ScalarValue::LargeList(_) => Ok(protobuf::ScalarValue {
1199-
value: Some(protobuf::scalar_value::Value::LargeListValue(
1200-
scalar_list_value,
1201-
)),
1202-
}),
1203-
ScalarValue::FixedSizeList(_) => Ok(protobuf::ScalarValue {
1204-
value: Some(protobuf::scalar_value::Value::FixedSizeListValue(
1205-
scalar_list_value,
1206-
)),
1207-
}),
1208-
_ => unreachable!(),
1209-
}
1168+
encode_scalar_list_value(arr.to_owned() as ArrayRef, val)
1169+
}
1170+
ScalarValue::FixedSizeList(arr) => {
1171+
encode_scalar_list_value(arr.to_owned() as ArrayRef, val)
12101172
}
12111173
ScalarValue::Date32(val) => {
12121174
create_proto_scalar(val.as_ref(), &data_type, |s| Value::Date32Value(*s))
@@ -1723,3 +1685,47 @@ fn create_proto_scalar<I, T: FnOnce(&I) -> protobuf::scalar_value::Value>(
17231685

17241686
Ok(protobuf::ScalarValue { value: Some(value) })
17251687
}
1688+
1689+
fn encode_scalar_list_value(
1690+
arr: ArrayRef,
1691+
val: &ScalarValue,
1692+
) -> Result<protobuf::ScalarValue, Error> {
1693+
let batch = RecordBatch::try_from_iter(vec![("field_name", arr)]).map_err(|e| {
1694+
Error::General(format!(
1695+
"Error creating temporary batch while encoding ScalarValue::List: {e}"
1696+
))
1697+
})?;
1698+
1699+
let gen = IpcDataGenerator {};
1700+
let mut dict_tracker = DictionaryTracker::new(false);
1701+
let (_, encoded_message) = gen
1702+
.encoded_batch(&batch, &mut dict_tracker, &Default::default())
1703+
.map_err(|e| {
1704+
Error::General(format!("Error encoding ScalarValue::List as IPC: {e}"))
1705+
})?;
1706+
1707+
let schema: protobuf::Schema = batch.schema().try_into()?;
1708+
1709+
let scalar_list_value = protobuf::ScalarListValue {
1710+
ipc_message: encoded_message.ipc_message,
1711+
arrow_data: encoded_message.arrow_data,
1712+
schema: Some(schema),
1713+
};
1714+
1715+
match val {
1716+
ScalarValue::List(_) => Ok(protobuf::ScalarValue {
1717+
value: Some(protobuf::scalar_value::Value::ListValue(scalar_list_value)),
1718+
}),
1719+
ScalarValue::LargeList(_) => Ok(protobuf::ScalarValue {
1720+
value: Some(protobuf::scalar_value::Value::LargeListValue(
1721+
scalar_list_value,
1722+
)),
1723+
}),
1724+
ScalarValue::FixedSizeList(_) => Ok(protobuf::ScalarValue {
1725+
value: Some(protobuf::scalar_value::Value::FixedSizeListValue(
1726+
scalar_list_value,
1727+
)),
1728+
}),
1729+
_ => unreachable!(),
1730+
}
1731+
}

0 commit comments

Comments
 (0)