Skip to content

Commit

Permalink
improve
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Sep 10, 2024
1 parent db5a9db commit ed3c486
Show file tree
Hide file tree
Showing 10 changed files with 41 additions and 38 deletions.
19 changes: 11 additions & 8 deletions crates/polars-expr/src/expressions/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ pub struct ApplyExpr {
function: SpecialEq<Arc<dyn SeriesUdf>>,
expr: Expr,
collect_groups: ApplyOptions,
returns_scalar: bool,
function_returns_scalar: bool,
function_operates_on_scalar: bool,
allow_rename: bool,
pass_name_to_apply: bool,
input_schema: Option<SchemaRef>,
Expand All @@ -29,6 +30,7 @@ pub struct ApplyExpr {
}

impl ApplyExpr {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
inputs: Vec<Arc<dyn PhysicalExpr>>,
function: SpecialEq<Arc<dyn SeriesUdf>>,
Expand All @@ -37,22 +39,22 @@ impl ApplyExpr {
allow_threading: bool,
input_schema: Option<SchemaRef>,
output_dtype: Option<DataType>,
returns_scalar: bool
returns_scalar: bool,
) -> Self {
#[cfg(debug_assertions)]
if matches!(options.collect_groups, ApplyOptions::ElementWise)
&& options.flags.contains(FunctionFlags::RETURNS_SCALAR)
{
panic!("expr {:?} is not implemented correctly. 'returns_scalar' and 'elementwise' are mutually exclusive", expr)
}
dbg!(returns_scalar, options.flags.contains(FunctionFlags::RETURNS_SCALAR), &expr);

Self {
inputs,
function,
expr,
collect_groups: options.collect_groups,
returns_scalar: options.flags.contains(FunctionFlags::RETURNS_SCALAR),
function_returns_scalar: options.flags.contains(FunctionFlags::RETURNS_SCALAR),
function_operates_on_scalar: returns_scalar,
allow_rename: options.flags.contains(FunctionFlags::ALLOW_RENAME),
pass_name_to_apply: options.flags.contains(FunctionFlags::PASS_NAME_TO_APPLY),
input_schema,
Expand All @@ -74,7 +76,8 @@ impl ApplyExpr {
function,
expr,
collect_groups,
returns_scalar: false,
function_returns_scalar: false,
function_operates_on_scalar: false,
allow_rename: false,
pass_name_to_apply: false,
input_schema: None,
Expand Down Expand Up @@ -106,7 +109,7 @@ impl ApplyExpr {
ca: ListChunked,
) -> PolarsResult<AggregationContext<'a>> {
let all_unit_len = all_unit_length(&ca);
if all_unit_len && self.returns_scalar {
if all_unit_len && self.function_returns_scalar {
ac.with_agg_state(AggState::AggregatedScalar(
ca.explode().unwrap().into_series(),
));
Expand Down Expand Up @@ -255,7 +258,7 @@ impl ApplyExpr {
let mut ac = acs.swap_remove(0);
ac.with_update_groups(UpdateGroups::No);

let agg_state = if self.returns_scalar {
let agg_state = if self.function_returns_scalar {
AggState::AggregatedScalar(Series::new_empty(field.name().clone(), &field.dtype))
} else {
match self.collect_groups {
Expand Down Expand Up @@ -429,7 +432,7 @@ impl PhysicalExpr for ApplyExpr {
}
}
fn is_scalar(&self) -> bool {
self.returns_scalar
self.function_returns_scalar || self.function_operates_on_scalar
}
}

Expand Down
6 changes: 3 additions & 3 deletions crates/polars-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pub struct BinaryExpr {
expr: Expr,
has_literal: bool,
allow_threading: bool,
is_scalar: bool
is_scalar: bool,
}

impl BinaryExpr {
Expand All @@ -26,7 +26,7 @@ impl BinaryExpr {
expr: Expr,
has_literal: bool,
allow_threading: bool,
is_scalar: bool
is_scalar: bool,
) -> Self {
Self {
left,
Expand All @@ -35,7 +35,7 @@ impl BinaryExpr {
expr,
has_literal,
allow_threading,
is_scalar
is_scalar,
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions crates/polars-expr/src/expressions/ternary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub struct TernaryExpr {
expr: Expr,
// Can be expensive on small data to run literals in parallel.
run_par: bool,
returns_scalar: bool
returns_scalar: bool,
}

impl TernaryExpr {
Expand All @@ -22,15 +22,15 @@ impl TernaryExpr {
falsy: Arc<dyn PhysicalExpr>,
expr: Expr,
run_par: bool,
returns_scalar: bool
returns_scalar: bool,
) -> Self {
Self {
predicate,
truthy,
falsy,
expr,
run_par,
returns_scalar
returns_scalar,
}
}
}
Expand Down
8 changes: 4 additions & 4 deletions crates/polars-expr/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ fn create_physical_expr_inner(
node_to_expr(expression, expr_arena),
state.local.has_lit,
state.allow_threading,
is_scalar
is_scalar,
)))
},
Column(column) => Ok(Arc::new(ColumnExpr::new(
Expand Down Expand Up @@ -464,7 +464,7 @@ fn create_physical_expr_inner(
falsy,
node_to_expr(expression, expr_arena),
lit_count < 2,
is_scalar
is_scalar,
)))
},
AnonymousFunction {
Expand Down Expand Up @@ -505,7 +505,7 @@ fn create_physical_expr_inner(
state.allow_threading,
schema.cloned(),
output_dtype,
is_scalar
is_scalar,
)))
},
Function {
Expand Down Expand Up @@ -545,7 +545,7 @@ fn create_physical_expr_inner(
state.allow_threading,
schema.cloned(),
output_dtype,
is_scalar
is_scalar,
)))
},
Slice {
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-mem-engine/src/executors/projection_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ pub(super) fn check_expand_literals(
} else {
polars_ensure!(phys.is_scalar(),
InvalidOperation: "Series length {} doesn't match the DataFrame height of {}\n\n\
If you want this Series to be broadcasted, ensure it is a scalar (for instance by adding '.first()'.",
If you want this Series to be broadcasted, ensure it is a scalar (for instance by adding '.first()').",
series.len(), df_height
);
series.new_from_index(0, df_height)
Expand Down
8 changes: 4 additions & 4 deletions crates/polars-plan/src/plans/aexpr/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#[cfg(feature = "cse")]
mod hash;
mod scalar;
mod schema;
mod utils;
mod traverse;
mod scalar;
mod utils;

use std::hash::{Hash, Hasher};

Expand All @@ -13,16 +13,16 @@ use polars_core::chunked_array::cast::CastOptions;
use polars_core::prelude::*;
use polars_core::utils::{get_time_units, try_get_supertype};
use polars_utils::arena::{Arena, Node};
pub use scalar::is_scalar_ae;
#[cfg(feature = "ir_serde")]
use serde::{Deserialize, Serialize};
use strum_macros::IntoStaticStr;
pub use traverse::*;
pub use utils::*;

use crate::constants::LEN;
use crate::plans::Context;
use crate::prelude::*;
pub use scalar::is_scalar_ae;
pub use traverse::*;

#[derive(Clone, Debug, IntoStaticStr)]
#[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))]
Expand Down
13 changes: 7 additions & 6 deletions crates/polars-plan/src/plans/aexpr/scalar.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
use recursive::recursive;
use super::*;

use super::*;

#[recursive]
pub fn is_scalar_ae(node: Node, expr_arena: &Arena<AExpr>) -> bool {
match expr_arena.get(node) {
AExpr::Literal(lv) => lv.is_scalar(),
AExpr::Function {options, input, .. } | AExpr::AnonymousFunction{options , input, ..} => {
AExpr::Function { options, input, .. }
| AExpr::AnonymousFunction { options, input, .. } => {
if options.is_elementwise() {
input.iter().all(|e| e.is_scalar(expr_arena))
} else {
options.flags.contains(FunctionFlags::RETURNS_SCALAR)
}
},
AExpr::BinaryExpr {left, right, ..} => {
AExpr::BinaryExpr { left, right, .. } => {
is_scalar_ae(*left, expr_arena) && is_scalar_ae(*right, expr_arena)
},
AExpr::Ternary {truthy, falsy, ..} => {
AExpr::Ternary { truthy, falsy, .. } => {
is_scalar_ae(*truthy, expr_arena) && is_scalar_ae(*falsy, expr_arena)
},
AExpr::Agg(_) | AExpr::Len => true,
AExpr::Cast{expr , ..} | AExpr::Alias(expr, _) => is_scalar_ae(*expr, expr_arena),
_ => false
AExpr::Cast { expr, .. } | AExpr::Alias(expr, _) => is_scalar_ae(*expr, expr_arena),
_ => false,
}
}
13 changes: 6 additions & 7 deletions crates/polars-plan/src/plans/aexpr/traverse.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use super::*;

impl AExpr {

/// Push nodes at this level to a pre-allocated stack.
pub(crate) fn nodes<C: PushNode>(&self, container: &mut C) {
use AExpr::*;
Expand Down Expand Up @@ -51,12 +50,12 @@ impl AExpr {
AnonymousFunction { input, .. } | Function { input, .. } =>
// we iterate in reverse order, so that the lhs is popped first and will be found
// as the root columns/ input columns by `_suffix` and `_keep_name` etc.
{
input
.iter()
.rev()
.for_each(|e| container.push_node(e.node()))
},
{
input
.iter()
.rev()
.for_each(|e| container.push_node(e.node()))
},
Explode(e) => container.push_node(*e),
Window {
function,
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/src/plans/aexpr/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,4 @@ pub fn all_streamable(exprs: &[ExprIR], expr_arena: &Arena<AExpr>, context: Cont
exprs
.iter()
.all(|e| is_streamable(e.node(), expr_arena, context))
}
}
2 changes: 1 addition & 1 deletion crates/polars-plan/src/plans/lit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ impl LiteralValue {
}

pub fn is_scalar(&self) -> bool {
!matches!(self, LiteralValue::Series(_) | LiteralValue::Range {..})
!matches!(self, LiteralValue::Series(_) | LiteralValue::Range { .. })
}
}

Expand Down

0 comments on commit ed3c486

Please sign in to comment.