diff --git a/encodings/fastlanes/src/bitpacking/compute/search_sorted.rs b/encodings/fastlanes/src/bitpacking/compute/search_sorted.rs index 772ea0692..9a33fe158 100644 --- a/encodings/fastlanes/src/bitpacking/compute/search_sorted.rs +++ b/encodings/fastlanes/src/bitpacking/compute/search_sorted.rs @@ -2,13 +2,15 @@ use std::cmp::Ordering; use std::cmp::Ordering::Greater; use fastlanes::BitPacking; +use num_traits::AsPrimitive; use vortex::array::{PrimitiveArray, SparseArray}; use vortex::compute::{ search_sorted, IndexOrd, Len, SearchResult, SearchSorted, SearchSortedFn, SearchSortedSide, }; +use vortex::validity::Validity; use vortex::{ArrayDType, IntoArrayVariant}; use vortex_dtype::{match_each_unsigned_integer_ptype, NativePType}; -use vortex_error::VortexResult; +use vortex_error::{VortexError, VortexResult}; use vortex_scalar::Scalar; use crate::{unpack_single_primitive, BitPackedArray}; @@ -16,20 +18,41 @@ use crate::{unpack_single_primitive, BitPackedArray}; impl SearchSortedFn for BitPackedArray { fn search_sorted(&self, value: &Scalar, side: SearchSortedSide) -> VortexResult { match_each_unsigned_integer_ptype!(self.ptype(), |$P| { - let unwrapped_value: $P = value.cast(self.dtype())?.try_into().unwrap(); - if let Some(patches_array) = self.patches() { - if unwrapped_value as usize >= self.max_packed_value() { - search_sorted(&patches_array, value.clone(), side) - } else { - Ok(SearchSorted::search_sorted(&BitPackedSearch::new(self), &unwrapped_value, side)) - } - } else { - Ok(SearchSorted::search_sorted(&BitPackedSearch::new(self), &unwrapped_value, side)) - } + search_sorted_typed::<$P>(self, value, side) }) } } +fn search_sorted_typed( + array: &BitPackedArray, + value: &Scalar, + side: SearchSortedSide, +) -> VortexResult +where + T: NativePType + TryFrom + BitPacking + AsPrimitive, +{ + let unwrapped_value: T = value.cast(array.dtype())?.try_into()?; + if let Some(patches_array) = array.patches() { + // If patches exist they must be the last elements in the array, if the value we're looking for is greater than + // max packed value just search the patches + if unwrapped_value.as_() > array.max_packed_value() { + search_sorted(&patches_array, value.clone(), side) + } else { + Ok(SearchSorted::search_sorted( + &BitPackedSearch::new(array), + &unwrapped_value, + side, + )) + } + } else { + Ok(SearchSorted::search_sorted( + &BitPackedSearch::new(array), + &unwrapped_value, + side, + )) + } +} + /// This wrapper exists, so that you can't invoke SearchSorted::search_sorted directly on BitPackedArray as it omits searching patches #[derive(Debug)] struct BitPackedSearch { @@ -38,6 +61,7 @@ struct BitPackedSearch { length: usize, bit_width: usize, min_patch_offset: Option, + validity: Validity, } impl BitPackedSearch { @@ -52,6 +76,7 @@ impl BitPackedSearch { .expect("Only Sparse patches are supported") .min_index() }), + validity: array.validity(), } } } @@ -63,6 +88,11 @@ impl IndexOrd for BitPackedSearch { return Some(Greater); } } + + if self.validity.is_null(idx) { + return Some(Greater); + } + // SAFETY: Used in search_sorted_by which ensures that idx is within bounds let val: T = unsafe { unpack_single_primitive( diff --git a/encodings/fastlanes/src/bitpacking/mod.rs b/encodings/fastlanes/src/bitpacking/mod.rs index faddbb314..8773fcf34 100644 --- a/encodings/fastlanes/src/bitpacking/mod.rs +++ b/encodings/fastlanes/src/bitpacking/mod.rs @@ -164,7 +164,7 @@ impl BitPackedArray { #[inline] pub fn max_packed_value(&self) -> usize { - 1 << self.bit_width() + (1 << self.bit_width()) - 1 } } diff --git a/encodings/fastlanes/src/for/compress.rs b/encodings/fastlanes/src/for/compress.rs index d8aca00ad..9ff6372e9 100644 --- a/encodings/fastlanes/src/for/compress.rs +++ b/encodings/fastlanes/src/for/compress.rs @@ -21,7 +21,7 @@ pub fn for_compress(array: &PrimitiveArray) -> VortexResult { if shift == <$T>::PTYPE.bit_width() as u8 { match array.validity().to_logical(array.len()) { LogicalValidity::AllValid(l) => { - ConstantArray::new(Scalar::zero::(array.dtype().nullability()), l).into_array() + ConstantArray::new(Scalar::zero::<$T>(array.dtype().nullability()), l).into_array() }, LogicalValidity::AllInvalid(l) => { ConstantArray::new(Scalar::null(array.dtype().clone()), l).into_array() @@ -66,12 +66,11 @@ fn compress_primitive( ) -> PrimitiveArray { assert!(shift < T::PTYPE.bit_width() as u8); let values = if shift > 0 { - let shifted_min = min >> shift as usize; parray .maybe_null_slice::() .iter() - .map(|&v| v >> shift as usize) - .map(|v| v.wrapping_sub(&shifted_min)) + .map(|&v| v.wrapping_sub(&min)) + .map(|v| v >> shift as usize) .collect_vec() } else { parray @@ -90,9 +89,9 @@ pub fn decompress(array: FoRArray) -> VortexResult { let encoded = array.encoded().into_primitive()?.reinterpret_cast(ptype); let validity = encoded.validity(); Ok(match_each_integer_ptype!(ptype, |$T| { - let reference: $T = array.reference().try_into()?; + let min: $T = array.reference().try_into()?; PrimitiveArray::from_vec( - decompress_primitive(encoded.into_maybe_null_slice::<$T>(), reference, shift), + decompress_primitive(encoded.into_maybe_null_slice::<$T>(), min, shift), validity, ) })) @@ -100,19 +99,19 @@ pub fn decompress(array: FoRArray) -> VortexResult { fn decompress_primitive( values: Vec, - reference: T, + min: T, shift: usize, ) -> Vec { if shift > 0 { values .into_iter() .map(|v| v << shift) - .map(|v| v.wrapping_add(&reference)) + .map(|v| v.wrapping_add(&min)) .collect_vec() } else { values .into_iter() - .map(|v| v.wrapping_add(&reference)) + .map(|v| v.wrapping_add(&min)) .collect_vec() } } diff --git a/encodings/fastlanes/src/for/compute.rs b/encodings/fastlanes/src/for/compute.rs index 5b1566b68..4172f9dcf 100644 --- a/encodings/fastlanes/src/for/compute.rs +++ b/encodings/fastlanes/src/for/compute.rs @@ -1,12 +1,15 @@ +use std::ops::{AddAssign, Shl, Shr}; + +use num_traits::{WrappingAdd, WrappingSub}; use vortex::compute::unary::{scalar_at_unchecked, ScalarAtFn}; use vortex::compute::{ search_sorted, slice, take, ArrayCompute, SearchResult, SearchSortedFn, SearchSortedSide, SliceFn, TakeFn, }; use vortex::{Array, ArrayDType, IntoArray}; -use vortex_dtype::match_each_integer_ptype; -use vortex_error::VortexResult; -use vortex_scalar::{PrimitiveScalar, Scalar}; +use vortex_dtype::{match_each_integer_ptype, NativePType}; +use vortex_error::{VortexError, VortexResult}; +use vortex_scalar::{PValue, PrimitiveScalar, Scalar}; use crate::FoRArray; @@ -51,7 +54,6 @@ impl ScalarAtFn for FoRArray { let reference = PrimitiveScalar::try_from(self.reference()).unwrap(); match_each_integer_ptype!(encoded.ptype(), |$P| { - use num_traits::WrappingAdd; encoded.typed_value::<$P>().map(|v| (v << self.shift()).wrapping_add(reference.typed_value::<$P>().unwrap())) .map(|v| Scalar::primitive::<$P>(v, encoded.dtype().nullability())) .unwrap_or_else(|| Scalar::null(encoded.dtype().clone())) @@ -73,32 +75,72 @@ impl SliceFn for FoRArray { impl SearchSortedFn for FoRArray { fn search_sorted(&self, value: &Scalar, side: SearchSortedSide) -> VortexResult { match_each_integer_ptype!(self.ptype(), |$P| { - let min: $P = self.reference().try_into().unwrap(); - let shifted_min = min >> self.shift(); - let unwrapped_value: $P = value.cast(self.dtype())?.try_into().unwrap(); - let shifted_value: $P = unwrapped_value >> self.shift(); - // Make sure that smaller values are still smaller and not larger than (which they would be after wrapping_sub) - if shifted_value < shifted_min { - return Ok(SearchResult::NotFound(0)); - } - - let translated_scalar = Scalar::primitive( - shifted_value.wrapping_sub(shifted_min), - value.dtype().nullability(), - ) - .reinterpret_cast(self.ptype().to_unsigned()); - search_sorted(&self.encoded(), translated_scalar, side) + search_sorted_typed::<$P>(self, value, side) }) } } +fn search_sorted_typed( + array: &FoRArray, + value: &Scalar, + side: SearchSortedSide, +) -> VortexResult +where + T: NativePType + + for<'a> TryFrom<&'a Scalar, Error = VortexError> + + Shr + + Shl + + WrappingSub + + WrappingAdd + + AddAssign + + Into, +{ + let min: T = array.reference().try_into()?; + let primitive_value: T = value.cast(array.dtype())?.as_ref().try_into()?; + // Make sure that smaller values are still smaller and not larger than (which they would be after wrapping_sub) + if primitive_value < min { + return Ok(SearchResult::NotFound(0)); + } + + // When the values in the array are shifted, not all values in the domain are representable in the compressed + // space. Multiple different search values can translate to same value in the compressed space. + let encoded_value = primitive_value.wrapping_sub(&min) >> array.shift(); + let decoded_value = (encoded_value << array.shift()).wrapping_add(&min); + + // We first determine whether the value can be represented in the compressed array. For any value that is not + // representable, it is by definition NotFound. For NotFound values, the correct insertion index is by definition + // the same regardless of which side we search on. + // However, to correctly handle repeated values in the array, we need to search left on the next *representable* + // value (i.e., increment the translated value by 1). + let representable = decoded_value == primitive_value; + let (side, target) = if representable { + (side, encoded_value) + } else { + ( + SearchSortedSide::Left, + encoded_value.wrapping_add(&T::one()), + ) + }; + + let target_scalar = Scalar::primitive(target, value.dtype().nullability()) + .reinterpret_cast(array.ptype().to_unsigned()); + let search_result = search_sorted(&array.encoded(), target_scalar, side)?; + Ok( + if representable && matches!(search_result, SearchResult::Found(_)) { + search_result + } else { + SearchResult::NotFound(search_result.to_index()) + }, + ) +} + #[cfg(test)] mod test { use vortex::array::PrimitiveArray; use vortex::compute::unary::scalar_at; use vortex::compute::{search_sorted, SearchResult, SearchSortedSide}; - use crate::for_compress; + use crate::{for_compress, FoRArray}; #[test] fn for_scalar_at() { @@ -124,4 +166,86 @@ mod test { SearchResult::NotFound(0) ); } + + #[test] + fn search_with_shift_notfound() { + let for_arr = for_compress(&PrimitiveArray::from(vec![62, 114])).unwrap(); + assert_eq!( + search_sorted(&for_arr, 63, SearchSortedSide::Left).unwrap(), + SearchResult::NotFound(1) + ); + let for_arr = for_compress(&PrimitiveArray::from(vec![62, 114])).unwrap(); + assert_eq!( + search_sorted(&for_arr, 61, SearchSortedSide::Left).unwrap(), + SearchResult::NotFound(0) + ); + let for_arr = for_compress(&PrimitiveArray::from(vec![62, 114])).unwrap(); + assert_eq!( + search_sorted(&for_arr, 113, SearchSortedSide::Left).unwrap(), + SearchResult::NotFound(1) + ); + assert_eq!( + search_sorted(&for_arr, 115, SearchSortedSide::Left).unwrap(), + SearchResult::NotFound(2) + ); + } + + #[test] + fn search_with_shift_repeated() { + let arr = for_compress(&PrimitiveArray::from(vec![62, 62, 114, 114])).unwrap(); + let for_array = FoRArray::try_from(arr.clone()).unwrap(); + + let min: i32 = for_array.reference().try_into().unwrap(); + assert_eq!(min, 62); + assert_eq!(for_array.shift(), 1); + + assert_eq!( + search_sorted(&arr, 61, SearchSortedSide::Left).unwrap(), + SearchResult::NotFound(0) + ); + assert_eq!( + search_sorted(&arr, 61, SearchSortedSide::Right).unwrap(), + SearchResult::NotFound(0) + ); + assert_eq!( + search_sorted(&arr, 62, SearchSortedSide::Left).unwrap(), + SearchResult::Found(0) + ); + assert_eq!( + search_sorted(&arr, 62, SearchSortedSide::Right).unwrap(), + SearchResult::Found(2) + ); + assert_eq!( + search_sorted(&arr, 63, SearchSortedSide::Left).unwrap(), + SearchResult::NotFound(2) + ); + assert_eq!( + search_sorted(&arr, 63, SearchSortedSide::Right).unwrap(), + SearchResult::NotFound(2) + ); + assert_eq!( + search_sorted(&arr, 113, SearchSortedSide::Left).unwrap(), + SearchResult::NotFound(2) + ); + assert_eq!( + search_sorted(&arr, 113, SearchSortedSide::Right).unwrap(), + SearchResult::NotFound(2) + ); + assert_eq!( + search_sorted(&arr, 114, SearchSortedSide::Left).unwrap(), + SearchResult::Found(2) + ); + assert_eq!( + search_sorted(&arr, 114, SearchSortedSide::Right).unwrap(), + SearchResult::Found(4) + ); + assert_eq!( + search_sorted(&arr, 115, SearchSortedSide::Left).unwrap(), + SearchResult::NotFound(4) + ); + assert_eq!( + search_sorted(&arr, 115, SearchSortedSide::Right).unwrap(), + SearchResult::NotFound(4) + ); + } } diff --git a/fuzz/fuzz_targets/fuzz_target_1.rs b/fuzz/fuzz_targets/fuzz_target_1.rs index 53d086425..ddd723cfc 100644 --- a/fuzz/fuzz_targets/fuzz_target_1.rs +++ b/fuzz/fuzz_targets/fuzz_target_1.rs @@ -26,7 +26,7 @@ fuzz_target!(|fuzz_action: FuzzArrayAction| -> Corpus { Corpus::Keep } Action::SearchSorted(s, side) => { - if !array_is_sorted(&array).unwrap() { + if !array_is_sorted(&array).unwrap() || s.is_null() { return Corpus::Reject; } @@ -76,7 +76,7 @@ fn assert_take(original: &Array, taken: &Array, indices: &Array) { let o = scalar_at(original, to_take).unwrap(); let s = scalar_at(taken, idx).unwrap(); - fuzzing_scalar_cmp(o, s, original.encoding().id(), taken.encoding().id(), idx); + fuzzing_scalar_cmp(o, s, original.encoding().id(), indices.encoding().id(), idx); } } @@ -122,8 +122,7 @@ fn fuzzing_scalar_cmp( assert!( equal_values, - "{l} != {r} at index {idx}, lhs is {} rhs is {}", - lhs_encoding, rhs_encoding + "{l} != {r} at index {idx}, lhs is {lhs_encoding} rhs is {rhs_encoding}", ); assert_eq!(l.is_valid(), r.is_valid()); } diff --git a/fuzz/src/lib.rs b/fuzz/src/lib.rs index fe2473398..596153fc0 100644 --- a/fuzz/src/lib.rs +++ b/fuzz/src/lib.rs @@ -2,6 +2,7 @@ use std::fmt::Debug; use std::iter; use std::ops::Range; +use libfuzzer_sys::arbitrary::Error::EmptyChoose; use libfuzzer_sys::arbitrary::{Arbitrary, Result, Unstructured}; use vortex::array::PrimitiveArray; use vortex::compute::unary::scalar_at; @@ -37,6 +38,10 @@ impl<'a> Arbitrary<'a> for FuzzArrayAction { Action::Slice(start..stop) } 2 => { + if len == 0 { + return Err(EmptyChoose); + } + let indices = PrimitiveArray::from(random_vec_in_range(u, 0, len - 1)?).into(); let compressed = SamplingCompressor::default() .compress(&indices, None) diff --git a/vortex-array/src/array/primitive/compute/search_sorted.rs b/vortex-array/src/array/primitive/compute/search_sorted.rs index 2b0f89257..f63592427 100644 --- a/vortex-array/src/array/primitive/compute/search_sorted.rs +++ b/vortex-array/src/array/primitive/compute/search_sorted.rs @@ -1,19 +1,62 @@ -use vortex_dtype::match_each_native_ptype; +use std::cmp::Ordering; +use std::cmp::Ordering::Greater; + +use vortex_dtype::{match_each_native_ptype, NativePType}; use vortex_error::VortexResult; use vortex_scalar::Scalar; use crate::array::primitive::PrimitiveArray; -use crate::compute::{SearchResult, SearchSorted, SearchSortedFn, SearchSortedSide}; +use crate::compute::{IndexOrd, Len, SearchResult, SearchSorted, SearchSortedFn, SearchSortedSide}; +use crate::validity::Validity; impl SearchSortedFn for PrimitiveArray { fn search_sorted(&self, value: &Scalar, side: SearchSortedSide) -> VortexResult { match_each_native_ptype!(self.ptype(), |$T| { - let pvalue: $T = value.try_into()?; - Ok(self.maybe_null_slice::<$T>().search_sorted(&pvalue, side)) + match self.validity() { + Validity::NonNullable | Validity::AllValid => { + let pvalue: $T = value.try_into()?; + Ok(self.maybe_null_slice::<$T>().search_sorted(&pvalue, side)) + } + Validity::AllInvalid => Ok(SearchResult::NotFound(0)), + Validity::Array(_) => { + let pvalue: $T = value.try_into()?; + Ok(SearchSortedNullsLast::new(self).search_sorted(&pvalue, side)) + } + } }) } } +struct SearchSortedNullsLast<'a, T> { + values: &'a [T], + validity: Validity, +} + +impl<'a, T: NativePType> SearchSortedNullsLast<'a, T> { + pub fn new(array: &'a PrimitiveArray) -> Self { + Self { + values: array.maybe_null_slice(), + validity: array.validity(), + } + } +} + +impl<'a, T: NativePType> IndexOrd for SearchSortedNullsLast<'a, T> { + fn index_cmp(&self, idx: usize, elem: &T) -> Option { + if self.validity.is_null(idx) { + return Some(Greater); + } + + self.values.index_cmp(idx, elem) + } +} + +impl<'a, T> Len for SearchSortedNullsLast<'a, T> { + fn len(&self) -> usize { + self.values.len() + } +} + #[cfg(test)] mod test { use super::*; diff --git a/vortex-array/src/compute/search_sorted.rs b/vortex-array/src/compute/search_sorted.rs index 05e526941..5e33eda47 100644 --- a/vortex-array/src/compute/search_sorted.rs +++ b/vortex-array/src/compute/search_sorted.rs @@ -43,15 +43,11 @@ impl SearchResult { Self::NotFound(i) => i, } } - - pub fn map usize>(self, f: F) -> Self { - match self { - Self::Found(i) => Self::Found(f(i)), - Self::NotFound(i) => Self::NotFound(f(i)), - } - } } +/// Searches for value assuming the array is sorted. +/// +/// For nullable arrays we assume that the nulls are sorted last, i.e. they're the greatest value pub trait SearchSortedFn { fn search_sorted(&self, value: &Scalar, side: SearchSortedSide) -> VortexResult; } @@ -62,6 +58,10 @@ pub fn search_sorted>( side: SearchSortedSide, ) -> VortexResult { let scalar = target.into().cast(array.dtype())?; + if scalar.is_null() { + vortex_bail!("Search sorted with null value is not supported"); + } + array.with_dyn(|a| { if let Some(search_sorted) = a.search_sorted() { return search_sorted.search_sorted(&scalar, side); diff --git a/vortex-scalar/src/value.rs b/vortex-scalar/src/value.rs index 78a4613c2..86d7b0c10 100644 --- a/vortex-scalar/src/value.rs +++ b/vortex-scalar/src/value.rs @@ -13,12 +13,14 @@ use crate::pvalue::PValue; /// cast on-read. #[derive(Debug, Clone, PartialEq, PartialOrd)] pub enum ScalarValue { - Null, Bool(bool), Primitive(PValue), Buffer(Buffer), BufferString(BufferString), List(Arc<[ScalarValue]>), + // It's significant that Null is last in this list. As a result generated PartialOrd sorts Scalar + // values such that Nulls are last (greatest) + Null, } impl ScalarValue {