Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle (partially) dictionary values in ScalarValue serde #243

Merged
merged 1 commit into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -948,9 +948,15 @@ message Union{

// Used for List/FixedSizeList/LargeList/Struct
message ScalarNestedValue {
message Dictionary {
bytes ipc_message = 1;
bytes arrow_data = 2;
}

bytes ipc_message = 1;
bytes arrow_data = 2;
Schema schema = 3;
repeated Dictionary dictionaries = 4;
}

message ScalarTime32Value {
Expand Down
133 changes: 133 additions & 0 deletions datafusion/proto/src/generated/pbjson.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 13 additions & 0 deletions datafusion/proto/src/generated/prost.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

49 changes: 48 additions & 1 deletion datafusion/proto/src/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use std::collections::HashMap;
use std::sync::Arc;

use crate::protobuf::{
Expand All @@ -29,6 +30,7 @@ use crate::protobuf::{
OptimizedPhysicalPlanType, PlaceholderNode, RollupNode,
};

use arrow::array::ArrayRef;
use arrow::{
array::AsArray,
buffer::Buffer,
Expand Down Expand Up @@ -587,6 +589,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
let protobuf::ScalarNestedValue {
ipc_message,
arrow_data,
dictionaries,
schema,
} = &v;

Expand All @@ -613,11 +616,55 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
)
})?;

let dict_by_id: HashMap<i64,ArrayRef> = dictionaries.iter().map(|protobuf::scalar_nested_value::Dictionary { ipc_message, arrow_data }| {
let message = root_as_message(ipc_message.as_slice()).map_err(|e| {
Error::General(format!(
"Error IPC message while deserializing ScalarValue::List dictionary message: {e}"
))
})?;
let buffer = Buffer::from(arrow_data);

let dict_batch = message.header_as_dictionary_batch().ok_or_else(|| {
Error::General(
"Unexpected message type deserializing ScalarValue::List dictionary message"
.to_string(),
)
})?;

let id = dict_batch.id();

let fields_using_this_dictionary = schema.fields_with_dict_id(id);
let first_field = fields_using_this_dictionary.first().ok_or_else(|| {
Error::General("dictionary id not found in schema while deserializing ScalarValue::List".to_string())
})?;

let values: ArrayRef = match first_field.data_type() {
DataType::Dictionary(_, ref value_type) => {
// Make a fake schema for the dictionary batch.
let value = value_type.as_ref().clone();
let schema = Schema::new(vec![Field::new("", value, true)]);
// Read a single column
let record_batch = read_record_batch(
&buffer,
dict_batch.data().unwrap(),
Arc::new(schema),
&Default::default(),
None,
&message.version(),
)?;
Ok(record_batch.column(0).clone())
}
_ => Err(Error::General("dictionary id not found in schema while deserializing ScalarValue::List".to_string())),
}?;

Ok((id,values))
}).collect::<Result<HashMap<_,_>>>()?;

let record_batch = read_record_batch(
&buffer,
ipc_batch,
Arc::new(schema),
&Default::default(),
&dict_by_id,
None,
&message.version(),
)
Expand Down
9 changes: 8 additions & 1 deletion datafusion/proto/src/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1604,7 +1604,7 @@ fn encode_scalar_nested_value(

let gen = IpcDataGenerator {};
let mut dict_tracker = DictionaryTracker::new(false);
let (_, encoded_message) = gen
let (encoded_dictionaries, encoded_message) = gen
.encoded_batch(&batch, &mut dict_tracker, &Default::default())
.map_err(|e| {
Error::General(format!("Error encoding ScalarValue::List as IPC: {e}"))
Expand All @@ -1615,6 +1615,13 @@ fn encode_scalar_nested_value(
let scalar_list_value = protobuf::ScalarNestedValue {
ipc_message: encoded_message.ipc_message,
arrow_data: encoded_message.arrow_data,
dictionaries: encoded_dictionaries
.into_iter()
.map(|data| protobuf::scalar_nested_value::Dictionary {
ipc_message: data.ipc_message,
arrow_data: data.arrow_data,
})
.collect(),
schema: Some(schema),
};

Expand Down
49 changes: 49 additions & 0 deletions datafusion/proto/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1092,11 +1092,60 @@ fn round_trip_scalar_values() {
)
.build()
.unwrap(),
ScalarStructBuilder::new()
.with_scalar(
Field::new("a", DataType::Int32, true),
ScalarValue::from(23i32),
)
.with_scalar(
Field::new("b", DataType::Boolean, false),
ScalarValue::from(false),
)
.with_scalar(
Field::new(
"c",
DataType::Dictionary(
Box::new(DataType::UInt16),
Box::new(DataType::Utf8),
),
false,
),
ScalarValue::Dictionary(
Box::new(DataType::UInt16),
Box::new("value".into()),
),
)
.build()
.unwrap(),
ScalarValue::try_from(&DataType::Struct(Fields::from(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Boolean, false),
])))
.unwrap(),
ScalarValue::try_from(&DataType::Struct(Fields::from(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Boolean, false),
Field::new(
"c",
DataType::Dictionary(
Box::new(DataType::UInt16),
Box::new(DataType::Binary),
),
false,
),
Field::new(
"d",
DataType::new_list(
DataType::Dictionary(
Box::new(DataType::UInt16),
Box::new(DataType::Binary),
),
false,
),
false,
),
])))
.unwrap(),
ScalarValue::FixedSizeBinary(b"bar".to_vec().len() as i32, Some(b"bar".to_vec())),
ScalarValue::FixedSizeBinary(0, None),
ScalarValue::FixedSizeBinary(5, None),
Expand Down
Loading
Loading