Skip to content

Commit

Permalink
squash
Browse files Browse the repository at this point in the history
  • Loading branch information
avantgardnerio committed Aug 6, 2023
1 parent e39b5ca commit 222c458
Show file tree
Hide file tree
Showing 14 changed files with 597 additions and 33 deletions.
103 changes: 103 additions & 0 deletions datafusion/core/src/physical_optimizer/limit_aggregation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! An optimizer rule that detects aggregate operations that could use a limited bucket count

use crate::physical_optimizer::PhysicalOptimizerRule;
use crate::physical_plan::aggregates::AggregateExec;
use crate::physical_plan::sorts::sort::SortExec;
use crate::physical_plan::ExecutionPlan;
use datafusion_common::config::ConfigOptions;
use datafusion_common::{DataFusionError, Result};
use std::sync::Arc;

/// An optimizer rule that passes a `limit` hint to aggregations if the whole result is not needed
pub struct LimitAggregation {}

impl LimitAggregation {
/// Create a new `LimitAggregation`
pub fn new() -> Self {
Self {}
}

fn recurse(plan: Arc<dyn ExecutionPlan>) -> Result<Arc<dyn ExecutionPlan>> {
// Not a sort
let sort = if let Some(sort) = plan.as_any().downcast_ref::<SortExec>() {
sort
} else {
return Ok(plan);
};

// Error if sorting with two inputs
let children = sort.children();
let child = match children.as_slice() {
[] => Err(DataFusionError::Execution(
"Sorts should have children".to_string(),
))?,
[child] => child,
_ => Err(DataFusionError::Execution(
"Sorts should have 1 child".to_string(),
))?,
};

// Sort doesn't have an aggregate before it
let binding = (*child).as_any();
let aggr = if let Some(aggr) = binding.downcast_ref::<AggregateExec>() {
aggr
} else {
return Ok(plan);
};

// We found what we want: clone, copy the limit down, and return modified node
let mut new_aggr = AggregateExec::try_new(
aggr.mode,
aggr.group_by.clone(),
aggr.aggr_expr.clone(),
aggr.filter_expr.clone(),
aggr.order_by_expr.clone(),
aggr.input.clone(),
aggr.input_schema.clone(),
)?;
new_aggr.limit = sort.fetch();
let plan = Arc::new(SortExec::new(sort.expr().to_vec(), Arc::new(new_aggr)));
Ok(plan)
}
}

impl Default for LimitAggregation {
fn default() -> Self {
Self::new()
}
}

impl PhysicalOptimizerRule for LimitAggregation {
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
_config: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
LimitAggregation::recurse(plan.clone())
}

fn name(&self) -> &str {
"limit aggregation"
}

fn schema_check(&self) -> bool {
true
}
}
1 change: 1 addition & 0 deletions datafusion/core/src/physical_optimizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub mod coalesce_batches;
pub mod combine_partial_final_agg;
pub mod dist_enforcement;
pub mod join_selection;
pub mod limit_aggregation;
pub mod optimizer;
pub mod pipeline_checker;
pub mod pruning;
Expand Down
2 changes: 2 additions & 0 deletions datafusion/core/src/physical_optimizer/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use crate::physical_optimizer::coalesce_batches::CoalesceBatches;
use crate::physical_optimizer::combine_partial_final_agg::CombinePartialFinalAggregate;
use crate::physical_optimizer::dist_enforcement::EnforceDistribution;
use crate::physical_optimizer::join_selection::JoinSelection;
use crate::physical_optimizer::limit_aggregation::LimitAggregation;
use crate::physical_optimizer::pipeline_checker::PipelineChecker;
use crate::physical_optimizer::repartition::Repartition;
use crate::physical_optimizer::sort_enforcement::EnforceSorting;
Expand Down Expand Up @@ -101,6 +102,7 @@ impl PhysicalOptimizer {
// diagnostic error message when this happens. It makes no changes to the
// given query plan; i.e. it only acts as a final gatekeeping rule.
Arc::new(PipelineChecker::new()),
Arc::new(LimitAggregation::new()),
];

Self::with_rules(rules)
Expand Down
55 changes: 43 additions & 12 deletions datafusion/core/src/physical_plan/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,10 @@ use std::sync::Arc;
mod group_values;
mod no_grouping;
mod order;
mod priority_queue;
mod row_hash;

