Skip to content

Commit

Permalink
Handle dictionary values in ScalarValue serde (#10563)
Browse files Browse the repository at this point in the history
* Handle dictionary values in ScalarValue serde

* Do not panic on failed physical expr decoding (#241)

* revert clippy change
  • Loading branch information
thinkharderdev committed May 17, 2024
1 parent 4e55768 commit e7858ff
Show file tree
Hide file tree
Showing 7 changed files with 259 additions and 4 deletions.
6 changes: 6 additions & 0 deletions datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -797,9 +797,15 @@ message Union{

// Used for List/FixedSizeList/LargeList/Struct
message ScalarNestedValue {
message Dictionary {
bytes ipc_message = 1;
bytes arrow_data = 2;
}

bytes ipc_message = 1;
bytes arrow_data = 2;
Schema schema = 3;
repeated Dictionary dictionaries = 4;
}

message ScalarTime32Value {
Expand Down
133 changes: 133 additions & 0 deletions datafusion/proto/src/generated/pbjson.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 13 additions & 0 deletions datafusion/proto/src/generated/prost.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

49 changes: 48 additions & 1 deletion datafusion/proto/src/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
// specific language governing permissions and limitations
// under the License.

use std::collections::HashMap;
use std::sync::Arc;

use arrow::array::ArrayRef;
use arrow::{
array::AsArray,
buffer::Buffer,
Expand Down Expand Up @@ -522,6 +524,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
let protobuf::ScalarNestedValue {
ipc_message,
arrow_data,
dictionaries,
schema,
} = &v;

Expand All @@ -548,11 +551,55 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
)
})?;

let dict_by_id: HashMap<i64,ArrayRef> = dictionaries.iter().map(|protobuf::scalar_nested_value::Dictionary { ipc_message, arrow_data }| {
let message = root_as_message(ipc_message.as_slice()).map_err(|e| {
Error::General(format!(
"Error IPC message while deserializing ScalarValue::List dictionary message: {e}"
))
})?;
let buffer = Buffer::from(arrow_data);

let dict_batch = message.header_as_dictionary_batch().ok_or_else(|| {
Error::General(
"Unexpected message type deserializing ScalarValue::List dictionary message"
.to_string(),
)
})?;

let id = dict_batch.id();

let fields_using_this_dictionary = schema.fields_with_dict_id(id);
let first_field = fields_using_this_dictionary.first().ok_or_else(|| {
Error::General("dictionary id not found in schema while deserializing ScalarValue::List".to_string())
})?;

let values: ArrayRef = match first_field.data_type() {
DataType::Dictionary(_, ref value_type) => {
// Make a fake schema for the dictionary batch.
let value = value_type.as_ref().clone();
let schema = Schema::new(vec![Field::new("", value, true)]);
// Read a single column
let record_batch = read_record_batch(
&buffer,
dict_batch.data().unwrap(),
Arc::new(schema),
&Default::default(),
None,
&message.version(),
)?;
Ok(record_batch.column(0).clone())
}
_ => Err(Error::General("dictionary id not found in schema while deserializing ScalarValue::List".to_string())),
}?;

Ok((id,values))
}).collect::<Result<HashMap<_,_>>>()?;

let record_batch = read_record_batch(
&buffer,
ipc_batch,
Arc::new(schema),
&Default::default(),
&dict_by_id,
None,
&message.version(),
)
Expand Down
9 changes: 8 additions & 1 deletion datafusion/proto/src/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1497,7 +1497,7 @@ fn encode_scalar_nested_value(

let gen = IpcDataGenerator {};
let mut dict_tracker = DictionaryTracker::new(false);
let (_, encoded_message) = gen
let (encoded_dictionaries, encoded_message) = gen
.encoded_batch(&batch, &mut dict_tracker, &Default::default())
.map_err(|e| {
Error::General(format!("Error encoding ScalarValue::List as IPC: {e}"))
Expand All @@ -1508,6 +1508,13 @@ fn encode_scalar_nested_value(
let scalar_list_value = protobuf::ScalarNestedValue {
ipc_message: encoded_message.ipc_message,
arrow_data: encoded_message.arrow_data,
dictionaries: encoded_dictionaries
.into_iter()
.map(|data| protobuf::scalar_nested_value::Dictionary {
ipc_message: data.ipc_message,
arrow_data: data.arrow_data,
})
.collect(),
schema: Some(schema),
};

Expand Down
4 changes: 2 additions & 2 deletions datafusion/proto/src/physical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -494,9 +494,9 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode {
match expr_type {
ExprType::AggregateExpr(agg_node) => {
let input_phy_expr: Vec<Arc<dyn PhysicalExpr>> = agg_node.expr.iter()
.map(|e| parse_physical_expr(e, registry, &physical_schema, extension_codec).unwrap()).collect();
.map(|e| parse_physical_expr(e, registry, &physical_schema, extension_codec)).collect::<Result<Vec<_>>>()?;
let ordering_req: Vec<PhysicalSortExpr> = agg_node.ordering_req.iter()
.map(|e| parse_physical_sort_expr(e, registry, &physical_schema, extension_codec).unwrap()).collect();
.map(|e| parse_physical_sort_expr(e, registry, &physical_schema, extension_codec)).collect::<Result<Vec<_>>>()?;
agg_node.aggregate_function.as_ref().map(|func| {
match func {
AggregateFunction::AggrFunction(i) => {
Expand Down
49 changes: 49 additions & 0 deletions datafusion/proto/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1104,9 +1104,58 @@ fn round_trip_scalar_values() {
)
.build()
.unwrap(),
ScalarStructBuilder::new()
.with_scalar(
Field::new("a", DataType::Int32, true),
ScalarValue::from(23i32),
)
.with_scalar(
Field::new("b", DataType::Boolean, false),
ScalarValue::from(false),
)
.with_scalar(
Field::new(
"c",
DataType::Dictionary(
Box::new(DataType::UInt16),
Box::new(DataType::Utf8),
),
false,
),
ScalarValue::Dictionary(
Box::new(DataType::UInt16),
Box::new("value".into()),
),
)
.build()
.unwrap(),
ScalarValue::try_from(&DataType::Struct(Fields::from(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Boolean, false),
])))
.unwrap(),
ScalarValue::try_from(&DataType::Struct(Fields::from(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Boolean, false),
Field::new(
"c",
DataType::Dictionary(
Box::new(DataType::UInt16),
Box::new(DataType::Binary),
),
false,
),
Field::new(
"d",
DataType::new_list(
DataType::Dictionary(
Box::new(DataType::UInt16),
Box::new(DataType::Binary),
),
false,
),
false,
),
])))
.unwrap(),
ScalarValue::FixedSizeBinary(b"bar".to_vec().len() as i32, Some(b"bar".to_vec())),
Expand Down

0 comments on commit e7858ff

Please sign in to comment.