From 1ebde3c2392257101ead1fbe04dbcec51fb4fc44 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Fri, 2 Jun 2023 08:45:08 +0100 Subject: [PATCH] Fix MutableArrayData::extend_nulls (#1230) --- arrow-data/src/transform/mod.rs | 56 ++++++++++++++++++++------------- arrow/tests/array_transform.rs | 23 ++++++++++++++ 2 files changed, 58 insertions(+), 21 deletions(-) diff --git a/arrow-data/src/transform/mod.rs b/arrow-data/src/transform/mod.rs index c7487507223..f4b2b46d172 100644 --- a/arrow-data/src/transform/mod.rs +++ b/arrow-data/src/transform/mod.rs @@ -53,7 +53,7 @@ struct _MutableArrayData<'a> { pub null_count: usize, pub len: usize, - pub null_buffer: MutableBuffer, + pub null_buffer: Option, // arrow specification only allows up to 3 buffers (2 ignoring the nulls above). // Thus, we place them in the stack to avoid bound checks and greater data locality. @@ -63,6 +63,12 @@ struct _MutableArrayData<'a> { } impl<'a> _MutableArrayData<'a> { + fn null_buffer(&mut self) -> &mut MutableBuffer { + self.null_buffer + .as_mut() + .expect("MutableArrayData not nullable") + } + fn freeze(self, dictionary: Option) -> ArrayDataBuilder { let buffers = into_buffers(&self.data_type, self.buffer1, self.buffer2); @@ -77,10 +83,13 @@ impl<'a> _MutableArrayData<'a> { } }; - let nulls = (self.null_count > 0).then(|| { - let bools = BooleanBuffer::new(self.null_buffer.into(), 0, self.len); - unsafe { NullBuffer::new_unchecked(bools, self.null_count) } - }); + let nulls = self + .null_buffer + .map(|nulls| { + let bools = BooleanBuffer::new(nulls.into(), 0, self.len); + unsafe { NullBuffer::new_unchecked(bools, self.null_count) } + }) + .filter(|n| n.null_count() > 0); ArrayDataBuilder::new(self.data_type) .offset(0) @@ -95,22 +104,25 @@ fn build_extend_null_bits(array: &ArrayData, use_nulls: bool) -> ExtendNullBits if let Some(nulls) = array.nulls() { let bytes = nulls.validity(); Box::new(move |mutable, start, len| { - utils::resize_for_bits(&mut mutable.null_buffer, mutable.len + len); + let mutable_len = mutable.len; + let out = mutable.null_buffer(); + utils::resize_for_bits(out, mutable_len + len); mutable.null_count += set_bits( - mutable.null_buffer.as_slice_mut(), + out.as_slice_mut(), bytes, - mutable.len, + mutable_len, nulls.offset() + start, len, ); }) } else if use_nulls { Box::new(|mutable, _, len| { - utils::resize_for_bits(&mut mutable.null_buffer, mutable.len + len); - let write_data = mutable.null_buffer.as_slice_mut(); - let offset = mutable.len; + let mutable_len = mutable.len; + let out = mutable.null_buffer(); + utils::resize_for_bits(out, mutable_len + len); + let write_data = out.as_slice_mut(); (0..len).for_each(|i| { - bit_util::set_bit(write_data, offset + i); + bit_util::set_bit(write_data, mutable_len + i); }); }) } else { @@ -555,13 +567,10 @@ impl<'a> MutableArrayData<'a> { .map(|array| build_extend_null_bits(array, use_nulls)) .collect(); - let null_buffer = if use_nulls { + let null_buffer = use_nulls.then(|| { let null_bytes = bit_util::ceil(array_capacity, 8); MutableBuffer::from_len_zeroed(null_bytes) - } else { - // create 0 capacity mutable buffer with the intention that it won't be used - MutableBuffer::with_capacity(0) - }; + }); let extend_values = match &data_type { DataType::Dictionary(_, _) => { @@ -624,13 +633,18 @@ impl<'a> MutableArrayData<'a> { } /// Extends this [MutableArrayData] with null elements, disregarding the bound arrays + /// + /// # Panics + /// + /// Panics if [`MutableArrayData`] not created with `use_nulls` or nullable source arrays + /// pub fn extend_nulls(&mut self, len: usize) { - // TODO: null_buffer should probably be extended here as well - // otherwise is_valid() could later panic - // add test to confirm + self.data.len += len; + let bit_len = bit_util::ceil(self.data.len, 8); + let nulls = self.data.null_buffer(); + nulls.resize(bit_len, 0); self.data.null_count += len; (self.extend_nulls)(&mut self.data, len); - self.data.len += len; } /// Returns the current length diff --git a/arrow/tests/array_transform.rs b/arrow/tests/array_transform.rs index 40938c80f4c..ebbadc00aec 100644 --- a/arrow/tests/array_transform.rs +++ b/arrow/tests/array_transform.rs @@ -922,6 +922,29 @@ fn test_fixed_size_binary_append() { assert_eq!(result, expected); } +#[test] +fn test_extend_nulls() { + let int = Int32Array::from(vec![1, 2, 3, 4]).into_data(); + let mut mutable = MutableArrayData::new(vec![&int], true, 4); + mutable.extend(0, 2, 3); + mutable.extend_nulls(2); + + let data = mutable.freeze(); + data.validate_full().unwrap(); + let out = Int32Array::from(data); + + assert_eq!(out.null_count(), 2); + assert_eq!(out.iter().collect::>(), vec![Some(3), None, None]); +} + +#[test] +#[should_panic(expected = "MutableArrayData not nullable")] +fn test_extend_nulls_panic() { + let int = Int32Array::from(vec![1, 2, 3, 4]).into_data(); + let mut mutable = MutableArrayData::new(vec![&int], false, 4); + mutable.extend_nulls(2); +} + /* // this is an old test used on a meanwhile removed dead code // that is still useful when `MutableArrayData` supports fixed-size lists.