diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/physical-expr/src/aggregate/array_agg.rs index 68d359018557..da5ab4e14e01 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg.rs @@ -17,22 +17,27 @@ //! Defines physical expressions that can evaluated at runtime during query execution +use crate::aggregate::groups_accumulator::accumulate::{ + accumulate_array, accumulate_array_elements, NullState, +}; use crate::aggregate::utils::down_cast_any_ref; use crate::expressions::format_state_name; use crate::{AggregateExpr, EmitTo, GroupsAccumulator, PhysicalExpr}; use arrow::array::ArrayRef; use arrow::datatypes::{DataType, Field}; +use arrow_array::cast::AsArray; +use arrow_array::types::{ + Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, + UInt32Type, UInt64Type, UInt8Type, +}; use arrow_array::{Array, ArrowPrimitiveType, BooleanArray, ListArray}; use datafusion_common::cast::as_list_array; use datafusion_common::utils::array_into_list_array; -use datafusion_common::{DataFusionError, Result}; use datafusion_common::ScalarValue; +use datafusion_common::{DataFusionError, Result}; use datafusion_expr::Accumulator; use std::any::Any; use std::sync::Arc; -use arrow_array::cast::AsArray; -use arrow_array::types::{Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type}; -use crate::aggregate::groups_accumulator::accumulate::{accumulate_array, accumulate_array_elements, NullState}; /// ARRAY_AGG aggregate expression #[derive(Debug)] @@ -107,19 +112,37 @@ impl AggregateExpr for ArrayAgg { fn create_groups_accumulator(&self) -> Result> { match self.input_data_type { DataType::Int8 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), - DataType::Int16 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), - DataType::Int32 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), - DataType::Int64 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), - DataType::UInt8 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), - DataType::UInt16 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), - DataType::UInt32 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), - DataType::UInt64 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), - DataType::Float32 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), - DataType::Float64 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), + DataType::Int16 => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new())) + } + DataType::Int32 => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new())) + } + DataType::Int64 => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new())) + } + DataType::UInt8 => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new())) + } + DataType::UInt16 => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new())) + } + DataType::UInt32 => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new())) + } + DataType::UInt64 => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new())) + } + DataType::Float32 => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new())) + } + DataType::Float64 => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new())) + } _ => Err(DataFusionError::Internal(format!( "ArrayAggGroupsAccumulator not supported for data type {:?}", self.input_data_type - ))) + ))), } } } @@ -237,7 +260,6 @@ impl GroupsAccumulator for ArrayAggGroupsAccumulator where T: ArrowPrimitiveType + Send + Sync, { - // TODO: // 1. Implement support for null state // 2. Implement support for low level ListArray creation api with offsets and nulls @@ -250,7 +272,7 @@ where values: &[ArrayRef], group_indices: &[usize], opt_filter: Option<&BooleanArray>, - total_num_groups: usize + total_num_groups: usize, ) -> Result<()> { assert_eq!(values.len(), 1, "single argument to update_batch"); let values = values[0].as_primitive::(); @@ -280,7 +302,7 @@ where values: &[ArrayRef], group_indices: &[usize], opt_filter: Option<&BooleanArray>, - total_num_groups: usize + total_num_groups: usize, ) -> Result<()> { assert_eq!(values.len(), 1, "single argument to merge_batch"); let values = values[0].as_list(); @@ -317,7 +339,6 @@ where } fn state(&mut self, emit_to: EmitTo) -> Result> { - // TODO: do we need null state? // let nulls = self.null_state.build(emit_to); // let nulls = Some(nulls); @@ -329,12 +350,14 @@ where } fn size(&self) -> usize { - self.values.capacity() + - self.values - .iter() - .map(|arr| arr.as_ref().unwrap_or(&Vec::new()).capacity()) - .sum::() * std::mem::size_of::() + - self.null_state.size() + self.values.capacity() + + self + .values + .iter() + .map(|arr| arr.as_ref().unwrap_or(&Vec::new()).capacity()) + .sum::() + * std::mem::size_of::() + + self.null_state.size() } } @@ -419,7 +442,13 @@ mod tests { ])]); let expected = ScalarValue::List(Arc::new(list.clone())); - test_op!(a.clone(), DataType::Int32, ArrayAgg, expected, DataType::Int32); + test_op!( + a.clone(), + DataType::Int32, + ArrayAgg, + expected, + DataType::Int32 + ); let expected: ArrayRef = Arc::new(list); test_op_new!(a, DataType::Int32, ArrayAgg, &expected, DataType::Int32) diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs index 3f2425ca6dc0..01f13d38adc0 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs @@ -20,8 +20,8 @@ //! [`GroupsAccumulator`]: crate::GroupsAccumulator use arrow::datatypes::ArrowPrimitiveType; -use arrow_array::{Array, BooleanArray, ListArray, PrimitiveArray}; use arrow_array::cast::AsArray; +use arrow_array::{Array, BooleanArray, ListArray, PrimitiveArray}; use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, NullBuffer}; use crate::EmitTo; @@ -445,7 +445,7 @@ pub fn accumulate_array_elements( mut value_fn: F, ) where F: FnMut(usize, ::Native) + Send, - T: ArrowPrimitiveType + Send + T: ArrowPrimitiveType + Send, { assert_eq!(values.len(), group_indices.len()); @@ -453,11 +453,12 @@ pub fn accumulate_array_elements( // no filter, None => { let iter = values.iter(); - group_indices.iter().zip(iter).for_each( - |(&group_index, new_value)| { + group_indices + .iter() + .zip(iter) + .for_each(|(&group_index, new_value)| { value_fn(group_index, new_value.unwrap()) - }, - ) + }) } // a filter Some(filter) => { @@ -482,7 +483,7 @@ pub fn accumulate_array( mut value_fn: F, ) where F: FnMut(usize, &PrimitiveArray) + Send, - T: ArrowPrimitiveType + Send + T: ArrowPrimitiveType + Send, { assert_eq!(values.len(), group_indices.len()); @@ -490,11 +491,12 @@ pub fn accumulate_array( // no filter, None => { let iter = values.iter(); - group_indices.iter().zip(iter).for_each( - |(&group_index, new_value)| { + group_indices + .iter() + .zip(iter) + .for_each(|(&group_index, new_value)| { value_fn(group_index, new_value.unwrap().as_primitive::()) - }, - ) + }) } // a filter Some(filter) => {