Skip to content

Commit

Permalink
Add new tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mustafasrepo committed Jul 22, 2024
1 parent d7d9189 commit ce82306
Showing 1 changed file with 215 additions and 29 deletions.
244 changes: 215 additions & 29 deletions datafusion/physical-expr/src/equivalence/properties.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1532,26 +1532,19 @@ impl Hash for ExprWrapper {
}

/// Calculates the union (in the sense of `UnionExec`) `EquivalenceProperties`
/// of `lhs` and `rhs` according to the given output `schema` (which need not
/// be the same with those of `lhs` and `rhs` as details such as nullability
/// may be different).
/// of `lhs` and `rhs` according to the schema of the `lhs`.
fn calculate_union_binary(
mut lhs: EquivalenceProperties,
lhs: EquivalenceProperties,
mut rhs: EquivalenceProperties,
schema: SchemaRef,
) -> Result<EquivalenceProperties> {
// TODO: In some cases, we should be able to preserve some equivalence
// classes. Add support for such cases.

// Harmonize the schemas of the two sides with the output schema:
if !lhs.schema.eq(&schema) {
lhs = lhs.with_new_schema(Arc::clone(&schema))?;
}
if !rhs.schema.eq(&schema) {
rhs = rhs.with_new_schema(Arc::clone(&schema))?;
// Harmonize the schema of the rhs with the schema of the lhs (which is the accumulator schema):
if !rhs.schema.eq(&lhs.schema) {
rhs = rhs.with_new_schema(Arc::clone(&lhs.schema))?;
}

let mut eq_properties = EquivalenceProperties::new(schema);
// First, calculate valid constants for the union. A quantity is constant
// after the union if it is constant in both sides.
let constants = lhs
Expand All @@ -1564,8 +1557,8 @@ fn calculate_union_binary(
// the same. However, we do not have the capability to
// check this yet.
ConstExpr::new(Arc::clone(const_expr.expr())).with_across_partitions(false)
});
eq_properties = eq_properties.add_constants(constants);
})
.collect::<Vec<_>>();

// Next, calculate valid orderings for the union by searching for prefixes
// in both sides.
Expand All @@ -1590,6 +1583,8 @@ fn calculate_union_binary(
orderings.push(ordering);
}
}
let mut eq_properties = EquivalenceProperties::new(lhs.schema);
eq_properties.constants = constants;
eq_properties.add_new_orderings(orderings);
Ok(eq_properties)
}
Expand All @@ -1604,10 +1599,14 @@ pub fn calculate_union(
) -> Result<EquivalenceProperties> {
// TODO: In some cases, we should be able to preserve some equivalence
// classes. Add support for such cases.
let init = eqps[0].clone();
eqps.into_iter().skip(1).try_fold(init, |acc, eqp| {
calculate_union_binary(acc, eqp, Arc::clone(&schema))
})
let mut init = eqps[0].clone();
// Harmonize the schema of the init with the schema of the union:
if !init.schema.eq(&schema) {
init = init.with_new_schema(schema)?;
}
eqps.into_iter()
.skip(1)
.try_fold(init, calculate_union_binary)
}

