Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for reading lists #292

Merged
merged 15 commits into from
Apr 17, 2024
12 changes: 11 additions & 1 deletion src/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::{convert, sync::Arc};
use super::{Error, Result, Statement};
use crate::types::{self, FromSql, FromSqlError, ValueRef};

use arrow::array::{ArrayRef, ListArray};
use arrow::{
array::{self, Array, StructArray},
datatypes::*,
Expand Down Expand Up @@ -339,6 +340,10 @@ impl<'stmt> Row<'stmt> {

fn value_ref(&self, row: usize, col: usize) -> ValueRef<'_> {
let column = self.arr.as_ref().as_ref().unwrap().column(col);
Self::value_ref_internal(row, col, column)
}

pub(crate) fn value_ref_internal(row: usize, col: usize, column: &ArrayRef) -> ValueRef {
if column.is_null(row) {
return ValueRef::Null;
}
Expand Down Expand Up @@ -578,7 +583,12 @@ impl<'stmt> Row<'stmt> {
// DataType::Time64(unit) if *unit == TimeUnit::Nanosecond => {
// make_string_time!(array::Time64NanosecondArray, column, row)
// }
_ => unreachable!("invalid value: {}, {}", col, self.stmt.column_type(col)),
DataType::List(_data) => {
let arr = column.as_any().downcast_ref::<ListArray>().unwrap();

ValueRef::List(arr, row)
}
_ => unreachable!("invalid value: {} {}", col, column.data_type()),
}
}

Expand Down
136 changes: 128 additions & 8 deletions src/test_all_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use pretty_assertions::assert_eq;
use rust_decimal::Decimal;

use crate::{
types::{TimeUnit, ValueRef},
types::{TimeUnit, Type, Value, ValueRef},
Connection,
};

Expand All @@ -22,13 +22,6 @@ fn test_all_types() -> crate::Result<()> {
"small_enum",
"medium_enum",
"large_enum",
"int_array",
"double_array",
"date_array",
"timestamp_array",
"timestamptz_array",
"varchar_array",
"nested_int_array",
"struct",
"struct_of_arrays",
"array_of_structs",
Expand Down Expand Up @@ -58,6 +51,9 @@ fn test_all_types() -> crate::Result<()> {
idx += 1;
for column in row.stmt.column_names() {
let value = row.get_ref_unwrap(row.stmt.column_index(&column)?);
if idx != 2 {
assert_ne!(value.data_type(), Type::Null);
}
test_single(&mut idx, column, value);
}
}
Expand Down Expand Up @@ -214,6 +210,122 @@ fn test_single(idx: &mut i32, column: String, value: ValueRef) {
1 => assert_eq!(value, ValueRef::Blob(&[0, 0, 0, 97])),
_ => assert_eq!(value, ValueRef::Null),
},
"int_array" => match idx {
0 => assert_eq!(value.to_owned(), Value::List(vec![])),
1 => assert_eq!(
value.to_owned(),
Value::List(vec![
Value::Int(42),
Value::Int(999),
Value::Null,
Value::Null,
Value::Int(-42),
])
),
_ => assert_eq!(value, ValueRef::Null),
},
"double_array" => match idx {
0 => assert_eq!(value.to_owned(), Value::List(vec![])),
1 => {
let value = value.to_owned();

if let Value::List(values) = value {
assert_eq!(values.len(), 6);
assert_eq!(values[0], Value::Double(42.0));
assert!(unwrap(&values[1]).is_nan());
let val = unwrap(&values[2]);
assert!(val.is_infinite() && val.is_sign_positive());
let val = unwrap(&values[3]);
assert!(val.is_infinite() && val.is_sign_negative());
assert_eq!(values[4], Value::Null);
assert_eq!(values[5], Value::Double(-42.0));
}
}
_ => assert_eq!(value, ValueRef::Null),
},
"date_array" => match idx {
0 => assert_eq!(value.to_owned(), Value::List(vec![])),
1 => assert_eq!(
value.to_owned(),
Value::List(vec![
Value::Date32(0),
Value::Date32(2147483647),
Value::Date32(-2147483647),
Value::Null,
Value::Date32(19124),
])
),
_ => assert_eq!(value, ValueRef::Null),
},
"timestamp_array" => match idx {
0 => assert_eq!(value.to_owned(), Value::List(vec![])),
1 => assert_eq!(
value.to_owned(),
Value::List(vec![
Value::Timestamp(TimeUnit::Microsecond, 0,),
Value::Timestamp(TimeUnit::Microsecond, 9223372036854775807,),
Value::Timestamp(TimeUnit::Microsecond, -9223372036854775807,),
Value::Null,
Value::Timestamp(TimeUnit::Microsecond, 1652372625000000,),
],)
),
_ => assert_eq!(value, ValueRef::Null),
},
"timestamptz_array" => match idx {
0 => assert_eq!(value.to_owned(), Value::List(vec![])),
1 => assert_eq!(
value.to_owned(),
Value::List(vec![
Value::Timestamp(TimeUnit::Microsecond, 0,),
Value::Timestamp(TimeUnit::Microsecond, 9223372036854775807,),
Value::Timestamp(TimeUnit::Microsecond, -9223372036854775807,),
Value::Null,
Value::Timestamp(TimeUnit::Microsecond, 1652397825000000,),
])
),
_ => assert_eq!(value, ValueRef::Null),
},
"varchar_array" => match idx {
0 => assert_eq!(value.to_owned(), Value::List(vec![])),
1 => assert_eq!(
value.to_owned(),
Value::List(vec![
Value::Text("🦆🦆🦆🦆🦆🦆".to_string()),
Value::Text("goose".to_string()),
Value::Null,
Value::Text("".to_string()),
])
),
_ => assert_eq!(value, ValueRef::Null),
},
"nested_int_array" => match idx {
0 => assert_eq!(value.to_owned(), Value::List(vec![])),
1 => {
assert_eq!(
value.to_owned(),
Value::List(vec![
Value::List(vec![],),
Value::List(vec![
Value::Int(42,),
Value::Int(999,),
Value::Null,
Value::Null,
Value::Int(-42,),
],),
Value::Null,
Value::List(vec![],),
Value::List(vec![
Value::Int(42,),
Value::Int(999,),
Value::Null,
Value::Null,
Value::Int(-42,),
],),
],)
)
}
_ => assert_eq!(value, ValueRef::Null),
},
"bit" => match idx {
0 => assert_eq!(value, ValueRef::Blob(&[1, 145, 46, 42, 215]),),
1 => assert_eq!(value, ValueRef::Blob(&[3, 245])),
Expand All @@ -222,3 +334,11 @@ fn test_single(idx: &mut i32, column: String, value: ValueRef) {
_ => todo!("{column:?}"),
}
}

fn unwrap(value: &Value) -> f64 {
if let Value::Double(val) = value {
*val
} else {
panic!();
}
}
46 changes: 46 additions & 0 deletions src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ pub use self::{
value_ref::{TimeUnit, ValueRef},
};

use arrow::datatypes::DataType;
use std::fmt;

#[cfg(feature = "chrono")]
Expand Down Expand Up @@ -144,10 +145,54 @@ pub enum Type {
Date32,
/// TIME64
Time64,
/// LIST
List(Box<Type>),
/// Any
Any,
}

impl From<&DataType> for Type {
fn from(value: &DataType) -> Self {
match value {
DataType::Null => Self::Null,
DataType::Boolean => Self::Boolean,
DataType::Int8 => Self::TinyInt,
DataType::Int16 => Self::SmallInt,
DataType::Int32 => Self::Int,
DataType::Int64 => Self::BigInt,
DataType::UInt8 => Self::UTinyInt,
DataType::UInt16 => Self::USmallInt,
DataType::UInt32 => Self::UInt,
DataType::UInt64 => Self::UBigInt,
// DataType::Float16 => Self::Float16,
// DataType::Float32 => Self::Float32,
DataType::Float64 => Self::Float,
DataType::Timestamp(_, _) => Self::Timestamp,
DataType::Date32 => Self::Date32,
// DataType::Date64 => Self::Date64,
// DataType::Time32(_) => Self::Time32,
DataType::Time64(_) => Self::Time64,
// DataType::Duration(_) => Self::Duration,
// DataType::Interval(_) => Self::Interval,
DataType::Binary => Self::Blob,
// DataType::FixedSizeBinary(_) => Self::FixedSizeBinary,
// DataType::LargeBinary => Self::LargeBinary,
DataType::Utf8 => Self::Text,
// DataType::LargeUtf8 => Self::LargeUtf8,
DataType::List(inner) => Self::List(Box::new(Type::from(inner.data_type()))),
// DataType::FixedSizeList(field, size) => Self::Array,
// DataType::LargeList(_) => Self::LargeList,
// DataType::Struct(inner) => Self::Struct,
// DataType::Union(_, _) => Self::Union,
// DataType::Dictionary(_, _) => Self::Enum,
DataType::Decimal128(..) => Self::Decimal,
DataType::Decimal256(..) => Self::Decimal,
// DataType::Map(field, ..) => Self::Map,
res => unimplemented!("{}", res),
}
}
}

impl fmt::Display for Type {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
Expand All @@ -170,6 +215,7 @@ impl fmt::Display for Type {
Type::Blob => f.pad("Blob"),
Type::Date32 => f.pad("Date32"),
Type::Time64 => f.pad("Time64"),
Type::List(..) => f.pad("List"),
Type::Any => f.pad("Any"),
}
}
Expand Down
3 changes: 3 additions & 0 deletions src/types/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ pub enum Value {
Date32(i32),
/// The value is a time64
Time64(TimeUnit, i64),
/// The value is a list
List(Vec<Value>),
}

impl From<Null> for Value {
Expand Down Expand Up @@ -212,6 +214,7 @@ impl Value {
Value::Blob(_) => Type::Blob,
Value::Date32(_) => Type::Date32,
Value::Time64(..) => Type::Time64,
Value::List(_) => todo!(),
}
}
}
20 changes: 20 additions & 0 deletions src/types/value_ref.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use super::{Type, Value};
use crate::types::{FromSqlError, FromSqlResult};

use crate::Row;
use rust_decimal::prelude::*;

use arrow::array::{Array, ListArray};

/// An absolute length of time in seconds, milliseconds, microseconds or nanoseconds.
/// Copy from arrow::datatypes::TimeUnit
#[derive(Copy, Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
Expand Down Expand Up @@ -61,6 +64,8 @@ pub enum ValueRef<'a> {
Date32(i32),
/// The value is a time64
Time64(TimeUnit, i64),
/// The value is a list
List(&'a ListArray, usize),
}

impl ValueRef<'_> {
Expand All @@ -87,8 +92,14 @@ impl ValueRef<'_> {
ValueRef::Blob(_) => Type::Blob,
ValueRef::Date32(_) => Type::Date32,
ValueRef::Time64(..) => Type::Time64,
ValueRef::List(arr, _) => arr.data_type().into(),
}
}

/// Returns an owned version of this ValueRef
pub fn to_owned(&self) -> Value {
Mause marked this conversation as resolved.
Show resolved Hide resolved
(*self).into()
}
}

impl<'a> ValueRef<'a> {
Expand Down Expand Up @@ -140,6 +151,14 @@ impl From<ValueRef<'_>> for Value {
ValueRef::Blob(b) => Value::Blob(b.to_vec()),
ValueRef::Date32(d) => Value::Date32(d),
ValueRef::Time64(t, d) => Value::Time64(t, d),
ValueRef::List(items, idx) => {
let offsets = items.offsets();
let range = offsets[idx]..offsets[idx + 1];
let map: Vec<Value> = range
.map(|row| Row::value_ref_internal(row.try_into().unwrap(), idx, items.values()).to_owned())
.collect();
Value::List(map)
}
}
}
}
Expand Down Expand Up @@ -181,6 +200,7 @@ impl<'a> From<&'a Value> for ValueRef<'a> {
Value::Blob(ref b) => ValueRef::Blob(b),
Value::Date32(d) => ValueRef::Date32(d),
Value::Time64(t, d) => ValueRef::Time64(t, d),
Value::List(..) => unimplemented!(),
}
}
}
Expand Down
Loading