Skip to content

Commit

Permalink
Groups accumulator for array_agg (#233)
Browse files Browse the repository at this point in the history
* Groups accumulator for array_agg

* small fix

* fmt

* clippy

* clippy
  • Loading branch information
lkt authored and joroKr21 committed Jun 13, 2024
1 parent 02867fd commit 59edbdf
Show file tree
Hide file tree
Showing 2 changed files with 704 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
//!
//! [`GroupsAccumulator`]: datafusion_expr::GroupsAccumulator

use arrow::array::{Array, BooleanArray, BooleanBufferBuilder, PrimitiveArray};
use arrow::array::BooleanBufferBuilder;
use arrow::array::{
Array, ArrayRef, BooleanArray, ListArray, PrimitiveArray, StringArray,
};
use arrow::buffer::{BooleanBuffer, NullBuffer};
use arrow::datatypes::ArrowPrimitiveType;

Expand Down Expand Up @@ -324,6 +327,166 @@ impl NullState {
}
}

/// Invokes `value_fn(group_index, value)` for each non null, non
/// filtered value in `values`, while tracking which groups have
/// seen null inputs and which groups have seen any inputs, for
/// [`ListArray`]s.
///
/// See [`Self::accumulate`], which handles `PrimitiveArray`s, for
/// more details on other arguments.
pub fn accumulate_array<F>(
&mut self,
group_indices: &[usize],
values: &ListArray,
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
mut value_fn: F,
) where
F: FnMut(usize, ArrayRef) + Send,
{
assert_eq!(values.len(), group_indices.len());

// ensure the seen_values is big enough (start everything at
// "not seen" valid)
let seen_values =
initialize_builder(&mut self.seen_values, total_num_groups, false);

match (values.null_count() > 0, opt_filter) {
// no nulls, no filter,
(false, None) => {
let iter = group_indices.iter().zip(values.iter());
for (&group_index, new_value) in iter {
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value.unwrap());
}
}
// nulls, no filter
(true, None) => {
let nulls = values.nulls().unwrap();
group_indices
.iter()
.zip(values.iter())
.zip(nulls.iter())
.for_each(|((&group_index, new_value), is_valid)| {
if is_valid {
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value.unwrap());
}
})
}
// no nulls, but a filter
(false, Some(filter)) => {
assert_eq!(filter.len(), group_indices.len());
group_indices
.iter()
.zip(values.iter())
.zip(filter.iter())
.for_each(|((&group_index, new_value), filter_value)| {
if let Some(true) = filter_value {
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value.unwrap());
}
});
}
// both null values and filters
(true, Some(filter)) => {
assert_eq!(filter.len(), group_indices.len());
filter
.iter()
.zip(group_indices.iter())
.zip(values.iter())
.for_each(|((filter_value, &group_index), new_value)| {
if let Some(true) = filter_value {
if let Some(new_value) = new_value {
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value);
}
}
});
}
}
}

/// Invokes `value_fn(group_index, value)` for each non-null,
/// non-filtered value in `values`, while tracking which groups have
/// seen null inputs and which groups have seen any inputs, for
/// [`ListArray`]s.
///
/// See [`Self::accumulate`], which handles `PrimitiveArray`s, for
/// more details on other arguments.
pub fn accumulate_string<F>(
&mut self,
group_indices: &[usize],
values: &StringArray,
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
mut value_fn: F,
) where
F: FnMut(usize, &str) + Send,
{
assert_eq!(values.len(), group_indices.len());

// ensure the seen_values is big enough (start everything at
// "not seen" valid)
let seen_values =
initialize_builder(&mut self.seen_values, total_num_groups, false);

match (values.null_count() > 0, opt_filter) {
// no nulls, no filter,
(false, None) => {
let iter = group_indices.iter().zip(values.iter());
for (&group_index, new_value) in iter {
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value.unwrap());
}
}
// nulls, no filter
(true, None) => {
let nulls = values.nulls().unwrap();
group_indices
.iter()
.zip(values.iter())
.zip(nulls.iter())
.for_each(|((&group_index, new_value), is_valid)| {
if is_valid {
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value.unwrap());
}
})
}
// no nulls, but a filter
(false, Some(filter)) => {
assert_eq!(filter.len(), group_indices.len());
group_indices
.iter()
.zip(values.iter())
.zip(filter.iter())
.for_each(|((&group_index, new_value), filter_value)| {
if let Some(true) = filter_value {
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value.unwrap());
}
});
}
// both null values and filters
(true, Some(filter)) => {
assert_eq!(filter.len(), group_indices.len());
filter
.iter()
.zip(group_indices.iter())
.zip(values.iter())
.for_each(|((filter_value, &group_index), new_value)| {
if let Some(true) = filter_value {
if let Some(new_value) = new_value {
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value);
}
}
});
}
}
}

/// Creates the a [`NullBuffer`] representing which group_indices
/// should have null values (because they never saw any values)
/// for the `emit_to` rows.
Expand Down
Loading

0 comments on commit 59edbdf

Please sign in to comment.