diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index aaca4dc48236..f2c5b4b080b2 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -76,9 +76,10 @@ impl From<&protobuf::PhysicalColumn> for Column { /// # Arguments /// /// * `proto` - Input proto with physical sort expression node -/// * `registry` - A registry knows how to build logical expressions out of user-defined function' names +/// * `registry` - A registry knows how to build logical expressions out of user-defined function names /// * `input_schema` - The Arrow schema for the input, used for determining expression data types /// when performing type coercion. +/// * `codec` - An extension codec used to decode custom UDFs. pub fn parse_physical_sort_expr( proto: &protobuf::PhysicalSortExprNode, registry: &dyn FunctionRegistry, @@ -102,9 +103,10 @@ pub fn parse_physical_sort_expr( /// # Arguments /// /// * `proto` - Input proto with vector of physical sort expression node -/// * `registry` - A registry knows how to build logical expressions out of user-defined function' names +/// * `registry` - A registry knows how to build logical expressions out of user-defined function names /// * `input_schema` - The Arrow schema for the input, used for determining expression data types /// when performing type coercion. +/// * `codec` - An extension codec used to decode custom UDFs. pub fn parse_physical_sort_exprs( proto: &[protobuf::PhysicalSortExprNode], registry: &dyn FunctionRegistry, @@ -123,9 +125,9 @@ pub fn parse_physical_sort_exprs( /// /// # Arguments /// -/// * `proto` - Input proto with physical window exprression node. +/// * `proto` - Input proto with physical window expression node. /// * `name` - Name of the window expression. -/// * `registry` - A registry knows how to build logical expressions out of user-defined function' names +/// * `registry` - A registry knows how to build logical expressions out of user-defined function names /// * `input_schema` - The Arrow schema for the input, used for determining expression data types /// when performing type coercion. pub fn parse_physical_window_expr( @@ -133,15 +135,29 @@ pub fn parse_physical_window_expr( registry: &dyn FunctionRegistry, input_schema: &Schema, ) -> Result> { - let codec = DefaultPhysicalExtensionCodec {}; + parse_physical_window_expr_ext( + proto, + registry, + input_schema, + &DefaultPhysicalExtensionCodec {}, + ) +} + +// TODO: Make this the public function on next major release. +pub(crate) fn parse_physical_window_expr_ext( + proto: &protobuf::PhysicalWindowExprNode, + registry: &dyn FunctionRegistry, + input_schema: &Schema, + codec: &dyn PhysicalExtensionCodec, +) -> Result> { let window_node_expr = - parse_physical_exprs(&proto.args, registry, input_schema, &codec)?; + parse_physical_exprs(&proto.args, registry, input_schema, codec)?; let partition_by = - parse_physical_exprs(&proto.partition_by, registry, input_schema, &codec)?; + parse_physical_exprs(&proto.partition_by, registry, input_schema, codec)?; let order_by = - parse_physical_sort_exprs(&proto.order_by, registry, input_schema, &codec)?; + parse_physical_sort_exprs(&proto.order_by, registry, input_schema, codec)?; let window_frame = proto .window_frame @@ -187,9 +203,10 @@ where /// # Arguments /// /// * `proto` - Input proto with physical expression node -/// * `registry` - A registry knows how to build logical expressions out of user-defined function' names +/// * `registry` - A registry knows how to build logical expressions out of user-defined function names /// * `input_schema` - The Arrow schema for the input, used for determining expression data types /// when performing type coercion. +/// * `codec` - An extension codec used to decode custom UDFs. pub fn parse_physical_expr( proto: &protobuf::PhysicalExprNode, registry: &dyn FunctionRegistry, @@ -213,6 +230,7 @@ pub fn parse_physical_expr( registry, "left", input_schema, + codec, )?, logical_plan::from_proto::from_proto_binary_op(&binary_expr.op)?, parse_required_physical_expr( @@ -220,6 +238,7 @@ pub fn parse_physical_expr( registry, "right", input_schema, + codec, )?, )), ExprType::AggregateExpr(_) => { @@ -241,6 +260,7 @@ pub fn parse_physical_expr( registry, "expr", input_schema, + codec, )?)) } ExprType::IsNotNullExpr(e) => { @@ -249,6 +269,7 @@ pub fn parse_physical_expr( registry, "expr", input_schema, + codec, )?)) } ExprType::NotExpr(e) => Arc::new(NotExpr::new(parse_required_physical_expr( @@ -256,6 +277,7 @@ pub fn parse_physical_expr( registry, "expr", input_schema, + codec, )?)), ExprType::Negative(e) => { Arc::new(NegativeExpr::new(parse_required_physical_expr( @@ -263,6 +285,7 @@ pub fn parse_physical_expr( registry, "expr", input_schema, + codec, )?)) } ExprType::InList(e) => in_list( @@ -271,6 +294,7 @@ pub fn parse_physical_expr( registry, "expr", input_schema, + codec, )?, parse_physical_exprs(&e.list, registry, input_schema, codec)?, &e.negated, @@ -290,12 +314,14 @@ pub fn parse_physical_expr( registry, "when_expr", input_schema, + codec, )?, parse_required_physical_expr( e.then_expr.as_ref(), registry, "then_expr", input_schema, + codec, )?, )) }) @@ -311,6 +337,7 @@ pub fn parse_physical_expr( registry, "expr", input_schema, + codec, )?, convert_required!(e.arrow_type)?, None, @@ -321,6 +348,7 @@ pub fn parse_physical_expr( registry, "expr", input_schema, + codec, )?, convert_required!(e.arrow_type)?, )), @@ -371,12 +399,14 @@ pub fn parse_physical_expr( registry, "expr", input_schema, + codec, )?, parse_required_physical_expr( like_expr.pattern.as_deref(), registry, "pattern", input_schema, + codec, )?, )), }; @@ -389,9 +419,9 @@ fn parse_required_physical_expr( registry: &dyn FunctionRegistry, field: &str, input_schema: &Schema, + codec: &dyn PhysicalExtensionCodec, ) -> Result> { - let codec = DefaultPhysicalExtensionCodec {}; - expr.map(|e| parse_physical_expr(e, registry, input_schema, &codec)) + expr.map(|e| parse_physical_expr(e, registry, input_schema, codec)) .transpose()? .ok_or_else(|| { DataFusionError::Internal(format!("Missing required field {field:?}")) @@ -433,15 +463,29 @@ pub fn parse_protobuf_hash_partitioning( partitioning: Option<&protobuf::PhysicalHashRepartition>, registry: &dyn FunctionRegistry, input_schema: &Schema, +) -> Result> { + parse_protobuf_hash_partitioning_ext( + partitioning, + registry, + input_schema, + &DefaultPhysicalExtensionCodec {}, + ) +} + +// TODO: Make this the public function on next major release. +fn parse_protobuf_hash_partitioning_ext( + partitioning: Option<&protobuf::PhysicalHashRepartition>, + registry: &dyn FunctionRegistry, + input_schema: &Schema, + codec: &dyn PhysicalExtensionCodec, ) -> Result> { match partitioning { Some(hash_part) => { - let codec = DefaultPhysicalExtensionCodec {}; let expr = parse_physical_exprs( &hash_part.hash_expr, registry, input_schema, - &codec, + codec, )?; Ok(Some(Partitioning::Hash( @@ -456,6 +500,19 @@ pub fn parse_protobuf_hash_partitioning( pub fn parse_protobuf_file_scan_config( proto: &protobuf::FileScanExecConf, registry: &dyn FunctionRegistry, +) -> Result { + parse_protobuf_file_scan_config_ext( + proto, + registry, + &DefaultPhysicalExtensionCodec {}, + ) +} + +// TODO: Make this the public function on next major release. +pub(crate) fn parse_protobuf_file_scan_config_ext( + proto: &protobuf::FileScanExecConf, + registry: &dyn FunctionRegistry, + codec: &dyn PhysicalExtensionCodec, ) -> Result { let schema: Arc = Arc::new(convert_required!(proto.schema)?); let projection = proto @@ -489,7 +546,7 @@ pub fn parse_protobuf_file_scan_config( .collect::>>()?; // Remove partition columns from the schema after recreating table_partition_cols - // because the partition columns are not in the file. They are present to allow the + // because the partition columns are not in the file. They are present to allow // the partition column types to be reconstructed after serde. let file_schema = Arc::new(Schema::new( schema @@ -502,12 +559,11 @@ pub fn parse_protobuf_file_scan_config( let mut output_ordering = vec![]; for node_collection in &proto.output_ordering { - let codec = DefaultPhysicalExtensionCodec {}; let sort_expr = parse_physical_sort_exprs( &node_collection.physical_sort_expr_nodes, registry, &schema, - &codec, + codec, )?; output_ordering.push(sort_expr); } diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 00dacffe06c2..b841d412a405 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -19,22 +19,8 @@ use std::convert::TryInto; use std::fmt::Debug; use std::sync::Arc; -use self::from_proto::parse_physical_window_expr; -use self::to_proto::serialize_physical_expr; - -use crate::common::{byte_to_string, proto_error, str_to_byte}; -use crate::convert_required; -use crate::physical_plan::from_proto::{ - parse_physical_expr, parse_physical_sort_expr, parse_physical_sort_exprs, - parse_protobuf_file_scan_config, -}; -use crate::protobuf::physical_aggregate_expr_node::AggregateFunction; -use crate::protobuf::physical_expr_node::ExprType; -use crate::protobuf::physical_plan_node::PhysicalPlanType; -use crate::protobuf::repartition_exec_node::PartitionMethod; -use crate::protobuf::{ - self, window_agg_exec_node, PhysicalPlanNode, PhysicalSortExprNodeCollection, -}; +use prost::bytes::BufMut; +use prost::Message; use datafusion::arrow::compute::SortOptions; use datafusion::arrow::datatypes::SchemaRef; @@ -79,8 +65,22 @@ use datafusion::physical_plan::{ use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; use datafusion_expr::ScalarUDF; -use prost::bytes::BufMut; -use prost::Message; +use crate::common::{byte_to_string, proto_error, str_to_byte}; +use crate::convert_required; +use crate::physical_plan::from_proto::{ + parse_physical_expr, parse_physical_sort_expr, parse_physical_sort_exprs, + parse_physical_window_expr_ext, parse_protobuf_file_scan_config, + parse_protobuf_file_scan_config_ext, +}; +use crate::protobuf::physical_aggregate_expr_node::AggregateFunction; +use crate::protobuf::physical_expr_node::ExprType; +use crate::protobuf::physical_plan_node::PhysicalPlanType; +use crate::protobuf::repartition_exec_node::PartitionMethod; +use crate::protobuf::{ + self, window_agg_exec_node, PhysicalPlanNode, PhysicalSortExprNodeCollection, +}; + +use self::to_proto::serialize_physical_expr; pub mod from_proto; pub mod to_proto; @@ -188,9 +188,10 @@ impl AsExecutionPlan for PhysicalPlanNode { } } PhysicalPlanType::CsvScan(scan) => Ok(Arc::new(CsvExec::new( - parse_protobuf_file_scan_config( + parse_protobuf_file_scan_config_ext( scan.base_conf.as_ref().unwrap(), registry, + extension_codec, )?, scan.has_header, str_to_byte(&scan.delimiter, "delimiter")?, @@ -230,12 +231,13 @@ impl AsExecutionPlan for PhysicalPlanNode { Default::default(), ))) } - PhysicalPlanType::AvroScan(scan) => { - Ok(Arc::new(AvroExec::new(parse_protobuf_file_scan_config( + PhysicalPlanType::AvroScan(scan) => Ok(Arc::new(AvroExec::new( + parse_protobuf_file_scan_config_ext( scan.base_conf.as_ref().unwrap(), registry, - )?))) - } + extension_codec, + )?, + ))), PhysicalPlanType::CoalesceBatches(coalesce_batches) => { let input: Arc = into_physical_plan( &coalesce_batches.input, @@ -334,10 +336,11 @@ impl AsExecutionPlan for PhysicalPlanNode { .window_expr .iter() .map(|window_expr| { - parse_physical_window_expr( + parse_physical_window_expr_ext( window_expr, registry, input_schema.as_ref(), + extension_codec, ) }) .collect::, _>>()?; diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 4924128ae190..06c0453f0e09 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -21,6 +21,8 @@ use std::sync::Arc; use std::vec; use arrow::csv::WriterBuilder; +use prost::Message; + use datafusion::arrow::array::ArrayRef; use datafusion::arrow::compute::kernels::sort::SortOptions; use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema}; @@ -79,19 +81,18 @@ use datafusion_expr::{ ScalarFunctionDefinition, ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF, WindowFrame, WindowFrameBound, }; -use datafusion_proto::physical_plan::from_proto::parse_physical_expr; -use datafusion_proto::physical_plan::to_proto::serialize_physical_expr; use datafusion_proto::physical_plan::{ AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec, }; use datafusion_proto::protobuf; -use prost::Message; /// Perform a serde roundtrip and assert that the string representation of the before and after plans /// are identical. Note that this often isn't sufficient to guarantee that no information is /// lost during serde because the string representation of a plan often only shows a subset of state. fn roundtrip_test(exec_plan: Arc) -> Result<()> { - let _ = roundtrip_test_and_return(exec_plan); + let ctx = SessionContext::new(); + let codec = DefaultPhysicalExtensionCodec {}; + roundtrip_test_and_return(exec_plan, &ctx, &codec)?; Ok(()) } @@ -103,15 +104,15 @@ fn roundtrip_test(exec_plan: Arc) -> Result<()> { /// farther in tests. fn roundtrip_test_and_return( exec_plan: Arc, + ctx: &SessionContext, + codec: &dyn PhysicalExtensionCodec, ) -> Result> { - let ctx = SessionContext::new(); - let codec = DefaultPhysicalExtensionCodec {}; let proto: protobuf::PhysicalPlanNode = - protobuf::PhysicalPlanNode::try_from_physical_plan(exec_plan.clone(), &codec) + protobuf::PhysicalPlanNode::try_from_physical_plan(exec_plan.clone(), codec) .expect("to proto"); let runtime = ctx.runtime_env(); let result_exec_plan: Arc = proto - .try_into_physical_plan(&ctx, runtime.deref(), &codec) + .try_into_physical_plan(ctx, runtime.deref(), codec) .expect("from proto"); assert_eq!(format!("{exec_plan:?}"), format!("{result_exec_plan:?}")); Ok(result_exec_plan) @@ -125,17 +126,10 @@ fn roundtrip_test_and_return( /// performing serde on some plans. fn roundtrip_test_with_context( exec_plan: Arc, - ctx: SessionContext, + ctx: &SessionContext, ) -> Result<()> { let codec = DefaultPhysicalExtensionCodec {}; - let proto: protobuf::PhysicalPlanNode = - protobuf::PhysicalPlanNode::try_from_physical_plan(exec_plan.clone(), &codec) - .expect("to proto"); - let runtime = ctx.runtime_env(); - let result_exec_plan: Arc = proto - .try_into_physical_plan(&ctx, runtime.deref(), &codec) - .expect("from proto"); - assert_eq!(format!("{exec_plan:?}"), format!("{result_exec_plan:?}")); + roundtrip_test_and_return(exec_plan, ctx, &codec)?; Ok(()) } @@ -444,7 +438,7 @@ fn roundtrip_aggregate_udaf() -> Result<()> { Arc::new(EmptyExec::new(schema.clone())), schema, )?), - ctx, + &ctx, ) } @@ -667,7 +661,7 @@ fn roundtrip_scalar_udf() -> Result<()> { ctx.register_udf(udf); - roundtrip_test_with_context(Arc::new(project), ctx) + roundtrip_test_with_context(Arc::new(project), &ctx) } #[test] @@ -682,11 +676,7 @@ fn roundtrip_scalar_udf_extension_codec() { impl MyRegexUdf { fn new(pattern: String) -> Self { Self { - signature: Signature::uniform( - 1, - vec![DataType::Int32], - Volatility::Immutable, - ), + signature: Signature::exact(vec![DataType::Utf8], Volatility::Immutable), pattern, } } @@ -707,7 +697,7 @@ fn roundtrip_scalar_udf_extension_codec() { if !matches!(args.first(), Some(&DataType::Utf8)) { return plan_err!("regex_udf only accepts Utf8 arguments"); } - Ok(DataType::Int32) + Ok(DataType::Boolean) } fn invoke(&self, _args: &[ColumnarValue]) -> Result { unimplemented!() @@ -772,32 +762,40 @@ fn roundtrip_scalar_udf_extension_codec() { } } + let field_text = Field::new("text", DataType::Utf8, true); + let field_published = Field::new("published", DataType::Boolean, false); + let schema = Arc::new(Schema::new(vec![field_text, field_published])); + let input = Arc::new(EmptyExec::new(schema.clone())); + let pattern = ".*"; let udf = ScalarUDF::from(MyRegexUdf::new(pattern.to_string())); - let test_expr = ScalarFunctionExpr::new( + let udf_expr = Arc::new(ScalarFunctionExpr::new( udf.name(), ScalarFunctionDefinition::UDF(Arc::new(udf.clone())), - vec![], - DataType::Int32, + vec![col("text", &schema).expect("text")], + DataType::Boolean, None, false, + )); + + let filter = Arc::new( + FilterExec::try_new( + Arc::new(BinaryExpr::new( + col("published", &schema).expect("published"), + Operator::And, + udf_expr, + )), + input, + ) + .expect("filter"), ); - let fmt_expr = format!("{test_expr:?}"); - let ctx = SessionContext::new(); - ctx.register_udf(udf.clone()); - let extension_codec = ScalarUDFExtensionCodec {}; - let proto: protobuf::PhysicalExprNode = - match serialize_physical_expr(Arc::new(test_expr), &extension_codec) { - Ok(proto) => proto, - Err(e) => panic!("failed to serialize expr: {e:?}"), - }; - let field_a = Field::new("a", DataType::Int32, false); - let schema = Arc::new(Schema::new(vec![field_a])); - let round_trip = - parse_physical_expr(&proto, &ctx, &schema, &extension_codec).unwrap(); - assert_eq!(fmt_expr, format!("{round_trip:?}")); + let ctx = SessionContext::new(); + let codec = ScalarUDFExtensionCodec {}; + ctx.register_udf(udf); + roundtrip_test_and_return(filter, &ctx, &codec).unwrap(); } + #[test] fn roundtrip_distinct_count() -> Result<()> { let field_a = Field::new("a", DataType::Int64, false); @@ -921,12 +919,18 @@ fn roundtrip_csv_sink() -> Result<()> { }), )]; - let roundtrip_plan = roundtrip_test_and_return(Arc::new(FileSinkExec::new( - input, - data_sink, - schema.clone(), - Some(sort_order), - ))) + let ctx = SessionContext::new(); + let codec = DefaultPhysicalExtensionCodec {}; + let roundtrip_plan = roundtrip_test_and_return( + Arc::new(FileSinkExec::new( + input, + data_sink, + schema.clone(), + Some(sort_order), + )), + &ctx, + &codec, + ) .unwrap(); let roundtrip_plan = roundtrip_plan