From c320a17da7c3c0b942eb4b5ae79417d976ad569a Mon Sep 17 00:00:00 2001 From: jeadie Date: Wed, 19 Jun 2024 10:27:37 +1000 Subject: [PATCH 1/5] support LargeUtf8 --- crates/duckdb/src/vtab/arrow.rs | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/crates/duckdb/src/vtab/arrow.rs b/crates/duckdb/src/vtab/arrow.rs index 941c6ea9..b123979b 100644 --- a/crates/duckdb/src/vtab/arrow.rs +++ b/crates/duckdb/src/vtab/arrow.rs @@ -6,9 +6,7 @@ use std::ptr::null_mut; use crate::vtab::vector::Inserter; use arrow::array::{ - as_boolean_array, as_generic_binary_array, as_large_list_array, as_list_array, as_primitive_array, as_string_array, - as_struct_array, Array, ArrayData, AsArray, BinaryArray, BooleanArray, Decimal128Array, FixedSizeListArray, - GenericListArray, OffsetSizeTrait, PrimitiveArray, StringArray, StructArray, + as_boolean_array, as_generic_binary_array, as_large_list_array, as_list_array, as_primitive_array, as_string_array, as_struct_array, Array, ArrayData, AsArray, BinaryArray, BooleanArray, Decimal128Array, FixedSizeListArray, GenericListArray, GenericStringArray, LargeStringArray, OffsetSizeTrait, PrimitiveArray, StructArray }; use arrow::{ @@ -228,6 +226,12 @@ pub fn record_batch_to_duckdb_data_chunk( } DataType::Utf8 => { string_array_to_vector(as_string_array(col.as_ref()), &mut chunk.flat_vector(i)); + }, + DataType::LargeUtf8 => { + string_array_to_vector( + col.as_ref().as_any().downcast_ref::() .ok_or_else(|| Box::::from("Unable to downcast to LargeStringArray"))?, + &mut chunk.flat_vector(i) + ); } DataType::Binary => { binary_array_to_vector(as_generic_binary_array(col.as_ref()), &mut chunk.flat_vector(i)); @@ -453,7 +457,7 @@ fn boolean_array_to_vector(array: &BooleanArray, out: &mut FlatVector) { } } -fn string_array_to_vector(array: &StringArray, out: &mut FlatVector) { +fn string_array_to_vector(array: &GenericStringArray, out: &mut FlatVector) { assert!(array.len() <= out.capacity()); // TODO: zero copy assignment @@ -611,10 +615,7 @@ mod test { use crate::{Connection, Result}; use arrow::{ array::{ - Array, ArrayRef, AsArray, BinaryArray, Date32Array, Date64Array, Decimal128Array, Decimal256Array, - FixedSizeListArray, GenericListArray, Int32Array, ListArray, OffsetSizeTrait, PrimitiveArray, StringArray, - StructArray, Time32SecondArray, Time64MicrosecondArray, TimestampMicrosecondArray, - TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, + Array, ArrayRef, AsArray, BinaryArray, Date32Array, Date64Array, Decimal128Array, Decimal256Array, FixedSizeListArray, GenericListArray, Int32Array, LargeStringArray, ListArray, OffsetSizeTrait, PrimitiveArray, StringArray, StructArray, Time32SecondArray, Time64MicrosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray }, buffer::{OffsetBuffer, ScalarBuffer}, datatypes::{i256, ArrowPrimitiveType, DataType, Field, Fields, Schema}, @@ -862,6 +863,15 @@ mod test { Ok(()) } + + + #[test] + fn test_utf8_roundtrip() -> Result<(), Box> { + check_rust_primitive_array_roundtrip(StringArray::from(vec![Some("foo"), None, Some("bar")]), StringArray::from(vec![Some("foo"), None, Some("bar")]))?; + check_rust_primitive_array_roundtrip(LargeStringArray::from(vec![Some("foo"), None, Some("bar")]), LargeStringArray::from(vec![Some("foo"), None, Some("bar")]))?; + Ok(()) + } + #[test] fn test_timestamp_roundtrip() -> Result<(), Box> { check_rust_primitive_array_roundtrip(Int32Array::from(vec![1, 2, 3]), Int32Array::from(vec![1, 2, 3]))?; From 76cc14244512abc5fcba5c8c060fea2c2226cf9b Mon Sep 17 00:00:00 2001 From: jeadie Date: Wed, 19 Jun 2024 12:41:21 +1000 Subject: [PATCH 2/5] lint --- crates/duckdb/src/vtab/arrow.rs | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/crates/duckdb/src/vtab/arrow.rs b/crates/duckdb/src/vtab/arrow.rs index b123979b..78437d1e 100644 --- a/crates/duckdb/src/vtab/arrow.rs +++ b/crates/duckdb/src/vtab/arrow.rs @@ -6,7 +6,9 @@ use std::ptr::null_mut; use crate::vtab::vector::Inserter; use arrow::array::{ - as_boolean_array, as_generic_binary_array, as_large_list_array, as_list_array, as_primitive_array, as_string_array, as_struct_array, Array, ArrayData, AsArray, BinaryArray, BooleanArray, Decimal128Array, FixedSizeListArray, GenericListArray, GenericStringArray, LargeStringArray, OffsetSizeTrait, PrimitiveArray, StructArray + as_boolean_array, as_generic_binary_array, as_large_list_array, as_list_array, as_primitive_array, as_string_array, + as_struct_array, Array, ArrayData, AsArray, BinaryArray, BooleanArray, Decimal128Array, FixedSizeListArray, + GenericListArray, GenericStringArray, LargeStringArray, OffsetSizeTrait, PrimitiveArray, StructArray, }; use arrow::{ @@ -226,11 +228,14 @@ pub fn record_batch_to_duckdb_data_chunk( } DataType::Utf8 => { string_array_to_vector(as_string_array(col.as_ref()), &mut chunk.flat_vector(i)); - }, + } DataType::LargeUtf8 => { string_array_to_vector( - col.as_ref().as_any().downcast_ref::() .ok_or_else(|| Box::::from("Unable to downcast to LargeStringArray"))?, - &mut chunk.flat_vector(i) + col.as_ref() + .as_any() + .downcast_ref::() + .ok_or_else(|| Box::::from("Unable to downcast to LargeStringArray"))?, + &mut chunk.flat_vector(i), ); } DataType::Binary => { @@ -615,7 +620,10 @@ mod test { use crate::{Connection, Result}; use arrow::{ array::{ - Array, ArrayRef, AsArray, BinaryArray, Date32Array, Date64Array, Decimal128Array, Decimal256Array, FixedSizeListArray, GenericListArray, Int32Array, LargeStringArray, ListArray, OffsetSizeTrait, PrimitiveArray, StringArray, StructArray, Time32SecondArray, Time64MicrosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray + Array, ArrayRef, AsArray, BinaryArray, Date32Array, Date64Array, Decimal128Array, Decimal256Array, + FixedSizeListArray, GenericListArray, Int32Array, LargeStringArray, ListArray, OffsetSizeTrait, + PrimitiveArray, StringArray, StructArray, Time32SecondArray, Time64MicrosecondArray, + TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, }, buffer::{OffsetBuffer, ScalarBuffer}, datatypes::{i256, ArrowPrimitiveType, DataType, Field, Fields, Schema}, @@ -863,12 +871,16 @@ mod test { Ok(()) } - - #[test] fn test_utf8_roundtrip() -> Result<(), Box> { - check_rust_primitive_array_roundtrip(StringArray::from(vec![Some("foo"), None, Some("bar")]), StringArray::from(vec![Some("foo"), None, Some("bar")]))?; - check_rust_primitive_array_roundtrip(LargeStringArray::from(vec![Some("foo"), None, Some("bar")]), LargeStringArray::from(vec![Some("foo"), None, Some("bar")]))?; + check_rust_primitive_array_roundtrip( + StringArray::from(vec![Some("foo"), None, Some("bar")]), + StringArray::from(vec![Some("foo"), None, Some("bar")]), + )?; + check_rust_primitive_array_roundtrip( + LargeStringArray::from(vec![Some("foo"), None, Some("bar")]), + LargeStringArray::from(vec![Some("foo"), None, Some("bar")]), + )?; Ok(()) } From afbe620cd0d31eda81b7af1dc1b12f109262e40e Mon Sep 17 00:00:00 2001 From: jeadie Date: Wed, 19 Jun 2024 12:54:56 +1000 Subject: [PATCH 3/5] fix tests --- crates/duckdb/src/vtab/arrow.rs | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/crates/duckdb/src/vtab/arrow.rs b/crates/duckdb/src/vtab/arrow.rs index 78437d1e..e0f83c11 100644 --- a/crates/duckdb/src/vtab/arrow.rs +++ b/crates/duckdb/src/vtab/arrow.rs @@ -873,14 +873,8 @@ mod test { #[test] fn test_utf8_roundtrip() -> Result<(), Box> { - check_rust_primitive_array_roundtrip( - StringArray::from(vec![Some("foo"), None, Some("bar")]), - StringArray::from(vec![Some("foo"), None, Some("bar")]), - )?; - check_rust_primitive_array_roundtrip( - LargeStringArray::from(vec![Some("foo"), None, Some("bar")]), - LargeStringArray::from(vec![Some("foo"), None, Some("bar")]), - )?; + check_generic_array_roundtrip(StringArray::from(vec![Some("foo"), None, Some("bar")]))?; + check_generic_array_roundtrip(LargeStringArray::from(vec![Some("foo"), None, Some("bar")]))?; Ok(()) } From 06b79f4f6217a03332df043c554d003f089367f1 Mon Sep 17 00:00:00 2001 From: jeadie Date: Fri, 21 Jun 2024 13:21:03 +1000 Subject: [PATCH 4/5] Fix tests check_generic_byte_roundtrip --- crates/duckdb/src/vtab/arrow.rs | 60 ++++++++++++++++++++++++++++++--- 1 file changed, 55 insertions(+), 5 deletions(-) diff --git a/crates/duckdb/src/vtab/arrow.rs b/crates/duckdb/src/vtab/arrow.rs index e0f83c11..a7b64347 100644 --- a/crates/duckdb/src/vtab/arrow.rs +++ b/crates/duckdb/src/vtab/arrow.rs @@ -621,12 +621,12 @@ mod test { use arrow::{ array::{ Array, ArrayRef, AsArray, BinaryArray, Date32Array, Date64Array, Decimal128Array, Decimal256Array, - FixedSizeListArray, GenericListArray, Int32Array, LargeStringArray, ListArray, OffsetSizeTrait, - PrimitiveArray, StringArray, StructArray, Time32SecondArray, Time64MicrosecondArray, + FixedSizeListArray, GenericByteArray, GenericListArray, Int32Array, LargeStringArray, ListArray, + OffsetSizeTrait, PrimitiveArray, StringArray, Time32SecondArray, Time64MicrosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, }, buffer::{OffsetBuffer, ScalarBuffer}, - datatypes::{i256, ArrowPrimitiveType, DataType, Field, Fields, Schema}, + datatypes::{i256, ArrowPrimitiveType, ByteArrayType, DataType, Field, Schema}, record_batch::RecordBatch, }; use std::{error::Error, sync::Arc}; @@ -793,6 +793,48 @@ mod test { Ok(()) } + fn check_generic_byte_roundtrip( + arry_in: GenericByteArray, + arry_out: GenericByteArray, + ) -> Result<(), Box> + where + T1: ByteArrayType, + T2: ByteArrayType, + { + let db = Connection::open_in_memory()?; + db.register_table_function::("arrow")?; + + // Roundtrip a record batch from Rust to DuckDB and back to Rust + let schema = Schema::new(vec![Field::new("a", arry_in.data_type().clone(), false)]); + + let rb = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arry_in.clone())])?; + let param = arrow_recordbatch_to_query_params(rb); + let mut stmt = db.prepare("select a from arrow(?, ?)")?; + let rb = stmt.query_arrow(param)?.next().expect("no record batch"); + + let output_any_array = rb.column(0); + + assert!( + output_any_array.data_type().equals_datatype(arry_out.data_type()), + "{} != {}", + output_any_array.data_type(), + arry_out.data_type() + ); + + match output_any_array.as_bytes_opt::() { + Some(output_array) => { + assert_eq!(output_array.len(), arry_out.len()); + for i in 0..output_array.len() { + assert_eq!(output_array.is_valid(i), arry_out.is_valid(i)); + assert_eq!(output_array.value_data(), arry_out.value_data()) + } + } + None => panic!("Expected GenericByteArray"), + } + + Ok(()) + } + #[test] fn test_array_roundtrip() -> Result<(), Box> { check_generic_array_roundtrip(ListArray::new( @@ -873,8 +915,16 @@ mod test { #[test] fn test_utf8_roundtrip() -> Result<(), Box> { - check_generic_array_roundtrip(StringArray::from(vec![Some("foo"), None, Some("bar")]))?; - check_generic_array_roundtrip(LargeStringArray::from(vec![Some("foo"), None, Some("bar")]))?; + check_generic_byte_roundtrip( + StringArray::from(vec![Some("foo"), Some("Baz"), Some("bar")]), + StringArray::from(vec![Some("foo"), Some("Baz"), Some("bar")]), + )?; + + // [`LargeStringArray`] will be downcasted to [`StringArray`]. + check_generic_byte_roundtrip( + LargeStringArray::from(vec![Some("foo"), Some("Baz"), Some("bar")]), + StringArray::from(vec![Some("foo"), Some("Baz"), Some("bar")]), + )?; Ok(()) } From 98e70b739c10ec8780e5893499e3cfae7a0c3190 Mon Sep 17 00:00:00 2001 From: jeadie Date: Fri, 21 Jun 2024 13:28:18 +1000 Subject: [PATCH 5/5] fix test --- crates/duckdb/src/vtab/arrow.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/duckdb/src/vtab/arrow.rs b/crates/duckdb/src/vtab/arrow.rs index a7b64347..3bb1693e 100644 --- a/crates/duckdb/src/vtab/arrow.rs +++ b/crates/duckdb/src/vtab/arrow.rs @@ -622,11 +622,11 @@ mod test { array::{ Array, ArrayRef, AsArray, BinaryArray, Date32Array, Date64Array, Decimal128Array, Decimal256Array, FixedSizeListArray, GenericByteArray, GenericListArray, Int32Array, LargeStringArray, ListArray, - OffsetSizeTrait, PrimitiveArray, StringArray, Time32SecondArray, Time64MicrosecondArray, + OffsetSizeTrait, PrimitiveArray, StringArray, StructArray, Time32SecondArray, Time64MicrosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, }, buffer::{OffsetBuffer, ScalarBuffer}, - datatypes::{i256, ArrowPrimitiveType, ByteArrayType, DataType, Field, Schema}, + datatypes::{i256, ArrowPrimitiveType, ByteArrayType, DataType, Field, Fields, Schema}, record_batch::RecordBatch, }; use std::{error::Error, sync::Arc};