diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs index 41ba89bb61d7..9497e8628dfb 100644 --- a/datafusion/physical-expr/src/equivalence/properties.rs +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -1532,26 +1532,19 @@ impl Hash for ExprWrapper { } /// Calculates the union (in the sense of `UnionExec`) `EquivalenceProperties` -/// of `lhs` and `rhs` according to the given output `schema` (which need not -/// be the same with those of `lhs` and `rhs` as details such as nullability -/// may be different). +/// of `lhs` and `rhs` according to the schema of the `lhs`. fn calculate_union_binary( - mut lhs: EquivalenceProperties, + lhs: EquivalenceProperties, mut rhs: EquivalenceProperties, - schema: SchemaRef, ) -> Result { // TODO: In some cases, we should be able to preserve some equivalence // classes. Add support for such cases. - // Harmonize the schemas of the two sides with the output schema: - if !lhs.schema.eq(&schema) { - lhs = lhs.with_new_schema(Arc::clone(&schema))?; - } - if !rhs.schema.eq(&schema) { - rhs = rhs.with_new_schema(Arc::clone(&schema))?; + // Harmonize the schema of the rhs with the schema of the lhs (which is the accumulator schema): + if !rhs.schema.eq(&lhs.schema) { + rhs = rhs.with_new_schema(Arc::clone(&lhs.schema))?; } - let mut eq_properties = EquivalenceProperties::new(schema); // First, calculate valid constants for the union. A quantity is constant // after the union if it is constant in both sides. let constants = lhs @@ -1564,8 +1557,8 @@ fn calculate_union_binary( // the same. However, we do not have the capability to // check this yet. ConstExpr::new(Arc::clone(const_expr.expr())).with_across_partitions(false) - }); - eq_properties = eq_properties.add_constants(constants); + }) + .collect::>(); // Next, calculate valid orderings for the union by searching for prefixes // in both sides. @@ -1590,6 +1583,8 @@ fn calculate_union_binary( orderings.push(ordering); } } + let mut eq_properties = EquivalenceProperties::new(lhs.schema); + eq_properties.constants = constants; eq_properties.add_new_orderings(orderings); Ok(eq_properties) } @@ -1604,10 +1599,14 @@ pub fn calculate_union( ) -> Result { // TODO: In some cases, we should be able to preserve some equivalence // classes. Add support for such cases. - let init = eqps[0].clone(); - eqps.into_iter().skip(1).try_fold(init, |acc, eqp| { - calculate_union_binary(acc, eqp, Arc::clone(&schema)) - }) + let mut init = eqps[0].clone(); + // Harmonize the schema of the init with the schema of the union: + if !init.schema.eq(&schema) { + init = init.with_new_schema(schema)?; + } + eqps.into_iter() + .skip(1) + .try_fold(init, calculate_union_binary) } #[cfg(test)] @@ -2608,23 +2607,202 @@ mod tests { Ok(()) } - #[tokio::test] - async fn test_union_equivalence_properties_binary() -> Result<()> { - let schema = create_test_schema()?; - let schema2 = Schema::new( + fn append_fields(schema: &SchemaRef, text: &str) -> SchemaRef { + Arc::new(Schema::new( schema .fields() .iter() .map(|field| { Field::new( // Annotate name with1, to change field name. - format!("{}1", field.name()), + format!("{}{}", field.name(), text), field.data_type().clone(), field.is_nullable(), ) }) .collect::>(), - ); + )) + } + + #[tokio::test] + async fn test_union_equivalence_properties_multi_children() -> Result<()> { + let schema = create_test_schema()?; + let schema2 = append_fields(&schema, "1"); + let schema3 = append_fields(&schema, "2"); + let test_cases = vec![ + // --------- TEST CASE 1 ---------- + ( + vec![ + // Children 1 + ( + // Orderings + vec![vec!["a", "b", "c"]], + Arc::clone(&schema), + ), + // Children 2 + ( + // Orderings + vec![vec!["a1", "b1", "c1"]], + Arc::clone(&schema2), + ), + // Children 3 + ( + // Orderings + vec![vec!["a2", "b2"]], + Arc::clone(&schema3), + ), + ], + // Expected + vec![vec!["a", "b"]], + ), + // --------- TEST CASE 2 ---------- + ( + vec![ + // Children 1 + ( + // Orderings + vec![vec!["a", "b", "c"]], + Arc::clone(&schema), + ), + // Children 2 + ( + // Orderings + vec![vec!["a1", "b1", "c1"]], + Arc::clone(&schema2), + ), + // Children 3 + ( + // Orderings + vec![vec!["a2", "b2", "c2"]], + Arc::clone(&schema3), + ), + ], + // Expected + vec![vec!["a", "b", "c"]], + ), + // --------- TEST CASE 3 ---------- + ( + vec![ + // Children 1 + ( + // Orderings + vec![vec!["a", "b"]], + Arc::clone(&schema), + ), + // Children 2 + ( + // Orderings + vec![vec!["a1", "b1", "c1"]], + Arc::clone(&schema2), + ), + // Children 3 + ( + // Orderings + vec![vec!["a2", "b2", "c2"]], + Arc::clone(&schema3), + ), + ], + // Expected + vec![vec!["a", "b"]], + ), + // --------- TEST CASE 4 ---------- + ( + vec![ + // Children 1 + ( + // Orderings + vec![vec!["a", "b"]], + Arc::clone(&schema), + ), + // Children 2 + ( + // Orderings + vec![vec!["a1", "b1"]], + Arc::clone(&schema2), + ), + // Children 3 + ( + // Orderings + vec![vec!["b2", "c2"]], + Arc::clone(&schema3), + ), + ], + // Expected + vec![], + ), + // --------- TEST CASE 5 ---------- + ( + vec![ + // Children 1 + ( + // Orderings + vec![vec!["a", "b"], vec!["c"]], + Arc::clone(&schema), + ), + // Children 2 + ( + // Orderings + vec![vec!["a1", "b1"], vec!["c1"]], + Arc::clone(&schema2), + ), + ], + // Expected + vec![vec!["a", "b"], vec!["c"]], + ), + ]; + for (children, expected) in test_cases { + let children_eqs = children + .iter() + .map(|(orderings, schema)| { + let orderings = orderings + .iter() + .map(|ordering| { + ordering + .iter() + .map(|name| PhysicalSortExpr { + expr: col(name, schema).unwrap(), + options: SortOptions::default(), + }) + .collect::>() + }) + .collect::>(); + EquivalenceProperties::new_with_orderings( + Arc::clone(schema), + &orderings, + ) + }) + .collect::>(); + let actual = calculate_union(children_eqs, Arc::clone(&schema))?; + + let expected_ordering = expected + .into_iter() + .map(|ordering| { + ordering + .into_iter() + .map(|name| PhysicalSortExpr { + expr: col(name, &schema).unwrap(), + options: SortOptions::default(), + }) + .collect::>() + }) + .collect::>(); + let expected = EquivalenceProperties::new_with_orderings( + Arc::clone(&schema), + &expected_ordering, + ); + assert_eq_properties_same( + &actual, + &expected, + format!("expected: {:?}, actual: {:?}", expected, actual), + ); + } + Ok(()) + } + + #[tokio::test] + async fn test_union_equivalence_properties_binary() -> Result<()> { + let schema = create_test_schema()?; + let schema2 = append_fields(&schema, "1"); let col_a = &col("a", &schema)?; let col_b = &col("b", &schema)?; let col_c = &col("c", &schema)?; @@ -2643,6 +2821,7 @@ mod tests { ], // First child constants vec![col_b, col_c], + Arc::clone(&schema), ), ( // Second child orderings @@ -2652,6 +2831,7 @@ mod tests { ], // Second child constants vec![col_a, col_c], + Arc::clone(&schema), ), ( // Union expected orderings @@ -2676,6 +2856,7 @@ mod tests { ], // No constant vec![], + Arc::clone(&schema), ), ( // Second child orderings @@ -2685,6 +2866,7 @@ mod tests { ], // No constant vec![], + Arc::clone(&schema), ), ( // Union orderings @@ -2707,6 +2889,7 @@ mod tests { ], // No constant vec![], + Arc::clone(&schema), ), ( // Second child orderings @@ -2716,6 +2899,7 @@ mod tests { ], // No constant vec![], + Arc::clone(&schema), ), ( // Union doesn't have any ordering @@ -2736,6 +2920,7 @@ mod tests { ], // No constant vec![], + Arc::clone(&schema), ), ( // Second child orderings @@ -2745,6 +2930,7 @@ mod tests { ], // No constant vec![], + Arc::clone(&schema2), ), ( // Union orderings @@ -2761,8 +2947,8 @@ mod tests { for ( test_idx, ( - (first_child_orderings, first_child_constants), - (second_child_orderings, second_child_constants), + (first_child_orderings, first_child_constants, first_schema), + (second_child_orderings, second_child_constants, second_schema), (union_orderings, union_constants), ), ) in test_cases.iter().enumerate() @@ -2775,7 +2961,7 @@ mod tests { .iter() .map(|expr| ConstExpr::new(Arc::clone(expr))) .collect::>(); - let mut lhs = EquivalenceProperties::new(Arc::clone(&schema)); + let mut lhs = EquivalenceProperties::new(Arc::clone(first_schema)); lhs = lhs.add_constants(first_constants); lhs.add_new_orderings(first_orderings); @@ -2787,7 +2973,7 @@ mod tests { .iter() .map(|expr| ConstExpr::new(Arc::clone(expr))) .collect::>(); - let mut rhs = EquivalenceProperties::new(Arc::clone(&schema)); + let mut rhs = EquivalenceProperties::new(Arc::clone(second_schema)); rhs = rhs.add_constants(second_constants); rhs.add_new_orderings(second_orderings); @@ -2803,7 +2989,7 @@ mod tests { union_expected_eq = union_expected_eq.add_constants(union_constants); union_expected_eq.add_new_orderings(union_expected_orderings); - let actual_union_eq = calculate_union_binary(lhs, rhs, Arc::clone(&schema))?; + let actual_union_eq = calculate_union_binary(lhs, rhs)?; let err_msg = format!( "Error in test id: {:?}, test case: {:?}", test_idx, test_cases[test_idx]