Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add was_valid parameter to NullState callbacks #11592

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions datafusion-examples/examples/advanced_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,10 +288,9 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator {
values,
opt_filter,
total_num_groups,
|group_index, new_value| {
|group_index, _, new_value| {
let prod = &mut self.prods[group_index];
*prod = prod.mul_wrapping(new_value);

self.counts[group_index] += 1;
},
);
Expand All @@ -318,7 +317,7 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator {
partial_counts,
opt_filter,
total_num_groups,
|group_index, partial_count| {
|group_index, _, partial_count| {
self.counts[group_index] += partial_count;
},
);
Expand All @@ -330,7 +329,7 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator {
partial_prods,
opt_filter,
total_num_groups,
|group_index, new_value: <Float64Type as ArrowPrimitiveType>::Native| {
|group_index, _, new_value| {
let prod = &mut self.prods[group_index];
*prod = prod.mul_wrapping(new_value);
},
Expand Down
7 changes: 3 additions & 4 deletions datafusion/functions-aggregate/src/average.rs
Original file line number Diff line number Diff line change
Expand Up @@ -450,10 +450,9 @@ where
values,
opt_filter,
total_num_groups,
|group_index, new_value| {
|group_index, _, new_value| {
let sum = &mut self.sums[group_index];
*sum = sum.add_wrapping(new_value);

self.counts[group_index] += 1;
},
);
Expand Down Expand Up @@ -533,7 +532,7 @@ where
partial_counts,
opt_filter,
total_num_groups,
|group_index, partial_count| {
|group_index, _, partial_count| {
self.counts[group_index] += partial_count;
},
);
Expand All @@ -545,7 +544,7 @@ where
partial_sums,
opt_filter,
total_num_groups,
|group_index, new_value: <T as ArrowPrimitiveType>::Native| {
|group_index, _, new_value: <T as ArrowPrimitiveType>::Native| {
let sum = &mut self.sums[group_index];
*sum = sum.add_wrapping(new_value);
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,18 @@ impl NullState {
self.seen_values.capacity() / 8
}

/// Invokes `value_fn(group_index, value)` for each non null, non
/// filtered value of `value`, while tracking which groups have
/// seen null inputs and which groups have seen any inputs if necessary
//
/// Invokes `value_fn(group_index, was_valid, new_value)` for each non-null,
/// non-filtered new value, while tracking which groups have seen null inputs
/// and which groups have seen any inputs if necessary.
///
/// `was_valid` indicates whether the group has seen other values previously.
///
/// # Arguments:
///
/// * `values`: the input arguments to the accumulator
/// * `group_indices`: To which groups do the rows in `values` belong, (aka group_index)
/// * `opt_filter`: if present, only rows for which is Some(true) are included
/// * `value_fn`: function invoked for (group_index, value) where value is non null
/// * `values`: The input arguments to the accumulator.
/// * `group_indices`: To which groups do the rows in `values` belong, (aka `group_index`).
/// * `opt_filter`: If present, only rows for which is `Some(true)` are included.
/// * `value_fn`: Function invoked for `(group_index, was_valid, new_value)` where the new value is not null.
///
/// # Example
///
Expand All @@ -111,18 +113,20 @@ impl NullState {
/// group_indices values opt_filter
/// ```
///
/// In the example above, `value_fn` is invoked for each (group_index,
/// value) pair where `opt_filter[i]` is true and values is non null
/// In the example above, `value_fn` is invoked for each `(group_index, was_valid, new_value)`
/// tuple where `opt_filter[i]` is true and the new value is not null.
/// In the first invocation of `value_fn` per `group_index`, `was_value` will be `false`,
/// and in all subsequent invocations for the same group, it will be `true`.
///
/// ```text
/// value_fn(2, 200)
/// value_fn(0, 200)
/// value_fn(0, 300)
/// value_fn(2, false, 200)
/// value_fn(0, false, 200)
/// value_fn(0, true, 300)
/// ```
///
/// It also sets
///
/// 1. `self.seen_values[group_index]` to true for all rows that had a non null vale
/// 1. `self.seen_values[group_index]` to true for all rows that had a non-null value.
pub fn accumulate<T, F>(
&mut self,
group_indices: &[usize],
Expand All @@ -132,7 +136,7 @@ impl NullState {
mut value_fn: F,
) where
T: ArrowPrimitiveType + Send,
F: FnMut(usize, T::Native) + Send,
F: FnMut(usize, bool, T::Native) + Send,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we please update the documentation to reflect this new argument and explains what it means

{
let data: &[T::Native] = values.values();
assert_eq!(data.len(), group_indices.len());
Expand All @@ -147,8 +151,9 @@ impl NullState {
(false, None) => {
let iter = group_indices.iter().zip(data.iter());
for (&group_index, &new_value) in iter {
let was_valid = seen_values.get_bit(group_index);
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value);
value_fn(group_index, was_valid, new_value);
}
}
// nulls, no filter
Expand All @@ -174,8 +179,9 @@ impl NullState {
// valid bit was set, real value
let is_valid = (mask & index_mask) != 0;
if is_valid {
let was_valid = seen_values.get_bit(group_index);
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value);
value_fn(group_index, was_valid, new_value);
}
index_mask <<= 1;
},
Expand All @@ -191,8 +197,9 @@ impl NullState {
.for_each(|(i, (&group_index, &new_value))| {
let is_valid = remainder_bits & (1 << i) != 0;
if is_valid {
let was_valid = seen_values.get_bit(group_index);
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value);
value_fn(group_index, was_valid, new_value);
}
});
}
Expand All @@ -208,8 +215,9 @@ impl NullState {
.zip(filter.iter())
.for_each(|((&group_index, &new_value), filter_value)| {
if let Some(true) = filter_value {
let was_valid = seen_values.get_bit(group_index);
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value);
value_fn(group_index, was_valid, new_value);
}
})
}
Expand All @@ -226,25 +234,22 @@ impl NullState {
.for_each(|((filter_value, &group_index), new_value)| {
if let Some(true) = filter_value {
if let Some(new_value) = new_value {
let was_valid = seen_values.get_bit(group_index);
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value)
value_fn(group_index, was_valid, new_value)
}
}
})
}
}
}

/// Invokes `value_fn(group_index, value)` for each non null, non
/// filtered value in `values`, while tracking which groups have
/// seen null inputs and which groups have seen any inputs, for
/// [`BooleanArray`]s.
///
/// Since `BooleanArray` is not a [`PrimitiveArray`] it must be
/// handled specially.
/// Invokes `value_fn(group_index, was_valid, new_value)` for each non-null,
/// non-filtered value in `values`, while tracking which groups have seen null inputs
/// and which groups have seen any inputs, for [`BooleanArray`]s.
///
/// See [`Self::accumulate`], which handles `PrimitiveArray`s, for
/// more details on other arguments.
/// Since `BooleanArray` is not a [`PrimitiveArray`] it must be handled specially.
/// See [`Self::accumulate`], which handles `PrimitiveArray`s, for more details.
pub fn accumulate_boolean<F>(
&mut self,
group_indices: &[usize],
Expand All @@ -253,7 +258,7 @@ impl NullState {
total_num_groups: usize,
mut value_fn: F,
) where
F: FnMut(usize, bool) + Send,
F: FnMut(usize, bool, bool) + Send,
{
let data = values.values();
assert_eq!(data.len(), group_indices.len());
Expand All @@ -271,8 +276,9 @@ impl NullState {
// buffer is big enough (start everything at valid)
group_indices.iter().zip(data.iter()).for_each(
|(&group_index, new_value)| {
let was_valid = seen_values.get_bit(group_index);
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value)
value_fn(group_index, was_valid, new_value)
},
)
}
Expand All @@ -285,8 +291,9 @@ impl NullState {
.zip(nulls.iter())
.for_each(|((&group_index, new_value), is_valid)| {
if is_valid {
let was_valid = seen_values.get_bit(group_index);
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value);
value_fn(group_index, was_valid, new_value);
}
})
}
Expand All @@ -300,8 +307,9 @@ impl NullState {
.zip(filter.iter())
.for_each(|((&group_index, new_value), filter_value)| {
if let Some(true) = filter_value {
let was_valid = seen_values.get_bit(group_index);
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value);
value_fn(group_index, was_valid, new_value);
}
})
}
Expand All @@ -315,8 +323,9 @@ impl NullState {
.for_each(|((filter_value, &group_index), new_value)| {
if let Some(true) = filter_value {
if let Some(new_value) = new_value {
let was_valid = seen_values.get_bit(group_index);
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value)
value_fn(group_index, was_valid, new_value)
}
}
})
Expand Down Expand Up @@ -352,14 +361,13 @@ impl NullState {
}

/// This function is called to update the accumulator state per row
/// when the value is not needed (e.g. COUNT)
/// when the value is not needed (e.g. `COUNT`).
///
/// `F`: Invoked like `value_fn(group_index) for all non null values
/// `F`: Invoked like `value_fn(group_index)` for all non-null values
/// passing the filter. Note that no tracking is done for null inputs
/// or which groups have seen any values
/// or which groups have seen any values.
///
/// See [`NullState::accumulate`], for more details on other
/// arguments.
/// See [`NullState::accumulate`], for more details.
pub fn accumulate_indices<F>(
group_indices: &[usize],
nulls: Option<&NullBuffer>,
Expand Down Expand Up @@ -669,8 +677,8 @@ mod test {
values,
opt_filter,
total_num_groups,
|group_index, value| {
accumulated_values.push((group_index, value));
|group_index, was_valid, value| {
accumulated_values.push((group_index, was_valid, value));
},
);

Expand All @@ -682,8 +690,9 @@ mod test {
None => group_indices.iter().zip(values.iter()).for_each(
|(&group_index, value)| {
if let Some(value) = value {
let was_valid = mock.expected_seen(group_index);
mock.saw_value(group_index);
expected_values.push((group_index, value));
expected_values.push((group_index, was_valid, value));
}
},
),
Expand All @@ -696,8 +705,9 @@ mod test {
// if value passed filter
if let Some(true) = is_included {
if let Some(value) = value {
let was_valid = mock.expected_seen(group_index);
mock.saw_value(group_index);
expected_values.push((group_index, value));
expected_values.push((group_index, was_valid, value));
}
}
});
Expand Down Expand Up @@ -785,8 +795,8 @@ mod test {
values,
opt_filter,
total_num_groups,
|group_index, value| {
accumulated_values.push((group_index, value));
|group_index, was_valid, value| {
accumulated_values.push((group_index, was_valid, value));
},
);

Expand All @@ -798,8 +808,9 @@ mod test {
None => group_indices.iter().zip(values.iter()).for_each(
|(&group_index, value)| {
if let Some(value) = value {
let was_valid = mock.expected_seen(group_index);
mock.saw_value(group_index);
expected_values.push((group_index, value));
expected_values.push((group_index, was_valid, value));
}
},
),
Expand All @@ -812,8 +823,9 @@ mod test {
// if value passed filter
if let Some(true) = is_included {
if let Some(value) = value {
let was_valid = mock.expected_seen(group_index);
mock.saw_value(group_index);
expected_values.push((group_index, value));
expected_values.push((group_index, was_valid, value));
}
}
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,14 @@ where
values,
opt_filter,
total_num_groups,
|group_index, new_value| {
let current_value = self.values.get_bit(group_index);
let value = (self.bool_fn)(current_value, new_value);
self.values.set_bit(group_index, value);
|group_index, was_valid, new_value| {
if was_valid {
let current_value = self.values.get_bit(group_index);
let value = (self.bool_fn)(current_value, new_value);
self.values.set_bit(group_index, value)
} else {
self.values.set_bit(group_index, new_value)
}
},
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,13 @@ where
values,
opt_filter,
total_num_groups,
|group_index, new_value| {
let value = &mut self.values[group_index];
(self.prim_fn)(value, new_value);
|group_index, was_valid, new_value| {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to add a test that covered this, if possible

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the existing implementation have to change, or do they?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps not, although conceptually an accumulator is just a map-reduce-map via some monoid or semigroup. The current implementation supports only monoids (need an empty value) but it doesn't support semigroups (no empty value). We could use this for something like FirstValue or AnyValue but I guess we could also implement it specifically for that use case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could even drop the initial value parameter since it doesn't matter in this case.

if was_valid {
let value = &mut self.values[group_index];
(self.prim_fn)(value, new_value)
} else {
self.values[group_index] = new_value
}
},
);

Expand Down