Skip to content

Commit

Permalink
Fix grouping sets behavior when data contains nulls
Browse files Browse the repository at this point in the history
  • Loading branch information
eejbyfeldt committed Sep 22, 2024
1 parent 3bd41bc commit 9c840b0
Show file tree
Hide file tree
Showing 5 changed files with 256 additions and 147 deletions.
16 changes: 3 additions & 13 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -707,10 +707,6 @@ impl DefaultPhysicalPlanner {
physical_input_schema.clone(),
)?);

// update group column indices based on partial aggregate plan evaluation
let final_group: Vec<Arc<dyn PhysicalExpr>> =
initial_aggr.output_group_expr();

let can_repartition = !groups.is_empty()
&& session_state.config().target_partitions() > 1
&& session_state.config().repartition_aggregations();
Expand All @@ -731,13 +727,7 @@ impl DefaultPhysicalPlanner {
AggregateMode::Final
};

let final_grouping_set = PhysicalGroupBy::new_single(
final_group
.iter()
.enumerate()
.map(|(i, expr)| (expr.clone(), groups.expr()[i].1.clone()))
.collect(),
);
let final_grouping_set = initial_aggr.group_expr().as_final();

Arc::new(AggregateExec::try_new(
next_partition_mode,
Expand Down Expand Up @@ -2061,7 +2051,7 @@ mod tests {
&session_state,
);

let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[false, false, false], [true, false, false], [false, true, false], [false, false, true], [true, true, false], [true, false, true], [false, true, true], [true, true, true]] })"#;
let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[false, false, false], [true, false, false], [false, true, false], [false, false, true], [true, true, false], [true, false, true], [false, true, true], [true, true, true]], num_output_exprs: 4 })"#;

assert_eq!(format!("{cube:?}"), expected);

Expand All @@ -2088,7 +2078,7 @@ mod tests {
&session_state,
);

let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[true, true, true], [false, true, true], [false, false, true], [false, false, false]] })"#;
let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[true, true, true], [false, true, true], [false, false, true], [false, false, false]], num_output_exprs: 4 })"#;

assert_eq!(format!("{rollup:?}"), expected);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ fn can_combine(final_agg: GroupExprsRef, partial_agg: GroupExprsRef) -> bool {

// Compare output expressions of the partial, and input expressions of the final operator.
physical_exprs_equal(
&input_group_by.output_exprs(),
&input_group_by.output_exprs(&AggregateMode::Partial),
&final_group_by.input_exprs(),
) && input_group_by.groups() == final_group_by.groups()
&& input_group_by.null_expr().len() == final_group_by.null_expr().len()
Expand Down
Loading

0 comments on commit 9c840b0

Please sign in to comment.