Skip to content

Commit

Permalink
feat: support "large" arrow data types (#307)
Browse files Browse the repository at this point in the history
* feat: add large arrow type support

* remove old match entry
  • Loading branch information
Mause committed Jun 3, 2024
1 parent a1aa55a commit bbc85d7
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 36 deletions.
27 changes: 8 additions & 19 deletions src/row.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::{convert, sync::Arc};

use super::{Error, Result, Statement};
use crate::types::{self, EnumType, FromSql, FromSqlError, ValueRef};
use crate::types::{self, EnumType, FromSql, FromSqlError, ListType, ValueRef};

use arrow::array::DictionaryArray;
use arrow::{
Expand Down Expand Up @@ -570,22 +570,6 @@ impl<'stmt> Row<'stmt> {
_ => unimplemented!("{:?}", unit),
},
// TODO: support more data types
// DataType::List(_) => make_string_from_list!(column, row),
// DataType::Dictionary(index_type, _value_type) => match **index_type {
// DataType::Int8 => dict_array_value_to_string::<Int8Type>(column, row),
// DataType::Int16 => dict_array_value_to_string::<Int16Type>(column, row),
// DataType::Int32 => dict_array_value_to_string::<Int32Type>(column, row),
// DataType::Int64 => dict_array_value_to_string::<Int64Type>(column, row),
// DataType::UInt8 => dict_array_value_to_string::<UInt8Type>(column, row),
// DataType::UInt16 => dict_array_value_to_string::<UInt16Type>(column, row),
// DataType::UInt32 => dict_array_value_to_string::<UInt32Type>(column, row),
// DataType::UInt64 => dict_array_value_to_string::<UInt64Type>(column, row),
// _ => Err(ArrowError::InvalidArgumentError(format!(
// "Pretty printing not supported for {:?} due to index type",
// column.data_type()
// ))),
// },

// NOTE: DataTypes not supported by duckdb
// DataType::Date64 => make_string_date!(array::Date64Array, column, row),
// DataType::Time32(unit) if *unit == TimeUnit::Second => {
Expand All @@ -597,10 +581,15 @@ impl<'stmt> Row<'stmt> {
// DataType::Time64(unit) if *unit == TimeUnit::Nanosecond => {
// make_string_time!(array::Time64NanosecondArray, column, row)
// }
DataType::List(_data) => {
DataType::LargeList(..) => {
let arr = column.as_any().downcast_ref::<array::LargeListArray>().unwrap();

ValueRef::List(ListType::Large(arr), row)
}
DataType::List(..) => {
let arr = column.as_any().downcast_ref::<ListArray>().unwrap();

ValueRef::List(arr, row)
ValueRef::List(ListType::Regular(arr), row)
}
DataType::Dictionary(key_type, ..) => {
let column = column.as_any();
Expand Down
12 changes: 11 additions & 1 deletion src/test_all_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,18 @@ use crate::{

#[test]
fn test_all_types() -> crate::Result<()> {
let database = Connection::open_in_memory()?;
test_with_database(&Connection::open_in_memory()?)
}

#[test]
fn test_large_arrow_types() -> crate::Result<()> {
let cfg = crate::Config::default().with("arrow_large_buffer_size", "true")?;
let database = Connection::open_in_memory_with_flags(cfg)?;

test_with_database(&database)
}

fn test_with_database(database: &Connection) -> crate::Result<()> {
let excluded = vec![
// uhugeint, time_tz, and dec38_10 aren't supported in the duckdb arrow layer
"uhugeint",
Expand Down
8 changes: 3 additions & 5 deletions src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ pub use self::{
from_sql::{FromSql, FromSqlError, FromSqlResult},
to_sql::{ToSql, ToSqlOutput},
value::Value,
value_ref::{EnumType, TimeUnit, ValueRef},
value_ref::{EnumType, ListType, TimeUnit, ValueRef},
};

use arrow::datatypes::DataType;
Expand Down Expand Up @@ -181,14 +181,12 @@ impl From<&DataType> for Type {
DataType::Binary => Self::Blob,
// DataType::FixedSizeBinary(_) => Self::FixedSizeBinary,
// DataType::LargeBinary => Self::LargeBinary,
DataType::Utf8 => Self::Text,
// DataType::LargeUtf8 => Self::LargeUtf8,
DataType::LargeUtf8 | DataType::Utf8 => Self::Text,
DataType::List(inner) => Self::List(Box::new(Type::from(inner.data_type()))),
// DataType::FixedSizeList(field, size) => Self::Array,
// DataType::LargeList(_) => Self::LargeList,
DataType::LargeList(inner) => Self::List(Box::new(Type::from(inner.data_type()))),
// DataType::Struct(inner) => Self::Struct,
// DataType::Union(_, _) => Self::Union,
// DataType::Dictionary(_, _) => Self::Enum,
DataType::Decimal128(..) => Self::Decimal,
DataType::Decimal256(..) => Self::Decimal,
// DataType::Map(field, ..) => Self::Map,
Expand Down
54 changes: 43 additions & 11 deletions src/types/value_ref.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::types::{FromSqlError, FromSqlResult};
use crate::Row;
use rust_decimal::prelude::*;

use arrow::array::{Array, DictionaryArray, ListArray};
use arrow::array::{Array, ArrayRef, DictionaryArray, LargeListArray, ListArray};
use arrow::datatypes::{UInt16Type, UInt32Type, UInt8Type};

/// An absolute length of time in seconds, milliseconds, microseconds or nanoseconds.
Expand Down Expand Up @@ -75,11 +75,20 @@ pub enum ValueRef<'a> {
nanos: i64,
},
/// The value is a list
List(&'a ListArray, usize),
List(ListType<'a>, usize),
/// The value is an enum
Enum(EnumType<'a>, usize),
}

/// Wrapper type for different list sizes
#[derive(Debug, Copy, Clone, PartialEq)]
pub enum ListType<'a> {
/// The underlying list is a `ListArray`
Regular(&'a ListArray),
/// The underlying list is a `LargeListArray`
Large(&'a LargeListArray),
}

/// Wrapper type for different enum sizes
#[derive(Debug, Copy, Clone, PartialEq)]
pub enum EnumType<'a> {
Expand Down Expand Up @@ -116,7 +125,10 @@ impl ValueRef<'_> {
ValueRef::Date32(_) => Type::Date32,
ValueRef::Time64(..) => Type::Time64,
ValueRef::Interval { .. } => Type::Interval,
ValueRef::List(arr, _) => arr.data_type().into(),
ValueRef::List(arr, _) => match arr {
ListType::Large(arr) => arr.data_type().into(),
ListType::Regular(arr) => arr.data_type().into(),
},
ValueRef::Enum(..) => Type::Enum,
}
}
Expand Down Expand Up @@ -177,14 +189,26 @@ impl From<ValueRef<'_>> for Value {
ValueRef::Date32(d) => Value::Date32(d),
ValueRef::Time64(t, d) => Value::Time64(t, d),
ValueRef::Interval { months, days, nanos } => Value::Interval { months, days, nanos },
ValueRef::List(items, idx) => {
let offsets = items.offsets();
let range = offsets[idx]..offsets[idx + 1];
let map: Vec<Value> = range
.map(|row| Row::value_ref_internal(row.try_into().unwrap(), idx, items.values()).to_owned())
.collect();
Value::List(map)
}
ValueRef::List(items, idx) => match items {
ListType::Regular(items) => {
let offsets = items.offsets();
from_list(
offsets[idx].try_into().unwrap(),
offsets[idx + 1].try_into().unwrap(),
idx,
items.values(),
)
}
ListType::Large(items) => {
let offsets = items.offsets();
from_list(
offsets[idx].try_into().unwrap(),
offsets[idx + 1].try_into().unwrap(),
idx,
items.values(),
)
}
},
ValueRef::Enum(items, idx) => {
let value = Row::value_ref_internal(
idx,
Expand All @@ -207,6 +231,14 @@ impl From<ValueRef<'_>> for Value {
}
}

fn from_list(start: usize, end: usize, idx: usize, values: &ArrayRef) -> Value {
Value::List(
(start..end)
.map(|row| Row::value_ref_internal(row, idx, values).to_owned())
.collect(),
)
}

impl<'a> From<&'a str> for ValueRef<'a> {
#[inline]
fn from(s: &str) -> ValueRef<'_> {
Expand Down

0 comments on commit bbc85d7

Please sign in to comment.