use crate::physical_plan::aggregates::priority_queue::GroupedPriorityQueueAggregateStream;
pub use datafusion_expr::AggregateFunction;
use datafusion_physical_expr::aggregate::is_order_sensitive;
pub use datafusion_physical_expr::expressions::create_aggregate_expr;
Expand Down Expand Up @@ -228,14 +230,16 @@ impl PartialEq for PhysicalGroupBy {

enum StreamType {
AggregateStream(AggregateStream),
GroupedHashAggregateStream(GroupedHashAggregateStream),
GroupedHash(GroupedHashAggregateStream),
GroupedPriorityQueue(GroupedPriorityQueueAggregateStream),
}

impl From<StreamType> for SendableRecordBatchStream {
fn from(stream: StreamType) -> Self {
match stream {
StreamType::AggregateStream(stream) => Box::pin(stream),
StreamType::GroupedHashAggregateStream(stream) => Box::pin(stream),
StreamType::GroupedHash(stream) => Box::pin(stream),
StreamType::GroupedPriorityQueue(stream) => Box::pin(stream),
}
}
}
Expand Down Expand Up @@ -265,6 +269,8 @@ pub struct AggregateExec {
pub(crate) filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
/// (ORDER BY clause) expression for each aggregate expression
pub(crate) order_by_expr: Vec<Option<LexOrdering>>,
/// Set if the output of this aggregation is truncated by a upstream sort/limit clause
pub(crate) limit: Option<usize>,
/// Input plan, could be a partial aggregate or the input to the aggregate
pub(crate) input: Arc<dyn ExecutionPlan>,
/// Schema after the aggregate is applied
Expand Down Expand Up @@ -670,6 +676,7 @@ impl AggregateExec {
metrics: ExecutionPlanMetricsSet::new(),
aggregation_ordering,
required_input_ordering,
limit: None,
})
}

Expand Down Expand Up @@ -718,15 +725,29 @@ impl AggregateExec {
partition: usize,
context: Arc<TaskContext>,
) -> Result<StreamType> {
// no group by at all
if self.group_by.expr.is_empty() {
Ok(StreamType::AggregateStream(AggregateStream::new(
return Ok(StreamType::AggregateStream(AggregateStream::new(
self, context, partition,
)?))
} else {
Ok(StreamType::GroupedHashAggregateStream(
GroupedHashAggregateStream::new(self, context, partition)?,
))
)?));
}

// grouping by an expression that has a sort/limit upstream
let is_minmax =
GroupedPriorityQueueAggregateStream::get_minmax_desc(self).is_some();
if self.limit.is_some() && is_minmax {
println!("Using limited priority queue aggregation");
return Ok(StreamType::GroupedPriorityQueue(
GroupedPriorityQueueAggregateStream::new(
self, context, partition, self.limit,
)?,
));
}

// grouping by something else and we need to just materialize all results
Ok(StreamType::GroupedHash(GroupedHashAggregateStream::new(
self, context, partition,
)?))
}
}

Expand Down Expand Up @@ -1149,7 +1170,7 @@ fn evaluate(
}

/// Evaluates expressions against a record batch.
fn evaluate_many(
pub fn evaluate_many(
expr: &[Vec<Arc<dyn PhysicalExpr>>],
batch: &RecordBatch,
) -> Result<Vec<Vec<ArrayRef>>> {
Expand All @@ -1172,7 +1193,17 @@ fn evaluate_optional(
.collect::<Result<Vec<_>>>()
}

fn evaluate_group_by(
/// Evaluate a group by expression against a `RecordBatch`
///
/// Arguments:
/// `group_by`: the expression to evaluate
/// `batch`: the `RecordBatch` to evaluate against
///
/// Returns: A Vec of Vecs of Array of results
/// The outer Vect appears to be for grouping sets
/// The inner Vect contains the results per expression
/// The inner-inner Array contains the results per row
pub fn evaluate_group_by(
group_by: &PhysicalGroupBy,
batch: &RecordBatch,
) -> Result<Vec<Vec<ArrayRef>>> {
Expand Down Expand Up @@ -1841,10 +1872,10 @@ mod tests {
assert!(matches!(stream, StreamType::AggregateStream(_)));
}
1 => {
assert!(matches!(stream, StreamType::GroupedHashAggregateStream(_)));
assert!(matches!(stream, StreamType::GroupedHash(_)));
}
2 => {
assert!(matches!(stream, StreamType::GroupedHashAggregateStream(_)));
assert!(matches!(stream, StreamType::GroupedHash(_)));
}
_ => panic!("Unknown version: {version}"),
}
Expand Down
Loading

0 comments on commit 222c458

Please sign in to comment.