diff --git a/src/expr/src/relation.rs b/src/expr/src/relation.rs index 2c2035a7fd1dc..0540bbb46a75c 100644 --- a/src/expr/src/relation.rs +++ b/src/expr/src/relation.rs @@ -25,6 +25,7 @@ use mz_ore::collections::CollectionExt; use mz_ore::id_gen::IdGen; use mz_ore::metrics::Histogram; use mz_ore::num::NonNeg; +use mz_ore::soft_assert_eq_no_log; use mz_ore::stack::RecursionLimitError; use mz_ore::str::Indent; use mz_proto::{IntoRustIfSome, ProtoType, RustType, TryFromProtoError}; @@ -1688,8 +1689,44 @@ impl MirRelationExpr { } /// Take ownership of `self`, leaving an empty `MirRelationExpr::Constant` with the correct type. + /// + /// This calls `.typ()` to determine scalar types. If you already know the scalar types type, + /// then use either + /// [MirRelationExpr::take_safely_with_col_types] or [MirRelationExpr::take_safely_with_rel_type]. pub fn take_safely(&mut self) -> MirRelationExpr { - let typ = self.typ(); + self.take_safely_with_rel_type(self.typ()) + } + + /// Take ownership of `self`, leaving an empty `MirRelationExpr::Constant` with the given scalar + /// types. Keys and nullability are ignored in the given `RelationType`, and instead we set the + /// best possible key and nullability, since we are making an empty collection. + /// + /// Compared to `take_safely()`, this just saves the cost of the `.typ()` call. + pub fn take_safely_with_col_types(&mut self, typ: Vec) -> MirRelationExpr { + self.take_safely_with_rel_type(RelationType::new(typ)) + } + + /// Take ownership of `self`, leaving an empty `MirRelationExpr::Constant` with the given scalar + /// types. Keys and nullability are ignored in the given `RelationType`, and instead we set the + /// best possible key and nullability, since we are making an empty collection. + /// + /// Compared to `take_safely()`, this just saves the cost of the `.typ()` call. + pub fn take_safely_with_rel_type(&mut self, mut typ: RelationType) -> MirRelationExpr { + soft_assert_eq_no_log!( + self.typ() + .column_types + .iter() + .map(|ct| ct.scalar_type.clone()) + .collect_vec(), + typ.column_types + .iter() + .map(|ct| ct.scalar_type.clone()) + .collect_vec() + ); + typ.keys = vec![vec![]]; + for ct in typ.column_types.iter_mut() { + ct.nullable = false; + } std::mem::replace( self, MirRelationExpr::Constant { @@ -1698,6 +1735,7 @@ impl MirRelationExpr { }, ) } + /// Take ownership of `self`, leaving an empty `MirRelationExpr::Constant` with an **incorrect** type. /// /// This should only be used if `self` is about to be dropped or otherwise overwritten. @@ -2233,11 +2271,13 @@ impl MirRelationExpr { value.visit_pre_mut(|e| { if let MirRelationExpr::Get { id: crate::Id::Local(id), + typ, .. } = e { + let typ = typ.clone(); if deadlist.contains(id) { - e.take_safely(); + e.take_safely_with_rel_type(typ); } } }); diff --git a/src/transform/src/equivalence_propagation.rs b/src/transform/src/equivalence_propagation.rs index cbf7c9a5c3bdc..ba3facce1bf07 100644 --- a/src/transform/src/equivalence_propagation.rs +++ b/src/transform/src/equivalence_propagation.rs @@ -124,6 +124,7 @@ impl EquivalencePropagation { let expr_type = derived .value::() .expect("RelationType required"); + assert!(expr_type.is_some()); let expr_equivalences = derived .value::() .expect("Equivalences required"); @@ -132,7 +133,7 @@ impl EquivalencePropagation { let expr_equivalences = if let Some(e) = expr_equivalences { e } else { - expr.take_safely(); + expr.take_safely_with_col_types(expr_type.clone().unwrap()); return; }; @@ -147,7 +148,7 @@ impl EquivalencePropagation { outer_equivalences.minimize(expr_type.as_ref().map(|x| &x[..])); if outer_equivalences.unsatisfiable() { - expr.take_safely(); + expr.take_safely_with_col_types(expr_type.clone().unwrap()); return; } diff --git a/src/transform/src/fold_constants.rs b/src/transform/src/fold_constants.rs index 68016f7720b7a..58a60850fadc6 100644 --- a/src/transform/src/fold_constants.rs +++ b/src/transform/src/fold_constants.rs @@ -250,7 +250,7 @@ impl FoldConstants { .iter() .any(|p| p.is_literal_false() || p.is_literal_null()) { - relation.take_safely(); + relation.take_safely_with_rel_type(relation_type.clone()); } else if let Some((rows, ..)) = (**input).as_const() { // Evaluate errors last, to reduce risk of spurious errors. predicates.sort_by_key(|p| p.is_literal_err()); @@ -291,7 +291,7 @@ impl FoldConstants { .. } => { if inputs.iter().any(|e| e.is_empty()) { - relation.take_safely(); + relation.take_safely_with_rel_type(relation_type.clone()); } else if let Some(e) = inputs.iter().find_map(|i| i.as_const_err()) { *relation = MirRelationExpr::Constant { rows: Err(e.clone()), diff --git a/src/transform/src/predicate_pushdown.rs b/src/transform/src/predicate_pushdown.rs index 95c7dfa188f3b..7def77b8d21df 100644 --- a/src/transform/src/predicate_pushdown.rs +++ b/src/transform/src/predicate_pushdown.rs @@ -595,7 +595,9 @@ impl PredicatePushdown { .count() > 1 { - relation.take_safely(); + relation.take_safely_with_rel_type( + relation.typ_with_input_types(&input_types), + ); return Ok(()); }