Skip to content

Commit

Permalink
String support in array_agg
Browse files Browse the repository at this point in the history
  • Loading branch information
lkt committed Apr 10, 2024
1 parent 5e25fa7 commit 3885af2
Show file tree
Hide file tree
Showing 2 changed files with 247 additions and 45 deletions.
195 changes: 159 additions & 36 deletions datafusion/physical-expr/src/aggregate/array_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ use crate::expressions::format_state_name;
use crate::{AggregateExpr, EmitTo, GroupsAccumulator, PhysicalExpr};
use arrow::array::ArrayRef;
use arrow::datatypes::{DataType, Field};
use arrow_array::builder::{ListBuilder, PrimitiveBuilder};
use arrow_array::builder::{ListBuilder, PrimitiveBuilder, StringBuilder};
use arrow_array::cast::AsArray;
use arrow_array::types::{
Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type,
UInt32Type, UInt64Type, UInt8Type,
};
use arrow_array::{Array, ArrowPrimitiveType, BooleanArray, ListArray, PrimitiveArray};
use arrow_array::{Array, ArrowPrimitiveType, BooleanArray, ListArray, StringArray};
use datafusion_common::cast::as_list_array;
use datafusion_common::utils::array_into_list_array;
use datafusion_common::ScalarValue;
Expand Down Expand Up @@ -111,33 +111,16 @@ impl AggregateExpr for ArrayAgg {
fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
match self.input_data_type {
DataType::Int8 => Ok(Box::new(ArrayAggGroupsAccumulator::<Int8Type>::new())),
DataType::Int16 => {
Ok(Box::new(ArrayAggGroupsAccumulator::<Int16Type>::new()))
}
DataType::Int32 => {
Ok(Box::new(ArrayAggGroupsAccumulator::<Int32Type>::new()))
}
DataType::Int64 => {
Ok(Box::new(ArrayAggGroupsAccumulator::<Int64Type>::new()))
}
DataType::UInt8 => {
Ok(Box::new(ArrayAggGroupsAccumulator::<UInt8Type>::new()))
}
DataType::UInt16 => {
Ok(Box::new(ArrayAggGroupsAccumulator::<UInt16Type>::new()))
}
DataType::UInt32 => {
Ok(Box::new(ArrayAggGroupsAccumulator::<UInt32Type>::new()))
}
DataType::UInt64 => {
Ok(Box::new(ArrayAggGroupsAccumulator::<UInt64Type>::new()))
}
DataType::Float32 => {
Ok(Box::new(ArrayAggGroupsAccumulator::<Float32Type>::new()))
}
DataType::Float64 => {
Ok(Box::new(ArrayAggGroupsAccumulator::<Float64Type>::new()))
}
DataType::Int16 => Ok(Box::new(ArrayAggGroupsAccumulator::<Int16Type>::new())),
DataType::Int32 => Ok(Box::new(ArrayAggGroupsAccumulator::<Int32Type>::new())),
DataType::Int64 => Ok(Box::new(ArrayAggGroupsAccumulator::<Int64Type>::new())),
DataType::UInt8 => Ok(Box::new(ArrayAggGroupsAccumulator::<UInt8Type>::new())),
DataType::UInt16 => Ok(Box::new(ArrayAggGroupsAccumulator::<UInt16Type>::new())),
DataType::UInt32 => Ok(Box::new(ArrayAggGroupsAccumulator::<UInt32Type>::new())),
DataType::UInt64 => Ok(Box::new(ArrayAggGroupsAccumulator::<UInt64Type>::new())),
DataType::Float32 => Ok(Box::new(ArrayAggGroupsAccumulator::<Float32Type>::new())),
DataType::Float64 => Ok(Box::new(ArrayAggGroupsAccumulator::<Float64Type>::new())),
DataType::Utf8 => Ok(Box::new(StringArrayAggGroupsAccumulator::new())),
_ => Err(DataFusionError::Internal(format!(
"ArrayAggGroupsAccumulator not supported for data type {:?}",
self.input_data_type
Expand Down Expand Up @@ -335,7 +318,8 @@ where
values,
opt_filter,
total_num_groups,
|group_index, new_value: &PrimitiveArray<T>| {
|group_index, new_value: ArrayRef| {
let new_value = new_value.as_primitive::<T>();
self.values[group_index].append(
new_value
.into_iter()
Expand Down Expand Up @@ -364,6 +348,124 @@ where
}
}

struct StringArrayAggGroupsAccumulator {
values: Vec<Vec<Option<String>>>,
null_state: NullState,
}

impl StringArrayAggGroupsAccumulator {
pub fn new() -> Self {
Self {
values: vec![],
null_state: NullState::new(),
}
}
}

impl StringArrayAggGroupsAccumulator {
fn build_list(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
let array = emit_to.take_needed(&mut self.values);
let nulls = self.null_state.build(emit_to);

assert_eq!(array.len(), nulls.len());

let mut builder =
ListBuilder::with_capacity(StringBuilder::new(), nulls.len());
for (is_valid, arr) in nulls.iter().zip(array.iter()) {
if is_valid {
for value in arr.iter() {
builder.values().append_option(value.as_deref());
}
builder.append(true);
} else {
builder.append_null();
}
}

Ok(Arc::new(builder.finish()))
}
}

impl GroupsAccumulator for StringArrayAggGroupsAccumulator {
fn update_batch(
&mut self,
new_values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
assert_eq!(new_values.len(), 1, "single argument to update_batch");
let new_values = new_values[0].as_string();

self.values.resize(total_num_groups, vec![]);

self.null_state.accumulate_string(
group_indices,
new_values,
opt_filter,
total_num_groups,
|group_index, new_value| {
self.values[group_index].push(Some(new_value.to_string()));
},
);

Ok(())
}

fn merge_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
assert_eq!(values.len(), 1, "single argument to merge_batch");
let values = values[0].as_list();

self.values.resize(total_num_groups, Vec::<Option<String>>::new());

self.null_state.accumulate_array(
group_indices,
values,
opt_filter,
total_num_groups,
|group_index, new_value: ArrayRef| {
let new_value = new_value.as_string::<i32>();

self.values[group_index].append(new_value
.into_iter()
.map(|s| s.map(|s| s.to_string()))
.collect::<Vec<Option<String>>>()
.as_mut());
},
);

Ok(())
}

fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
Ok(self.build_list(emit_to)?)
}

fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
Ok(vec![self.build_list(emit_to)?])
}

fn size(&self) -> usize {
self.values.capacity() +
self.values.iter().map(
|arr|
arr.iter().map(
|e|
e.as_ref().map(|s| s.len()).unwrap_or(0)
).sum::<usize>()
).sum::<usize>()

+ self.null_state.size()
}
}


#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -398,8 +500,6 @@ mod tests {
let expected = ScalarValue::from($EXPECTED);

assert_eq!(expected, actual);

Ok(()) as Result<(), DataFusionError>
}};
}

