diff --git a/arrow-buffer/src/buffer/immutable.rs b/arrow-buffer/src/buffer/immutable.rs index c53ef18ba58..f60536fc2a9 100644 --- a/arrow-buffer/src/buffer/immutable.rs +++ b/arrow-buffer/src/buffer/immutable.rs @@ -99,6 +99,12 @@ impl Buffer { buffer.into() } + /// Initializes a empty [Buffer]. + pub fn new_empty() -> Self { + let empty_slice: &[u8] = &[]; + Self::from_slice_ref(empty_slice) + } + /// Creates a buffer from an existing aligned memory region (must already be byte-aligned), this /// `Buffer` will free this piece of memory when dropped. /// diff --git a/arrow-select/src/filter.rs b/arrow-select/src/filter.rs index 65ccbe1e01a..e5c6360d3fd 100644 --- a/arrow-select/src/filter.rs +++ b/arrow-select/src/filter.rs @@ -30,7 +30,7 @@ use arrow_buffer::{bit_util, ArrowNativeType, BooleanBuffer, NullBuffer, RunEndB use arrow_buffer::{Buffer, MutableBuffer}; use arrow_data::bit_iterator::{BitIndexIterator, BitSliceIterator}; use arrow_data::transform::MutableArrayData; -use arrow_data::{ArrayData, ArrayDataBuilder}; +use arrow_data::{ArrayData, ArrayDataBuilder, ByteView}; use arrow_schema::*; /// If the filter selects more than this fraction of rows, use @@ -551,6 +551,58 @@ fn filter_native(values: &[T], predicate: &FilterPredicate) buffer.into() } +#[inline(never)] +fn filter_native_for_byte_view( + values: &[u128], + predicate: &FilterPredicate, + mut for_left_view: F, +) -> Buffer { + assert!(values.len() >= predicate.filter.len()); + + let buffer = match &predicate.strategy { + IterationStrategy::SlicesIterator => { + let mut buffer = MutableBuffer::with_capacity(predicate.count * u128::get_byte_width()); + for (start, end) in SlicesIterator::new(&predicate.filter) { + let left_views = &values[start..end]; + left_views.iter().for_each(&mut for_left_view); + buffer.extend_from_slice(left_views); + } + buffer + } + IterationStrategy::Slices(slices) => { + let mut buffer = MutableBuffer::with_capacity(predicate.count * u128::get_byte_width()); + for (start, end) in slices { + let left_views = &values[*start..*end]; + left_views.iter().for_each(&mut for_left_view); + buffer.extend_from_slice(left_views); + } + buffer + } + IterationStrategy::IndexIterator => { + let wrap_map_lambda = |x: usize| { + for_left_view(&values[x]); + values[x] + }; + let iter = IndexIterator::new(&predicate.filter, predicate.count).map(wrap_map_lambda); + // SAFETY: IndexIterator is trusted length + unsafe { MutableBuffer::from_trusted_len_iter(iter) } + } + IterationStrategy::Indices(indices) => { + let wrap_map_lambda = |x: &usize| { + for_left_view(&values[*x]); + values[*x] + }; + let iter = indices.iter().map(wrap_map_lambda); + + // SAFETY: `Vec::iter` is trusted length + unsafe { MutableBuffer::from_trusted_len_iter(iter) } + } + IterationStrategy::All | IterationStrategy::None => unreachable!(), + }; + + buffer.into() +} + /// `filter` implementation for primitive arrays pub(crate) fn filter_primitive( array: &PrimitiveArray, @@ -693,12 +745,41 @@ fn filter_byte_view( array: &GenericByteViewArray, predicate: &FilterPredicate, ) -> GenericByteViewArray { - let new_view_buffer = filter_native(array.views(), predicate); + let mut left_buffers = vec![false; array.data_buffers().len()]; + let get_left_buffer = |view: &u128| { + let byte_view = ByteView::from(*view); + if byte_view.length <= 12 { + return; + } + let buffer_index = byte_view.buffer_index as usize; + // SAFETY: left_buffers are initialized to the same length as data_buffers, so the index in the view is always valid. + unsafe { + *left_buffers.get_unchecked_mut(buffer_index) = true; + }; + }; + let new_view_buffer = filter_native_for_byte_view(array.views(), predicate, get_left_buffer); + + let new_buffers = array.data_buffers().to_vec(); + assert!(left_buffers.len() == new_buffers.len()); + // Replace the buffers marked as false with empty ones to preserve the mapping relationship between views and buffers. + let new_buffers = new_buffers + .into_iter() + .zip(left_buffers) + .map( + |(buffer, left)| { + if left { + buffer + } else { + Buffer::new_empty() + } + }, + ) + .collect(); let mut builder = ArrayDataBuilder::new(T::DATA_TYPE) .len(predicate.count) .add_buffer(new_view_buffer) - .add_buffers(array.data_buffers().to_vec()); + .add_buffers(new_buffers); if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) { builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));