Skip to content

Commit

Permalink
Respect ignore_nulls flag in DistinctArrayAgg
Browse files Browse the repository at this point in the history
  • Loading branch information
joroKr21 committed Aug 21, 2024
1 parent 65b2fc9 commit 39835b4
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 12 deletions.
25 changes: 17 additions & 8 deletions datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ pub struct DistinctArrayAgg {
expr: Arc<dyn PhysicalExpr>,
/// If the input expression can have NULLs
nullable: bool,
/// If NULLs should be ignored when aggregating
ignore_nulls: bool,
}

impl DistinctArrayAgg {
Expand All @@ -53,13 +55,15 @@ impl DistinctArrayAgg {
name: impl Into<String>,
input_data_type: DataType,
nullable: bool,
ignore_nulls: bool,
) -> Self {
let name = name.into();
Self {
name,
input_data_type,
expr,
nullable,
nullable: nullable && !ignore_nulls,
ignore_nulls: nullable && ignore_nulls,
}
}
}
Expand All @@ -74,21 +78,22 @@ impl AggregateExpr for DistinctArrayAgg {
Ok(Field::new_list(
&self.name,
// This should be the same as return type of AggregateFunction::ArrayAgg
Field::new("item", self.input_data_type.clone(), true),
Field::new_list_field(self.input_data_type.clone(), true),
self.nullable,
))
}

fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(DistinctArrayAggAccumulator::try_new(
&self.input_data_type,
self.ignore_nulls,
)?))
}

fn state_fields(&self) -> Result<Vec<Field>> {
Ok(vec![Field::new_list(
format_state_name(&self.name, "distinct_array_agg"),
Field::new("item", self.input_data_type.clone(), true),
Field::new_list_field(self.input_data_type.clone(), true),
self.nullable,
)])
}
Expand Down Expand Up @@ -119,13 +124,15 @@ impl PartialEq<dyn Any> for DistinctArrayAgg {
struct DistinctArrayAggAccumulator {
values: HashSet<ScalarValue>,
datatype: DataType,
ignore_nulls: bool,
}

impl DistinctArrayAggAccumulator {
pub fn try_new(datatype: &DataType) -> Result<Self> {
pub fn try_new(datatype: &DataType, ignore_nulls: bool) -> Result<Self> {
Ok(Self {
values: HashSet::new(),
datatype: datatype.clone(),
ignore_nulls,
})
}
}
Expand All @@ -137,12 +144,12 @@ impl Accumulator for DistinctArrayAggAccumulator {

fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
assert_eq!(values.len(), 1, "batch input should only include 1 column!");

let array = &values[0];

for i in 0..array.len() {
let scalar = ScalarValue::try_from_array(&array, i)?;
self.values.insert(scalar);
if !(self.ignore_nulls && scalar.is_null()) {
self.values.insert(scalar);
}
}

Ok(())
Expand Down Expand Up @@ -239,13 +246,14 @@ mod tests {
) -> Result<()> {
let schema = Schema::new(vec![Field::new("a", datatype.clone(), false)]);
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![input])?;

let agg = Arc::new(DistinctArrayAgg::new(
col("a", &schema)?,
"bla".to_string(),
datatype,
true,
false,
));

let actual = aggregate(&batch, agg)?;
compare_list_contents(expected, actual)
}
Expand All @@ -262,6 +270,7 @@ mod tests {
"bla".to_string(),
datatype,
true,
false,
));

let mut accum1 = agg.create_accumulator()?;
Expand Down
3 changes: 2 additions & 1 deletion datafusion/physical-expr/src/aggregate/build_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ pub fn create_aggregate_expr(
ordering_req: &[PhysicalSortExpr],
input_schema: &Schema,
name: impl Into<String>,
_ignore_nulls: bool,
ignore_nulls: bool,
) -> Result<Arc<dyn AggregateExpr>> {
let name = name.into();
// get the result data type for this aggregate function
Expand Down Expand Up @@ -140,6 +140,7 @@ pub fn create_aggregate_expr(
name,
data_type,
is_expr_nullable,
ignore_nulls,
))
}
(AggregateFunction::Min, _) => Arc::new(expressions::Min::new(
Expand Down
8 changes: 5 additions & 3 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -179,17 +179,19 @@ CREATE TABLE array_agg_distinct_list_table AS VALUES
('b', [1,0]),
('b', [1,0]),
('b', [1,0]),
('b', [0,1])
('b', [0,1]),
(NULL, [0,1]),
('b', NULL)
;

# Apply array_sort to have deterministic result, higher dimension nested array also works but not for array sort,
# so they are covered in `datafusion/physical-expr/src/aggregate/array_agg_distinct.rs`
query ??
select array_sort(c1), array_sort(c2) from (
select array_agg(distinct column1) as c1, array_agg(distinct column2) as c2 from array_agg_distinct_list_table
select array_agg(distinct column1) as c1, array_agg(distinct column2) ignore nulls as c2 from array_agg_distinct_list_table
);
----
[b, w] [[0, 1], [1, 0]]
[, b, w] [[0, 1], [1, 0]]

statement ok
drop table array_agg_distinct_list_table;
Expand Down

0 comments on commit 39835b4

Please sign in to comment.