Expand All @@ -426,8 +526,6 @@ mod tests {
));
let actual = aggregate_new(&batch, agg)?;
assert_eq!($EXPECTED, &actual);

Ok(()) as Result<(), DataFusionError>
}};
}

Expand All @@ -453,7 +551,30 @@ mod tests {
);

let expected: ArrayRef = Arc::new(list);
test_op_new!(a, DataType::Int32, ArrayAgg, &expected, DataType::Int32)
test_op_new!(a, DataType::Int32, ArrayAgg, &expected, DataType::Int32);

Ok(())
}

#[test]
fn array_agg_str() -> Result<()> {
let a: ArrayRef = Arc::new(StringArray::from(vec!["1", "2", "3", "4", "5"]));

let mut list_builder = ListBuilder::with_capacity(StringBuilder::new(), 5);
list_builder.values().append_value("1");
list_builder.values().append_value("2");
list_builder.values().append_value("3");
list_builder.values().append_value("4");
list_builder.values().append_value("5");
list_builder.append(true);

let list = list_builder.finish();
let expected = ScalarValue::List(Arc::new(list.clone()));

let expected: ArrayRef = Arc::new(list);
test_op_new!(a, DataType::Utf8, ArrayAgg, &expected, DataType::Utf8);

Ok(())
}

#[test]
Expand Down Expand Up @@ -519,6 +640,8 @@ mod tests {
ArrayAgg,
list,
DataType::List(Arc::new(Field::new("item", DataType::Int32, true,)))
)
);

Ok(())
}
}
Loading

0 comments on commit 3885af2

Please sign in to comment.