Skip to content

Commit

Permalink
Fix TakeFn on sliced Bitpacked array (#775)
Browse files Browse the repository at this point in the history
  • Loading branch information
robert3005 committed Sep 10, 2024
1 parent bc295db commit dbaf477
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 24 deletions.
34 changes: 24 additions & 10 deletions encodings/fastlanes/src/bitpacking/compute/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ fn take_primitive<T: NativePType + BitPacking>(
indices
.maybe_null_slice::<$P>()
.iter()
.chunk_by(|idx| (**idx / 1024) as usize)
.map(|i| *i as usize + array.offset())
.chunk_by(|idx| idx / 1024)
.into_iter()
.map(|(k, g)| (k, g.map(|idx| (*idx % 1024) as u16).collect()))
.map(|(k, g)| (k, g.map(|idx| (idx % 1024) as u16).collect()))
.collect()
});

Expand Down Expand Up @@ -133,10 +134,10 @@ mod test {
use itertools::Itertools;
use rand::distributions::Uniform;
use rand::{thread_rng, Rng};
use vortex::array::{Primitive, PrimitiveArray, SparseArray};
use vortex::compute::take;
use vortex::array::{PrimitiveArray, SparseArray};
use vortex::compute::unary::scalar_at;
use vortex::{ArrayDef, IntoArray, IntoArrayVariant};
use vortex::compute::{slice, take};
use vortex::{IntoArray, IntoArrayVariant};

use crate::BitPackedArray;

Expand All @@ -146,17 +147,30 @@ mod test {

// Create a u8 array modulo 63.
let unpacked = PrimitiveArray::from((0..4096).map(|i| (i % 63) as u8).collect::<Vec<_>>());

let bitpacked = BitPackedArray::encode(unpacked.array(), 6).unwrap();

let result = take(bitpacked.array(), &indices).unwrap();
assert_eq!(result.encoding().id(), Primitive::ID);

let primitive_result = result.into_primitive().unwrap();
let primitive_result = take(bitpacked.array(), &indices)
.unwrap()
.into_primitive()
.unwrap();
let res_bytes = primitive_result.maybe_null_slice::<u8>();
assert_eq!(res_bytes, &[0, 62, 31, 33, 9, 18]);
}

#[test]
fn take_sliced_indices() {
let indices = PrimitiveArray::from(vec![1919, 1921]).into_array();

// Create a u8 array modulo 63.
let unpacked = PrimitiveArray::from((0..4096).map(|i| (i % 63) as u8).collect::<Vec<_>>());
let bitpacked = BitPackedArray::encode(unpacked.array(), 6).unwrap();
let sliced = slice(bitpacked.array(), 128, 2050).unwrap();

let primitive_result = take(&sliced, &indices).unwrap().into_primitive().unwrap();
let res_bytes = primitive_result.maybe_null_slice::<u8>();
assert_eq!(res_bytes, &[31, 33]);
}

#[test]
#[cfg_attr(miri, ignore)] // This test is too slow on miri
fn take_random_indices() {
Expand Down
8 changes: 4 additions & 4 deletions encodings/runend/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,21 +64,21 @@ impl TakeFn for RunEndArray {
}
Validity::Array(original_validity) => {
let dense_validity = take(&original_validity, indices)?;
let filtered_values = filter(&dense_values, &dense_validity)?;
let length = dense_validity.len();
let dense_nonnull_indices = PrimitiveArray::from(
dense_validity
.clone()
.into_bool()?
.boolean_buffer()
.set_indices()
.map(|idx| idx as u64)
.collect::<Vec<u64>>(),
.collect::<Vec<_>>(),
)
.into_array();
let length = dense_validity.len();

SparseArray::try_new(
dense_nonnull_indices,
filter(&dense_values, &dense_validity)?,
filtered_values,
length,
Scalar::null(self.dtype().clone()),
)?
Expand Down
27 changes: 17 additions & 10 deletions fuzz/fuzz_targets/fuzz_target_1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,27 @@ use vortex_scalar::{PValue, Scalar, ScalarValue};
fuzz_target!(|fuzz_action: FuzzArrayAction| -> Corpus {
let FuzzArrayAction { array, actions } = fuzz_action;
let mut current_array = array.clone();
for (action, expected) in actions {
for (i, (action, expected)) in actions.into_iter().enumerate() {
match action {
Action::Compress(c) => {
match fuzz_compress(&current_array.into_canonical().unwrap().into(), &c) {
Some(compressed_array) => {
assert_array_eq(&expected.array(), &compressed_array);
assert_array_eq(&expected.array(), &compressed_array, i);
current_array = compressed_array;
}
None => return Corpus::Reject,
}
}
Action::Slice(range) => {
current_array = slice(&current_array, range.start, range.end).unwrap();
assert_array_eq(&expected.array(), &current_array);
assert_array_eq(&expected.array(), &current_array, i);
}
Action::Take(indices) => {
if indices.is_empty() {
return Corpus::Reject;
}
current_array = take(&current_array, &indices).unwrap();
assert_array_eq(&expected.array(), &current_array);
assert_array_eq(&expected.array(), &current_array, i);
}
Action::SearchSorted(s, side) => {
// TODO(robert): Ideally we'd preserve the encoding perfectly but this is close enough
Expand All @@ -51,7 +51,7 @@ fuzz_target!(|fuzz_action: FuzzArrayAction| -> Corpus {
sorted =
fuzz_compress(&sorted, &SamplingCompressor::default()).unwrap_or(sorted);
}
assert_search_sorted(sorted, s, side, expected.search())
assert_search_sorted(sorted, s, side, expected.search(), i)
}
}
}
Expand All @@ -67,23 +67,29 @@ fn fuzz_compress(array: &Array, compressor: &SamplingCompressor) -> Option<Array
.then(|| compressed_array.into_array())
}

fn assert_search_sorted(array: Array, s: Scalar, side: SearchSortedSide, expected: SearchResult) {
fn assert_search_sorted(
array: Array,
s: Scalar,
side: SearchSortedSide,
expected: SearchResult,
step: usize,
) {
let search_result = search_sorted(&array, s.clone(), side).unwrap();
assert_eq!(
search_result,
expected,
"Expected to find {s} at {expected} in ({}) but instead found it at {search_result}",
"Expected to find {s} at {expected} in ({}) but instead found it at {search_result} in step {step}",
array.encoding().id()
);
}

fn assert_array_eq(lhs: &Array, rhs: &Array) {
fn assert_array_eq(lhs: &Array, rhs: &Array, step: usize) {
assert_eq!(lhs.len(), rhs.len());
for idx in 0..lhs.len() {
let l = scalar_at(lhs, idx).unwrap();
let r = scalar_at(rhs, idx).unwrap();

fuzzing_scalar_cmp(l, r, lhs.encoding().id(), rhs.encoding().id(), idx);
fuzzing_scalar_cmp(l, r, lhs.encoding().id(), rhs.encoding().id(), idx, step);
}
}

Expand All @@ -93,6 +99,7 @@ fn fuzzing_scalar_cmp(
lhs_encoding: EncodingId,
rhs_encoding: EncodingId,
idx: usize,
step: usize,
) {
let equal_values = match (l.value(), r.value()) {
(ScalarValue::Primitive(l), ScalarValue::Primitive(r))
Expand All @@ -110,7 +117,7 @@ fn fuzzing_scalar_cmp(

assert!(
equal_values,
"{l} != {r} at index {idx}, lhs is {lhs_encoding} rhs is {rhs_encoding}",
"{l} != {r} at index {idx}, lhs is {lhs_encoding} rhs is {rhs_encoding} in step {step}",
);
assert_eq!(l.is_valid(), r.is_valid());
}

0 comments on commit dbaf477

Please sign in to comment.