From ed3c4861e1b4c7c7c2e0e09bba24a7a119576806 Mon Sep 17 00:00:00 2001 From: ritchie Date: Tue, 10 Sep 2024 13:35:46 +0200 Subject: [PATCH] improve --- crates/polars-expr/src/expressions/apply.rs | 19 +++++++++++-------- crates/polars-expr/src/expressions/binary.rs | 6 +++--- crates/polars-expr/src/expressions/ternary.rs | 6 +++--- crates/polars-expr/src/planner.rs | 8 ++++---- .../src/executors/projection_utils.rs | 2 +- crates/polars-plan/src/plans/aexpr/mod.rs | 8 ++++---- crates/polars-plan/src/plans/aexpr/scalar.rs | 13 +++++++------ .../polars-plan/src/plans/aexpr/traverse.rs | 13 ++++++------- crates/polars-plan/src/plans/aexpr/utils.rs | 2 +- crates/polars-plan/src/plans/lit.rs | 2 +- 10 files changed, 41 insertions(+), 38 deletions(-) diff --git a/crates/polars-expr/src/expressions/apply.rs b/crates/polars-expr/src/expressions/apply.rs index 75086a85d5c30..a5ea16d0f22fb 100644 --- a/crates/polars-expr/src/expressions/apply.rs +++ b/crates/polars-expr/src/expressions/apply.rs @@ -18,7 +18,8 @@ pub struct ApplyExpr { function: SpecialEq>, expr: Expr, collect_groups: ApplyOptions, - returns_scalar: bool, + function_returns_scalar: bool, + function_operates_on_scalar: bool, allow_rename: bool, pass_name_to_apply: bool, input_schema: Option, @@ -29,6 +30,7 @@ pub struct ApplyExpr { } impl ApplyExpr { + #[allow(clippy::too_many_arguments)] pub(crate) fn new( inputs: Vec>, function: SpecialEq>, @@ -37,7 +39,7 @@ impl ApplyExpr { allow_threading: bool, input_schema: Option, output_dtype: Option, - returns_scalar: bool + returns_scalar: bool, ) -> Self { #[cfg(debug_assertions)] if matches!(options.collect_groups, ApplyOptions::ElementWise) @@ -45,14 +47,14 @@ impl ApplyExpr { { panic!("expr {:?} is not implemented correctly. 'returns_scalar' and 'elementwise' are mutually exclusive", expr) } - dbg!(returns_scalar, options.flags.contains(FunctionFlags::RETURNS_SCALAR), &expr); Self { inputs, function, expr, collect_groups: options.collect_groups, - returns_scalar: options.flags.contains(FunctionFlags::RETURNS_SCALAR), + function_returns_scalar: options.flags.contains(FunctionFlags::RETURNS_SCALAR), + function_operates_on_scalar: returns_scalar, allow_rename: options.flags.contains(FunctionFlags::ALLOW_RENAME), pass_name_to_apply: options.flags.contains(FunctionFlags::PASS_NAME_TO_APPLY), input_schema, @@ -74,7 +76,8 @@ impl ApplyExpr { function, expr, collect_groups, - returns_scalar: false, + function_returns_scalar: false, + function_operates_on_scalar: false, allow_rename: false, pass_name_to_apply: false, input_schema: None, @@ -106,7 +109,7 @@ impl ApplyExpr { ca: ListChunked, ) -> PolarsResult> { let all_unit_len = all_unit_length(&ca); - if all_unit_len && self.returns_scalar { + if all_unit_len && self.function_returns_scalar { ac.with_agg_state(AggState::AggregatedScalar( ca.explode().unwrap().into_series(), )); @@ -255,7 +258,7 @@ impl ApplyExpr { let mut ac = acs.swap_remove(0); ac.with_update_groups(UpdateGroups::No); - let agg_state = if self.returns_scalar { + let agg_state = if self.function_returns_scalar { AggState::AggregatedScalar(Series::new_empty(field.name().clone(), &field.dtype)) } else { match self.collect_groups { @@ -429,7 +432,7 @@ impl PhysicalExpr for ApplyExpr { } } fn is_scalar(&self) -> bool { - self.returns_scalar + self.function_returns_scalar || self.function_operates_on_scalar } } diff --git a/crates/polars-expr/src/expressions/binary.rs b/crates/polars-expr/src/expressions/binary.rs index 7a7ec2c98537e..c7e89132bc2e0 100644 --- a/crates/polars-expr/src/expressions/binary.rs +++ b/crates/polars-expr/src/expressions/binary.rs @@ -15,7 +15,7 @@ pub struct BinaryExpr { expr: Expr, has_literal: bool, allow_threading: bool, - is_scalar: bool + is_scalar: bool, } impl BinaryExpr { @@ -26,7 +26,7 @@ impl BinaryExpr { expr: Expr, has_literal: bool, allow_threading: bool, - is_scalar: bool + is_scalar: bool, ) -> Self { Self { left, @@ -35,7 +35,7 @@ impl BinaryExpr { expr, has_literal, allow_threading, - is_scalar + is_scalar, } } } diff --git a/crates/polars-expr/src/expressions/ternary.rs b/crates/polars-expr/src/expressions/ternary.rs index 964c0784815fe..c776e4b951dd6 100644 --- a/crates/polars-expr/src/expressions/ternary.rs +++ b/crates/polars-expr/src/expressions/ternary.rs @@ -12,7 +12,7 @@ pub struct TernaryExpr { expr: Expr, // Can be expensive on small data to run literals in parallel. run_par: bool, - returns_scalar: bool + returns_scalar: bool, } impl TernaryExpr { @@ -22,7 +22,7 @@ impl TernaryExpr { falsy: Arc, expr: Expr, run_par: bool, - returns_scalar: bool + returns_scalar: bool, ) -> Self { Self { predicate, @@ -30,7 +30,7 @@ impl TernaryExpr { falsy, expr, run_par, - returns_scalar + returns_scalar, } } } diff --git a/crates/polars-expr/src/planner.rs b/crates/polars-expr/src/planner.rs index 29f44153a84d4..e578b8da96795 100644 --- a/crates/polars-expr/src/planner.rs +++ b/crates/polars-expr/src/planner.rs @@ -303,7 +303,7 @@ fn create_physical_expr_inner( node_to_expr(expression, expr_arena), state.local.has_lit, state.allow_threading, - is_scalar + is_scalar, ))) }, Column(column) => Ok(Arc::new(ColumnExpr::new( @@ -464,7 +464,7 @@ fn create_physical_expr_inner( falsy, node_to_expr(expression, expr_arena), lit_count < 2, - is_scalar + is_scalar, ))) }, AnonymousFunction { @@ -505,7 +505,7 @@ fn create_physical_expr_inner( state.allow_threading, schema.cloned(), output_dtype, - is_scalar + is_scalar, ))) }, Function { @@ -545,7 +545,7 @@ fn create_physical_expr_inner( state.allow_threading, schema.cloned(), output_dtype, - is_scalar + is_scalar, ))) }, Slice { diff --git a/crates/polars-mem-engine/src/executors/projection_utils.rs b/crates/polars-mem-engine/src/executors/projection_utils.rs index cd41e9260c922..25c59fad895be 100644 --- a/crates/polars-mem-engine/src/executors/projection_utils.rs +++ b/crates/polars-mem-engine/src/executors/projection_utils.rs @@ -299,7 +299,7 @@ pub(super) fn check_expand_literals( } else { polars_ensure!(phys.is_scalar(), InvalidOperation: "Series length {} doesn't match the DataFrame height of {}\n\n\ - If you want this Series to be broadcasted, ensure it is a scalar (for instance by adding '.first()'.", + If you want this Series to be broadcasted, ensure it is a scalar (for instance by adding '.first()').", series.len(), df_height ); series.new_from_index(0, df_height) diff --git a/crates/polars-plan/src/plans/aexpr/mod.rs b/crates/polars-plan/src/plans/aexpr/mod.rs index b13f06a3e5234..4be23e79df147 100644 --- a/crates/polars-plan/src/plans/aexpr/mod.rs +++ b/crates/polars-plan/src/plans/aexpr/mod.rs @@ -1,9 +1,9 @@ #[cfg(feature = "cse")] mod hash; +mod scalar; mod schema; -mod utils; mod traverse; -mod scalar; +mod utils; use std::hash::{Hash, Hasher}; @@ -13,16 +13,16 @@ use polars_core::chunked_array::cast::CastOptions; use polars_core::prelude::*; use polars_core::utils::{get_time_units, try_get_supertype}; use polars_utils::arena::{Arena, Node}; +pub use scalar::is_scalar_ae; #[cfg(feature = "ir_serde")] use serde::{Deserialize, Serialize}; use strum_macros::IntoStaticStr; +pub use traverse::*; pub use utils::*; use crate::constants::LEN; use crate::plans::Context; use crate::prelude::*; -pub use scalar::is_scalar_ae; -pub use traverse::*; #[derive(Clone, Debug, IntoStaticStr)] #[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))] diff --git a/crates/polars-plan/src/plans/aexpr/scalar.rs b/crates/polars-plan/src/plans/aexpr/scalar.rs index 4076f05e049d3..f7d681b407d44 100644 --- a/crates/polars-plan/src/plans/aexpr/scalar.rs +++ b/crates/polars-plan/src/plans/aexpr/scalar.rs @@ -1,26 +1,27 @@ use recursive::recursive; -use super::*; +use super::*; #[recursive] pub fn is_scalar_ae(node: Node, expr_arena: &Arena) -> bool { match expr_arena.get(node) { AExpr::Literal(lv) => lv.is_scalar(), - AExpr::Function {options, input, .. } | AExpr::AnonymousFunction{options , input, ..} => { + AExpr::Function { options, input, .. } + | AExpr::AnonymousFunction { options, input, .. } => { if options.is_elementwise() { input.iter().all(|e| e.is_scalar(expr_arena)) } else { options.flags.contains(FunctionFlags::RETURNS_SCALAR) } }, - AExpr::BinaryExpr {left, right, ..} => { + AExpr::BinaryExpr { left, right, .. } => { is_scalar_ae(*left, expr_arena) && is_scalar_ae(*right, expr_arena) }, - AExpr::Ternary {truthy, falsy, ..} => { + AExpr::Ternary { truthy, falsy, .. } => { is_scalar_ae(*truthy, expr_arena) && is_scalar_ae(*falsy, expr_arena) }, AExpr::Agg(_) | AExpr::Len => true, - AExpr::Cast{expr , ..} | AExpr::Alias(expr, _) => is_scalar_ae(*expr, expr_arena), - _ => false + AExpr::Cast { expr, .. } | AExpr::Alias(expr, _) => is_scalar_ae(*expr, expr_arena), + _ => false, } } diff --git a/crates/polars-plan/src/plans/aexpr/traverse.rs b/crates/polars-plan/src/plans/aexpr/traverse.rs index 2fb08aa34e1a8..29999ef6995ff 100644 --- a/crates/polars-plan/src/plans/aexpr/traverse.rs +++ b/crates/polars-plan/src/plans/aexpr/traverse.rs @@ -1,7 +1,6 @@ use super::*; impl AExpr { - /// Push nodes at this level to a pre-allocated stack. pub(crate) fn nodes(&self, container: &mut C) { use AExpr::*; @@ -51,12 +50,12 @@ impl AExpr { AnonymousFunction { input, .. } | Function { input, .. } => // we iterate in reverse order, so that the lhs is popped first and will be found // as the root columns/ input columns by `_suffix` and `_keep_name` etc. - { - input - .iter() - .rev() - .for_each(|e| container.push_node(e.node())) - }, + { + input + .iter() + .rev() + .for_each(|e| container.push_node(e.node())) + }, Explode(e) => container.push_node(*e), Window { function, diff --git a/crates/polars-plan/src/plans/aexpr/utils.rs b/crates/polars-plan/src/plans/aexpr/utils.rs index d345b385a4255..aef7cd1573347 100644 --- a/crates/polars-plan/src/plans/aexpr/utils.rs +++ b/crates/polars-plan/src/plans/aexpr/utils.rs @@ -68,4 +68,4 @@ pub fn all_streamable(exprs: &[ExprIR], expr_arena: &Arena, context: Cont exprs .iter() .all(|e| is_streamable(e.node(), expr_arena, context)) -} \ No newline at end of file +} diff --git a/crates/polars-plan/src/plans/lit.rs b/crates/polars-plan/src/plans/lit.rs index de94d6b263724..c44fc3fe81473 100644 --- a/crates/polars-plan/src/plans/lit.rs +++ b/crates/polars-plan/src/plans/lit.rs @@ -229,7 +229,7 @@ impl LiteralValue { } pub fn is_scalar(&self) -> bool { - !matches!(self, LiteralValue::Series(_) | LiteralValue::Range {..}) + !matches!(self, LiteralValue::Series(_) | LiteralValue::Range { .. }) } }