From 0067eb33c5842a6fe4bbb1e43d170b82c91a2ba5 Mon Sep 17 00:00:00 2001 From: Dan Harris Date: Fri, 17 May 2024 10:17:42 -0400 Subject: [PATCH] Handle (partially) dictionary values in ScalarValue serde --- datafusion/proto/proto/datafusion.proto | 6 + datafusion/proto/src/generated/pbjson.rs | 133 ++++++++++++++++++ datafusion/proto/src/generated/prost.rs | 13 ++ .../proto/src/logical_plan/from_proto.rs | 49 ++++++- datafusion/proto/src/logical_plan/to_proto.rs | 9 +- .../tests/cases/roundtrip_logical_plan.rs | 49 +++++++ datafusion/substrait/src/serializer.rs | 2 +- 7 files changed, 258 insertions(+), 3 deletions(-) diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 766ca6633ee1..e9d170f30851 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -948,9 +948,15 @@ message Union{ // Used for List/FixedSizeList/LargeList/Struct message ScalarNestedValue { + message Dictionary { + bytes ipc_message = 1; + bytes arrow_data = 2; + } + bytes ipc_message = 1; bytes arrow_data = 2; Schema schema = 3; + repeated Dictionary dictionaries = 4; } message ScalarTime32Value { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index f2814956ef1b..3b9bfb9750a1 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -23242,6 +23242,9 @@ impl serde::Serialize for ScalarNestedValue { if self.schema.is_some() { len += 1; } + if !self.dictionaries.is_empty() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.ScalarNestedValue", len)?; if !self.ipc_message.is_empty() { #[allow(clippy::needless_borrow)] @@ -23254,6 +23257,9 @@ impl serde::Serialize for ScalarNestedValue { if let Some(v) = self.schema.as_ref() { struct_ser.serialize_field("schema", v)?; } + if !self.dictionaries.is_empty() { + struct_ser.serialize_field("dictionaries", &self.dictionaries)?; + } struct_ser.end() } } @@ -23269,6 +23275,7 @@ impl<'de> serde::Deserialize<'de> for ScalarNestedValue { "arrow_data", "arrowData", "schema", + "dictionaries", ]; #[allow(clippy::enum_variant_names)] @@ -23276,6 +23283,7 @@ impl<'de> serde::Deserialize<'de> for ScalarNestedValue { IpcMessage, ArrowData, Schema, + Dictionaries, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -23300,6 +23308,7 @@ impl<'de> serde::Deserialize<'de> for ScalarNestedValue { "ipcMessage" | "ipc_message" => Ok(GeneratedField::IpcMessage), "arrowData" | "arrow_data" => Ok(GeneratedField::ArrowData), "schema" => Ok(GeneratedField::Schema), + "dictionaries" => Ok(GeneratedField::Dictionaries), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -23322,6 +23331,7 @@ impl<'de> serde::Deserialize<'de> for ScalarNestedValue { let mut ipc_message__ = None; let mut arrow_data__ = None; let mut schema__ = None; + let mut dictionaries__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::IpcMessage => { @@ -23346,18 +23356,141 @@ impl<'de> serde::Deserialize<'de> for ScalarNestedValue { } schema__ = map_.next_value()?; } + GeneratedField::Dictionaries => { + if dictionaries__.is_some() { + return Err(serde::de::Error::duplicate_field("dictionaries")); + } + dictionaries__ = Some(map_.next_value()?); + } } } Ok(ScalarNestedValue { ipc_message: ipc_message__.unwrap_or_default(), arrow_data: arrow_data__.unwrap_or_default(), schema: schema__, + dictionaries: dictionaries__.unwrap_or_default(), }) } } deserializer.deserialize_struct("datafusion.ScalarNestedValue", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for scalar_nested_value::Dictionary { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.ipc_message.is_empty() { + len += 1; + } + if !self.arrow_data.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.ScalarNestedValue.Dictionary", len)?; + if !self.ipc_message.is_empty() { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("ipcMessage", pbjson::private::base64::encode(&self.ipc_message).as_str())?; + } + if !self.arrow_data.is_empty() { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("arrowData", pbjson::private::base64::encode(&self.arrow_data).as_str())?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for scalar_nested_value::Dictionary { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "ipc_message", + "ipcMessage", + "arrow_data", + "arrowData", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + IpcMessage, + ArrowData, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "ipcMessage" | "ipc_message" => Ok(GeneratedField::IpcMessage), + "arrowData" | "arrow_data" => Ok(GeneratedField::ArrowData), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = scalar_nested_value::Dictionary; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.ScalarNestedValue.Dictionary") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut ipc_message__ = None; + let mut arrow_data__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::IpcMessage => { + if ipc_message__.is_some() { + return Err(serde::de::Error::duplicate_field("ipcMessage")); + } + ipc_message__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; + } + GeneratedField::ArrowData => { + if arrow_data__.is_some() { + return Err(serde::de::Error::duplicate_field("arrowData")); + } + arrow_data__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; + } + } + } + Ok(scalar_nested_value::Dictionary { + ipc_message: ipc_message__.unwrap_or_default(), + arrow_data: arrow_data__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.ScalarNestedValue.Dictionary", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for ScalarTime32Value { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index ecc94fcdaf99..4375c3226c84 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1147,6 +1147,19 @@ pub struct ScalarNestedValue { pub arrow_data: ::prost::alloc::vec::Vec, #[prost(message, optional, tag = "3")] pub schema: ::core::option::Option, + #[prost(message, repeated, tag = "4")] + pub dictionaries: ::prost::alloc::vec::Vec, +} +/// Nested message and enum types in `ScalarNestedValue`. +pub mod scalar_nested_value { + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct Dictionary { + #[prost(bytes = "vec", tag = "1")] + pub ipc_message: ::prost::alloc::vec::Vec, + #[prost(bytes = "vec", tag = "2")] + pub arrow_data: ::prost::alloc::vec::Vec, + } } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 19edd71a3a80..8f55aa2df363 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::collections::HashMap; use std::sync::Arc; use crate::protobuf::{ @@ -29,6 +30,7 @@ use crate::protobuf::{ OptimizedPhysicalPlanType, PlaceholderNode, RollupNode, }; +use arrow::array::ArrayRef; use arrow::{ array::AsArray, buffer::Buffer, @@ -587,6 +589,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { let protobuf::ScalarNestedValue { ipc_message, arrow_data, + dictionaries, schema, } = &v; @@ -613,11 +616,55 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { ) })?; + let dict_by_id: HashMap = dictionaries.iter().map(|protobuf::scalar_nested_value::Dictionary { ipc_message, arrow_data }| { + let message = root_as_message(ipc_message.as_slice()).map_err(|e| { + Error::General(format!( + "Error IPC message while deserializing ScalarValue::List dictionary message: {e}" + )) + })?; + let buffer = Buffer::from(arrow_data); + + let dict_batch = message.header_as_dictionary_batch().ok_or_else(|| { + Error::General( + "Unexpected message type deserializing ScalarValue::List dictionary message" + .to_string(), + ) + })?; + + let id = dict_batch.id(); + + let fields_using_this_dictionary = schema.fields_with_dict_id(id); + let first_field = fields_using_this_dictionary.first().ok_or_else(|| { + Error::General("dictionary id not found in schema while deserializing ScalarValue::List".to_string()) + })?; + + let values: ArrayRef = match first_field.data_type() { + DataType::Dictionary(_, ref value_type) => { + // Make a fake schema for the dictionary batch. + let value = value_type.as_ref().clone(); + let schema = Schema::new(vec![Field::new("", value, true)]); + // Read a single column + let record_batch = read_record_batch( + &buffer, + dict_batch.data().unwrap(), + Arc::new(schema), + &Default::default(), + None, + &message.version(), + )?; + Ok(record_batch.column(0).clone()) + } + _ => Err(Error::General("dictionary id not found in schema while deserializing ScalarValue::List".to_string())), + }?; + + Ok((id,values)) + }).collect::>>()?; + let record_batch = read_record_batch( &buffer, ipc_batch, Arc::new(schema), - &Default::default(), + &dict_by_id, None, &message.version(), ) diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 11fc7362c75d..592988097f01 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1604,7 +1604,7 @@ fn encode_scalar_nested_value( let gen = IpcDataGenerator {}; let mut dict_tracker = DictionaryTracker::new(false); - let (_, encoded_message) = gen + let (encoded_dictionaries, encoded_message) = gen .encoded_batch(&batch, &mut dict_tracker, &Default::default()) .map_err(|e| { Error::General(format!("Error encoding ScalarValue::List as IPC: {e}")) @@ -1615,6 +1615,13 @@ fn encode_scalar_nested_value( let scalar_list_value = protobuf::ScalarNestedValue { ipc_message: encoded_message.ipc_message, arrow_data: encoded_message.arrow_data, + dictionaries: encoded_dictionaries + .into_iter() + .map(|data| protobuf::scalar_nested_value::Dictionary { + ipc_message: data.ipc_message, + arrow_data: data.arrow_data, + }) + .collect(), schema: Some(schema), }; diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 3c43f100750f..c2ea3f546b61 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -1092,11 +1092,60 @@ fn round_trip_scalar_values() { ) .build() .unwrap(), + ScalarStructBuilder::new() + .with_scalar( + Field::new("a", DataType::Int32, true), + ScalarValue::from(23i32), + ) + .with_scalar( + Field::new("b", DataType::Boolean, false), + ScalarValue::from(false), + ) + .with_scalar( + Field::new( + "c", + DataType::Dictionary( + Box::new(DataType::UInt16), + Box::new(DataType::Utf8), + ), + false, + ), + ScalarValue::Dictionary( + Box::new(DataType::UInt16), + Box::new("value".into()), + ), + ) + .build() + .unwrap(), ScalarValue::try_from(&DataType::Struct(Fields::from(vec![ Field::new("a", DataType::Int32, true), Field::new("b", DataType::Boolean, false), ]))) .unwrap(), + ScalarValue::try_from(&DataType::Struct(Fields::from(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Boolean, false), + Field::new( + "c", + DataType::Dictionary( + Box::new(DataType::UInt16), + Box::new(DataType::Binary), + ), + false, + ), + Field::new( + "d", + DataType::new_list( + DataType::Dictionary( + Box::new(DataType::UInt16), + Box::new(DataType::Binary), + ), + false, + ), + false, + ), + ]))) + .unwrap(), ScalarValue::FixedSizeBinary(b"bar".to_vec().len() as i32, Some(b"bar".to_vec())), ScalarValue::FixedSizeBinary(0, None), ScalarValue::FixedSizeBinary(5, None), diff --git a/datafusion/substrait/src/serializer.rs b/datafusion/substrait/src/serializer.rs index 6b81e33dfc37..0eb8bd6cd405 100644 --- a/datafusion/substrait/src/serializer.rs +++ b/datafusion/substrait/src/serializer.rs @@ -27,7 +27,7 @@ use substrait::proto::Plan; use std::fs::OpenOptions; use std::io::{Read, Write}; -#[allow(clippy::suspicious_open_options)] +#[allow(clippy::nonsensical_open_options)] pub async fn serialize(sql: &str, ctx: &SessionContext, path: &str) -> Result<()> { let protobuf_out = serialize_bytes(sql, ctx).await; let mut file = OpenOptions::new().create(true).write(true).open(path)?;