diff --git a/crates/duckdb/src/vtab/arrow.rs b/crates/duckdb/src/vtab/arrow.rs index 941c6ea9..3bb1693e 100644 --- a/crates/duckdb/src/vtab/arrow.rs +++ b/crates/duckdb/src/vtab/arrow.rs @@ -8,7 +8,7 @@ use crate::vtab::vector::Inserter; use arrow::array::{ as_boolean_array, as_generic_binary_array, as_large_list_array, as_list_array, as_primitive_array, as_string_array, as_struct_array, Array, ArrayData, AsArray, BinaryArray, BooleanArray, Decimal128Array, FixedSizeListArray, - GenericListArray, OffsetSizeTrait, PrimitiveArray, StringArray, StructArray, + GenericListArray, GenericStringArray, LargeStringArray, OffsetSizeTrait, PrimitiveArray, StructArray, }; use arrow::{ @@ -229,6 +229,15 @@ pub fn record_batch_to_duckdb_data_chunk( DataType::Utf8 => { string_array_to_vector(as_string_array(col.as_ref()), &mut chunk.flat_vector(i)); } + DataType::LargeUtf8 => { + string_array_to_vector( + col.as_ref() + .as_any() + .downcast_ref::() + .ok_or_else(|| Box::::from("Unable to downcast to LargeStringArray"))?, + &mut chunk.flat_vector(i), + ); + } DataType::Binary => { binary_array_to_vector(as_generic_binary_array(col.as_ref()), &mut chunk.flat_vector(i)); } @@ -453,7 +462,7 @@ fn boolean_array_to_vector(array: &BooleanArray, out: &mut FlatVector) { } } -fn string_array_to_vector(array: &StringArray, out: &mut FlatVector) { +fn string_array_to_vector(array: &GenericStringArray, out: &mut FlatVector) { assert!(array.len() <= out.capacity()); // TODO: zero copy assignment @@ -612,12 +621,12 @@ mod test { use arrow::{ array::{ Array, ArrayRef, AsArray, BinaryArray, Date32Array, Date64Array, Decimal128Array, Decimal256Array, - FixedSizeListArray, GenericListArray, Int32Array, ListArray, OffsetSizeTrait, PrimitiveArray, StringArray, - StructArray, Time32SecondArray, Time64MicrosecondArray, TimestampMicrosecondArray, - TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, + FixedSizeListArray, GenericByteArray, GenericListArray, Int32Array, LargeStringArray, ListArray, + OffsetSizeTrait, PrimitiveArray, StringArray, StructArray, Time32SecondArray, Time64MicrosecondArray, + TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, }, buffer::{OffsetBuffer, ScalarBuffer}, - datatypes::{i256, ArrowPrimitiveType, DataType, Field, Fields, Schema}, + datatypes::{i256, ArrowPrimitiveType, ByteArrayType, DataType, Field, Fields, Schema}, record_batch::RecordBatch, }; use std::{error::Error, sync::Arc}; @@ -784,6 +793,48 @@ mod test { Ok(()) } + fn check_generic_byte_roundtrip( + arry_in: GenericByteArray, + arry_out: GenericByteArray, + ) -> Result<(), Box> + where + T1: ByteArrayType, + T2: ByteArrayType, + { + let db = Connection::open_in_memory()?; + db.register_table_function::("arrow")?; + + // Roundtrip a record batch from Rust to DuckDB and back to Rust + let schema = Schema::new(vec![Field::new("a", arry_in.data_type().clone(), false)]); + + let rb = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arry_in.clone())])?; + let param = arrow_recordbatch_to_query_params(rb); + let mut stmt = db.prepare("select a from arrow(?, ?)")?; + let rb = stmt.query_arrow(param)?.next().expect("no record batch"); + + let output_any_array = rb.column(0); + + assert!( + output_any_array.data_type().equals_datatype(arry_out.data_type()), + "{} != {}", + output_any_array.data_type(), + arry_out.data_type() + ); + + match output_any_array.as_bytes_opt::() { + Some(output_array) => { + assert_eq!(output_array.len(), arry_out.len()); + for i in 0..output_array.len() { + assert_eq!(output_array.is_valid(i), arry_out.is_valid(i)); + assert_eq!(output_array.value_data(), arry_out.value_data()) + } + } + None => panic!("Expected GenericByteArray"), + } + + Ok(()) + } + #[test] fn test_array_roundtrip() -> Result<(), Box> { check_generic_array_roundtrip(ListArray::new( @@ -862,6 +913,21 @@ mod test { Ok(()) } + #[test] + fn test_utf8_roundtrip() -> Result<(), Box> { + check_generic_byte_roundtrip( + StringArray::from(vec![Some("foo"), Some("Baz"), Some("bar")]), + StringArray::from(vec![Some("foo"), Some("Baz"), Some("bar")]), + )?; + + // [`LargeStringArray`] will be downcasted to [`StringArray`]. + check_generic_byte_roundtrip( + LargeStringArray::from(vec![Some("foo"), Some("Baz"), Some("bar")]), + StringArray::from(vec![Some("foo"), Some("Baz"), Some("bar")]), + )?; + Ok(()) + } + #[test] fn test_timestamp_roundtrip() -> Result<(), Box> { check_rust_primitive_array_roundtrip(Int32Array::from(vec![1, 2, 3]), Int32Array::from(vec![1, 2, 3]))?;