Skip to content

Commit

Permalink
fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
lkt committed Apr 4, 2024
1 parent b2225c4 commit b637f84
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 36 deletions.
79 changes: 54 additions & 25 deletions datafusion/physical-expr/src/aggregate/array_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,27 @@

//! Defines physical expressions that can evaluated at runtime during query execution

use crate::aggregate::groups_accumulator::accumulate::{
accumulate_array, accumulate_array_elements, NullState,
};
use crate::aggregate::utils::down_cast_any_ref;
use crate::expressions::format_state_name;
use crate::{AggregateExpr, EmitTo, GroupsAccumulator, PhysicalExpr};
use arrow::array::ArrayRef;
use arrow::datatypes::{DataType, Field};
use arrow_array::cast::AsArray;
use arrow_array::types::{
Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type,
UInt32Type, UInt64Type, UInt8Type,
};
use arrow_array::{Array, ArrowPrimitiveType, BooleanArray, ListArray};
use datafusion_common::cast::as_list_array;
use datafusion_common::utils::array_into_list_array;
use datafusion_common::{DataFusionError, Result};
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::Accumulator;
use std::any::Any;
use std::sync::Arc;
use arrow_array::cast::AsArray;
use arrow_array::types::{Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type};
use crate::aggregate::groups_accumulator::accumulate::{accumulate_array, accumulate_array_elements, NullState};

/// ARRAY_AGG aggregate expression
#[derive(Debug)]
Expand Down Expand Up @@ -107,19 +112,37 @@ impl AggregateExpr for ArrayAgg {
fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
match self.input_data_type {
DataType::Int8 => Ok(Box::new(ArrayAggGroupsAccumulator::<Int8Type>::new())),
DataType::Int16 => Ok(Box::new(ArrayAggGroupsAccumulator::<Int16Type>::new())),
DataType::Int32 => Ok(Box::new(ArrayAggGroupsAccumulator::<Int32Type>::new())),
DataType::Int64 => Ok(Box::new(ArrayAggGroupsAccumulator::<Int64Type>::new())),
DataType::UInt8 => Ok(Box::new(ArrayAggGroupsAccumulator::<UInt8Type>::new())),
DataType::UInt16 => Ok(Box::new(ArrayAggGroupsAccumulator::<UInt16Type>::new())),
DataType::UInt32 => Ok(Box::new(ArrayAggGroupsAccumulator::<UInt32Type>::new())),
DataType::UInt64 => Ok(Box::new(ArrayAggGroupsAccumulator::<UInt64Type>::new())),
DataType::Float32 => Ok(Box::new(ArrayAggGroupsAccumulator::<Float32Type>::new())),
DataType::Float64 => Ok(Box::new(ArrayAggGroupsAccumulator::<Float64Type>::new())),
DataType::Int16 => {
Ok(Box::new(ArrayAggGroupsAccumulator::<Int16Type>::new()))
}
DataType::Int32 => {
Ok(Box::new(ArrayAggGroupsAccumulator::<Int32Type>::new()))
}
DataType::Int64 => {
Ok(Box::new(ArrayAggGroupsAccumulator::<Int64Type>::new()))
}
DataType::UInt8 => {
Ok(Box::new(ArrayAggGroupsAccumulator::<UInt8Type>::new()))
}
DataType::UInt16 => {
Ok(Box::new(ArrayAggGroupsAccumulator::<UInt16Type>::new()))
}
DataType::UInt32 => {
Ok(Box::new(ArrayAggGroupsAccumulator::<UInt32Type>::new()))
}
DataType::UInt64 => {
Ok(Box::new(ArrayAggGroupsAccumulator::<UInt64Type>::new()))
}
DataType::Float32 => {
Ok(Box::new(ArrayAggGroupsAccumulator::<Float32Type>::new()))
}
DataType::Float64 => {
Ok(Box::new(ArrayAggGroupsAccumulator::<Float64Type>::new()))
}
_ => Err(DataFusionError::Internal(format!(
"ArrayAggGroupsAccumulator not supported for data type {:?}",
self.input_data_type
)))
))),
}
}
}
Expand Down Expand Up @@ -237,7 +260,6 @@ impl<T> GroupsAccumulator for ArrayAggGroupsAccumulator<T>
where
T: ArrowPrimitiveType + Send + Sync,
{

// TODO:
// 1. Implement support for null state
// 2. Implement support for low level ListArray creation api with offsets and nulls
Expand All @@ -250,7 +272,7 @@ where
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&BooleanArray>,
total_num_groups: usize
total_num_groups: usize,
) -> Result<()> {
assert_eq!(values.len(), 1, "single argument to update_batch");
let values = values[0].as_primitive::<T>();
Expand Down Expand Up @@ -280,7 +302,7 @@ where
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&BooleanArray>,
total_num_groups: usize
total_num_groups: usize,
) -> Result<()> {
assert_eq!(values.len(), 1, "single argument to merge_batch");
let values = values[0].as_list();
Expand Down Expand Up @@ -317,7 +339,6 @@ where
}

fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {

// TODO: do we need null state?
// let nulls = self.null_state.build(emit_to);
// let nulls = Some(nulls);
Expand All @@ -329,12 +350,14 @@ where
}

fn size(&self) -> usize {
self.values.capacity() +
self.values
.iter()
.map(|arr| arr.as_ref().unwrap_or(&Vec::new()).capacity())
.sum::<usize>() * std::mem::size_of::<T>() +
self.null_state.size()
self.values.capacity()
+ self
.values
.iter()
.map(|arr| arr.as_ref().unwrap_or(&Vec::new()).capacity())
.sum::<usize>()
* std::mem::size_of::<T>()
+ self.null_state.size()
}
}

Expand Down Expand Up @@ -419,7 +442,13 @@ mod tests {
])]);
let expected = ScalarValue::List(Arc::new(list.clone()));

test_op!(a.clone(), DataType::Int32, ArrayAgg, expected, DataType::Int32);
test_op!(
a.clone(),
DataType::Int32,
ArrayAgg,
expected,
DataType::Int32
);

let expected: ArrayRef = Arc::new(list);
test_op_new!(a, DataType::Int32, ArrayAgg, &expected, DataType::Int32)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
//! [`GroupsAccumulator`]: crate::GroupsAccumulator

use arrow::datatypes::ArrowPrimitiveType;
use arrow_array::{Array, BooleanArray, ListArray, PrimitiveArray};
use arrow_array::cast::AsArray;
use arrow_array::{Array, BooleanArray, ListArray, PrimitiveArray};
use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, NullBuffer};

use crate::EmitTo;
Expand Down Expand Up @@ -445,19 +445,20 @@ pub fn accumulate_array_elements<F, T>(
mut value_fn: F,
) where
F: FnMut(usize, <T as ArrowPrimitiveType>::Native) + Send,
T: ArrowPrimitiveType + Send
T: ArrowPrimitiveType + Send,
{
assert_eq!(values.len(), group_indices.len());

match opt_filter {
// no filter,
None => {
let iter = values.iter();
group_indices.iter().zip(iter).for_each(
|(&group_index, new_value)| {
group_indices
.iter()
.zip(iter)
.for_each(|(&group_index, new_value)| {
value_fn(group_index, new_value.unwrap())
},
)
})
}
// a filter
Some(filter) => {
Expand All @@ -482,19 +483,20 @@ pub fn accumulate_array<F, T>(
mut value_fn: F,
) where
F: FnMut(usize, &PrimitiveArray<T>) + Send,
T: ArrowPrimitiveType + Send
T: ArrowPrimitiveType + Send,
{
assert_eq!(values.len(), group_indices.len());

match opt_filter {
// no filter,
None => {
let iter = values.iter();
group_indices.iter().zip(iter).for_each(
|(&group_index, new_value)| {
group_indices
.iter()
.zip(iter)
.for_each(|(&group_index, new_value)| {
value_fn(group_index, new_value.unwrap().as_primitive::<T>())
},
)
})
}
// a filter
Some(filter) => {
Expand Down

0 comments on commit b637f84

Please sign in to comment.