Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wip] update list & struct coercion to support incrementality #15259

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion datafusion/datasource/src/schema_adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ use arrow::array::{new_null_array, RecordBatch, RecordBatchOptions};
use arrow::compute::{can_cast_types, cast};
use arrow::datatypes::{Schema, SchemaRef};
use datafusion_common::plan_err;
use datafusion_expr::binary::struct_coercion;
use datafusion_expr::binary::list_coercion;
use std::fmt::Debug;
use std::sync::Arc;

Expand Down Expand Up @@ -269,7 +271,11 @@ impl SchemaAdapter for DefaultSchemaAdapter {
if let Some((table_idx, table_field)) =
self.projected_table_schema.fields().find(file_field.name())
{
match can_cast_types(file_field.data_type(), table_field.data_type()) {
match
can_cast_types(file_field.data_type(), table_field.data_type()) ||
struct_coercion(file_field.data_type(), table_field.data_type()).is_some() ||
list_coercion(file_field.data_type(), table_field.data_type()).is_some()
{
true => {
field_mappings[table_idx] = Some(projection.len());
projection.push(file_idx);
Expand Down
94 changes: 76 additions & 18 deletions datafusion/expr-common/src/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,27 @@ pub fn type_union_resolution(data_types: &[DataType]) -> Option<DataType> {
return Some(DataType::Utf8);
}

// Check if all data_types are Structs and coalescable
if data_types.iter().all(|t| matches!(t, DataType::Struct(_))) {
let mut max_cardinality_struct: Option<&DataType> = None;
for data_type in data_types.iter() {
if let Some(DataType::Struct(fields)) = max_cardinality_struct {
if let DataType::Struct(new_fields) = data_type {
if fields.len() == new_fields.len() {
max_cardinality_struct = Some(data_type);
} else if new_fields.len() > fields.len() {
max_cardinality_struct = Some(data_type);
}
}
} else {
max_cardinality_struct = Some(data_type);
}
}
if let Some(max_struct) = max_cardinality_struct {
return Some(max_struct.clone());
}
}

// Ignore Nulls, if any data_type category is not the same, return None
let data_types_category: Vec<TypeCategory> = data_types
.iter()
Expand Down Expand Up @@ -976,27 +997,48 @@ fn coerce_numeric_type_to_decimal256(numeric_type: &DataType) -> Option<DataType
}
}

fn struct_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
pub fn struct_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match (lhs_type, rhs_type) {
(Struct(lhs_fields), Struct(rhs_fields)) => {
if lhs_fields.len() != rhs_fields.len() {
return None;
}

let coerced_types = std::iter::zip(lhs_fields.iter(), rhs_fields.iter())
.map(|(lhs, rhs)| comparison_coercion(lhs.data_type(), rhs.data_type()))
.collect::<Option<Vec<DataType>>>()?;

// preserve the field name and nullability
let orig_fields = std::iter::zip(lhs_fields.iter(), rhs_fields.iter());
use std::collections::{HashMap, HashSet};
// Create maps from field name to field
let lhs_map: HashMap<String, FieldRef> = lhs_fields
.iter()
.map(|f| (f.name().clone(), Arc::clone(f)))
.collect();
let rhs_map: HashMap<String, FieldRef> = rhs_fields
.iter()
.map(|f| (f.name().clone(), Arc::clone(f)))
.collect();

let fields: Vec<FieldRef> = coerced_types
.into_iter()
.zip(orig_fields)
.map(|(datatype, (lhs, rhs))| coerce_fields(datatype, lhs, rhs))
// Compute the union of all field names
let keys: HashSet<String> = lhs_map
.keys()
.chain(rhs_map.keys())
.cloned()
.collect();
Some(Struct(fields.into()))

let mut coerced_fields: Vec<FieldRef> = Vec::new();
for key in keys {
let lhs_field = lhs_map.get(&key);
let rhs_field = rhs_map.get(&key);

// If a field is missing on one side, treat it as Null.
let lhs_dt = lhs_field.map(|f| f.data_type()).unwrap_or(&Null);
let rhs_dt = rhs_field.map(|f| f.data_type()).unwrap_or(&Null);

let coerced_dt = comparison_coercion(lhs_dt, rhs_dt)?;
// Determine nullability:
// - If the field is missing on one side, we consider it nullable.
// - Otherwise, true if either field is nullable.
let nullable = lhs_field.map(|f| f.is_nullable()).unwrap_or(true)
|| rhs_field.map(|f| f.is_nullable()).unwrap_or(true);
coerced_fields.push(Arc::new(Field::new(&key, coerced_dt, nullable)));
}
// Optionally, sort fields (here we sort alphabetically by name)
coerced_fields.sort_by(|a, b| a.name().cmp(b.name()));
Some(Struct(coerced_fields.into()))
}
_ => None,
}
Expand Down Expand Up @@ -1221,7 +1263,7 @@ fn coerce_list_children(lhs_field: &FieldRef, rhs_field: &FieldRef) -> Option<Fi
}

/// Coercion rules for list types.
fn list_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
pub fn list_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match (lhs_type, rhs_type) {
// Coerce to the left side FixedSizeList type if the list lengths are the same,
Expand Down Expand Up @@ -2240,7 +2282,6 @@ mod tests {
#[test]
fn test_list_coercion() {
let lhs_type = DataType::List(Arc::new(Field::new("lhs", DataType::Int8, false)));

let rhs_type = DataType::List(Arc::new(Field::new("rhs", DataType::Int64, true)));

let coerced_type = list_coercion(&lhs_type, &rhs_type).unwrap();
Expand All @@ -2250,6 +2291,23 @@ mod tests {
); // nullable because the RHS is nullable
}

#[test]
fn test_struct_coercion() {
let lhs_struct = DataType::Struct(vec![
Field::new("a", DataType::Int8, false),
Field::new("b", DataType::Int16, false),
].into());

let rhs_struct = DataType::Struct(vec![
Field::new("a", DataType::Int8, false),
Field::new("b", DataType::Int16, false),
Field::new("c", DataType::Int32, true),
].into());

let coerced_type = type_union_resolution(&[lhs_struct.clone(), rhs_struct.clone()]).unwrap();
assert_eq!(coerced_type, rhs_struct);
}

#[test]
fn test_type_coercion_logical_op() -> Result<()> {
test_coercion_binary_rule!(
Expand Down
Loading