#[cfg(test)]
Expand Down Expand Up @@ -2608,23 +2607,202 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn test_union_equivalence_properties_binary() -> Result<()> {
let schema = create_test_schema()?;
let schema2 = Schema::new(
fn append_fields(schema: &SchemaRef, text: &str) -> SchemaRef {
Arc::new(Schema::new(
schema
.fields()
.iter()
.map(|field| {
Field::new(
// Annotate name with1, to change field name.
format!("{}1", field.name()),
format!("{}{}", field.name(), text),
field.data_type().clone(),
field.is_nullable(),
)
})
.collect::<Vec<_>>(),
);
))
}

#[tokio::test]
async fn test_union_equivalence_properties_multi_children() -> Result<()> {
let schema = create_test_schema()?;
let schema2 = append_fields(&schema, "1");
let schema3 = append_fields(&schema, "2");
let test_cases = vec![
// --------- TEST CASE 1 ----------
(
vec![
// Children 1
(
// Orderings
vec![vec!["a", "b", "c"]],
Arc::clone(&schema),
),
// Children 2
(
// Orderings
vec![vec!["a1", "b1", "c1"]],
Arc::clone(&schema2),
),
// Children 3
(
// Orderings
vec![vec!["a2", "b2"]],
Arc::clone(&schema3),
),
],
// Expected
vec![vec!["a", "b"]],
),
// --------- TEST CASE 2 ----------
(
vec![
// Children 1
(
// Orderings
vec![vec!["a", "b", "c"]],
Arc::clone(&schema),
),
// Children 2
(
// Orderings
vec![vec!["a1", "b1", "c1"]],
Arc::clone(&schema2),
),
// Children 3
(
// Orderings
vec![vec!["a2", "b2", "c2"]],
Arc::clone(&schema3),
),
],
// Expected
vec![vec!["a", "b", "c"]],
),
// --------- TEST CASE 3 ----------
(
vec![
// Children 1
(
// Orderings
vec![vec!["a", "b"]],
Arc::clone(&schema),
),
// Children 2
(
// Orderings
vec![vec!["a1", "b1", "c1"]],
Arc::clone(&schema2),
),
// Children 3
(
// Orderings
vec![vec!["a2", "b2", "c2"]],
Arc::clone(&schema3),
),
],
// Expected
vec![vec!["a", "b"]],
),
// --------- TEST CASE 4 ----------
(
vec![
// Children 1
(
// Orderings
vec![vec!["a", "b"]],
Arc::clone(&schema),
),
// Children 2
(
// Orderings
vec![vec!["a1", "b1"]],
Arc::clone(&schema2),
),
// Children 3
(
// Orderings
vec![vec!["b2", "c2"]],
Arc::clone(&schema3),
),
],
// Expected
vec![],
),
// --------- TEST CASE 5 ----------
(
vec![
// Children 1
(
// Orderings
vec![vec!["a", "b"], vec!["c"]],
Arc::clone(&schema),
),
// Children 2
(
// Orderings
vec![vec!["a1", "b1"], vec!["c1"]],
Arc::clone(&schema2),
),
],
// Expected
vec![vec!["a", "b"], vec!["c"]],
),
];
for (children, expected) in test_cases {
let children_eqs = children
.iter()
.map(|(orderings, schema)| {
let orderings = orderings
.iter()
.map(|ordering| {
ordering
.iter()
.map(|name| PhysicalSortExpr {
expr: col(name, schema).unwrap(),
options: SortOptions::default(),
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
EquivalenceProperties::new_with_orderings(
Arc::clone(schema),
&orderings,
)
})
.collect::<Vec<_>>();
let actual = calculate_union(children_eqs, Arc::clone(&schema))?;

let expected_ordering = expected
.into_iter()
.map(|ordering| {
ordering
.into_iter()
.map(|name| PhysicalSortExpr {
expr: col(name, &schema).unwrap(),
options: SortOptions::default(),
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
let expected = EquivalenceProperties::new_with_orderings(
Arc::clone(&schema),
&expected_ordering,
);
assert_eq_properties_same(
&actual,
&expected,
format!("expected: {:?}, actual: {:?}", expected, actual),
);
}
Ok(())
}

#[tokio::test]
async fn test_union_equivalence_properties_binary() -> Result<()> {
let schema = create_test_schema()?;
let schema2 = append_fields(&schema, "1");
let col_a = &col("a", &schema)?;
let col_b = &col("b", &schema)?;
let col_c = &col("c", &schema)?;
Expand All @@ -2643,6 +2821,7 @@ mod tests {
],
// First child constants
vec![col_b, col_c],
Arc::clone(&schema),
),
(
// Second child orderings
Expand All @@ -2652,6 +2831,7 @@ mod tests {
],
// Second child constants
vec![col_a, col_c],
Arc::clone(&schema),
),
(
// Union expected orderings
Expand All @@ -2676,6 +2856,7 @@ mod tests {
],
// No constant
vec![],
Arc::clone(&schema),
),
(
// Second child orderings
Expand All @@ -2685,6 +2866,7 @@ mod tests {
],
// No constant
vec![],
Arc::clone(&schema),
),
(
// Union orderings
Expand All @@ -2707,6 +2889,7 @@ mod tests {
],
// No constant
vec![],
Arc::clone(&schema),
),
(
// Second child orderings
Expand All @@ -2716,6 +2899,7 @@ mod tests {
],
// No constant
vec![],
Arc::clone(&schema),
),
(
// Union doesn't have any ordering
Expand All @@ -2736,6 +2920,7 @@ mod tests {
],
// No constant
vec![],
Arc::clone(&schema),
),
(
// Second child orderings
Expand All @@ -2745,6 +2930,7 @@ mod tests {
],
// No constant
vec![],
Arc::clone(&schema2),
),
(
// Union orderings
Expand All @@ -2761,8 +2947,8 @@ mod tests {
for (
test_idx,
(
(first_child_orderings, first_child_constants),
(second_child_orderings, second_child_constants),
(first_child_orderings, first_child_constants, first_schema),
(second_child_orderings, second_child_constants, second_schema),
(union_orderings, union_constants),
),
) in test_cases.iter().enumerate()
Expand All @@ -2775,7 +2961,7 @@ mod tests {
.iter()
.map(|expr| ConstExpr::new(Arc::clone(expr)))
.collect::<Vec<_>>();
let mut lhs = EquivalenceProperties::new(Arc::clone(&schema));
let mut lhs = EquivalenceProperties::new(Arc::clone(first_schema));
lhs = lhs.add_constants(first_constants);
lhs.add_new_orderings(first_orderings);

Expand All @@ -2787,7 +2973,7 @@ mod tests {
.iter()
.map(|expr| ConstExpr::new(Arc::clone(expr)))
.collect::<Vec<_>>();
let mut rhs = EquivalenceProperties::new(Arc::clone(&schema));
let mut rhs = EquivalenceProperties::new(Arc::clone(second_schema));
rhs = rhs.add_constants(second_constants);
rhs.add_new_orderings(second_orderings);

Expand All @@ -2803,7 +2989,7 @@ mod tests {
union_expected_eq = union_expected_eq.add_constants(union_constants);
union_expected_eq.add_new_orderings(union_expected_orderings);

let actual_union_eq = calculate_union_binary(lhs, rhs, Arc::clone(&schema))?;
let actual_union_eq = calculate_union_binary(lhs, rhs)?;
let err_msg = format!(
"Error in test id: {:?}, test case: {:?}",
test_idx, test_cases[test_idx]
Expand Down

0 comments on commit ce82306

Please sign in to comment.