Skip to content

Commit

Permalink
Merge commit '2f550032140d42d1ee6d8ed86f7790766fa7302e' into chunchun…
Browse files Browse the repository at this point in the history
…/update-df-apr-week-1
  • Loading branch information
appletreeisyellow committed Apr 17, 2024
2 parents 38dd2d1 + 2f55003 commit 1a1979a
Show file tree
Hide file tree
Showing 7 changed files with 945 additions and 823 deletions.
82 changes: 68 additions & 14 deletions datafusion/common/src/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -532,8 +532,20 @@ impl<T> Transformed<T> {
}
}

/// Transformation helper to process tree nodes that are siblings.
/// Transformation helper to process a sequence of iterable tree nodes that are siblings.
pub trait TransformedIterator: Iterator {
/// Apples `f` to each item in this iterator
///
/// Visits all items in the iterator unless
/// `f` returns an error or `f` returns TreeNodeRecursion::stop.
///
/// # Returns
/// Error if `f` returns an error
///
/// Ok(Transformed) such that:
/// 1. `transformed` is true if any return from `f` had transformed true
/// 2. `data` from the last invocation of `f`
/// 3. `tnr` from the last invocation of `f` or `Continue` if the iterator is empty
fn map_until_stop_and_collect<
F: FnMut(Self::Item) -> Result<Transformed<Self::Item>>,
>(
Expand All @@ -551,22 +563,64 @@ impl<I: Iterator> TransformedIterator for I {
) -> Result<Transformed<Vec<Self::Item>>> {
let mut tnr = TreeNodeRecursion::Continue;
let mut transformed = false;
let data = self
.map(|item| match tnr {
TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {
f(item).map(|result| {
tnr = result.tnr;
transformed |= result.transformed;
result.data
})
}
TreeNodeRecursion::Stop => Ok(item),
})
.collect::<Result<Vec<_>>>()?;
Ok(Transformed::new(data, transformed, tnr))
self.map(|item| match tnr {
TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {
f(item).map(|result| {
tnr = result.tnr;
transformed |= result.transformed;
result.data
})
}
TreeNodeRecursion::Stop => Ok(item),
})
.collect::<Result<Vec<_>>>()
.map(|data| Transformed::new(data, transformed, tnr))
}
}

/// Transformation helper to process a heterogeneous sequence of tree node containing
/// expressions.
/// This macro is very similar to [TransformedIterator::map_until_stop_and_collect] to
/// process nodes that are siblings, but it accepts an initial transformation (`F0`) and
/// a sequence of pairs. Each pair is made of an expression (`EXPR`) and its
/// transformation (`F`).
///
/// The macro builds up a tuple that contains `Transformed.data` result of `F0` as the
/// first element and further elements from the sequence of pairs. An element from a pair
/// is either the value of `EXPR` or the `Transformed.data` result of `F`, depending on
/// the `Transformed.tnr` result of previous `F`s (`F0` initially).
///
/// # Returns
/// Error if any of the transformations returns an error
///
/// Ok(Transformed<(data0, ..., dataN)>) such that:
/// 1. `transformed` is true if any of the transformations had transformed true
/// 2. `(data0, ..., dataN)`, where `data0` is the `Transformed.data` from `F0` and
/// `data1` ... `dataN` are from either `EXPR` or the `Transformed.data` of `F`
/// 3. `tnr` from `F0` or the last invocation of `F`
#[macro_export]
macro_rules! map_until_stop_and_collect {
($F0:expr, $($EXPR:expr, $F:expr),*) => {{
$F0.and_then(|Transformed { data: data0, mut transformed, mut tnr }| {
let all_datas = (
data0,
$(
if tnr == TreeNodeRecursion::Continue || tnr == TreeNodeRecursion::Jump {
$F.map(|result| {
tnr = result.tnr;
transformed |= result.transformed;
result.data
})?
} else {
$EXPR
},
)*
);
Ok(Transformed::new(all_datas, transformed, tnr))
})
}}
}

/// Transformation helper to access [`Transformed`] fields in a [`Result`] easily.
pub trait TransformedResult<T> {
fn data(self) -> Result<T>;
Expand Down
167 changes: 57 additions & 110 deletions datafusion/core/benches/sql_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ use arrow::datatypes::{DataType, Field, Fields, Schema};
use datafusion::datasource::MemTable;
use datafusion::execution::context::SessionContext;
use std::sync::Arc;
use test_utils::tpcds::tpcds_schemas;
use test_utils::tpch::tpch_schemas;
use test_utils::TableDef;
use tokio::runtime::Runtime;

/// Create a logical plan from the specified sql
Expand All @@ -48,116 +51,18 @@ fn physical_plan(ctx: &SessionContext, sql: &str) {
}

/// Create schema with the specified number of columns
pub fn create_schema(column_prefix: &str, num_columns: usize) -> Schema {
fn create_schema(column_prefix: &str, num_columns: usize) -> Schema {
let fields: Fields = (0..num_columns)
.map(|i| Field::new(format!("{column_prefix}{i}"), DataType::Int32, true))
.collect();
Schema::new(fields)
}

pub fn create_table_provider(column_prefix: &str, num_columns: usize) -> Arc<MemTable> {
fn create_table_provider(column_prefix: &str, num_columns: usize) -> Arc<MemTable> {
let schema = Arc::new(create_schema(column_prefix, num_columns));
MemTable::try_new(schema, vec![]).map(Arc::new).unwrap()
}

pub fn create_tpch_schemas() -> [(String, Schema); 8] {
let lineitem_schema = Schema::new(vec![
Field::new("l_orderkey", DataType::Int64, false),
Field::new("l_partkey", DataType::Int64, false),
Field::new("l_suppkey", DataType::Int64, false),
Field::new("l_linenumber", DataType::Int32, false),
Field::new("l_quantity", DataType::Decimal128(15, 2), false),
Field::new("l_extendedprice", DataType::Decimal128(15, 2), false),
Field::new("l_discount", DataType::Decimal128(15, 2), false),
Field::new("l_tax", DataType::Decimal128(15, 2), false),
Field::new("l_returnflag", DataType::Utf8, false),
Field::new("l_linestatus", DataType::Utf8, false),
Field::new("l_shipdate", DataType::Date32, false),
Field::new("l_commitdate", DataType::Date32, false),
Field::new("l_receiptdate", DataType::Date32, false),
Field::new("l_shipinstruct", DataType::Utf8, false),
Field::new("l_shipmode", DataType::Utf8, false),
Field::new("l_comment", DataType::Utf8, false),
]);

let orders_schema = Schema::new(vec![
Field::new("o_orderkey", DataType::Int64, false),
Field::new("o_custkey", DataType::Int64, false),
Field::new("o_orderstatus", DataType::Utf8, false),
Field::new("o_totalprice", DataType::Decimal128(15, 2), false),
Field::new("o_orderdate", DataType::Date32, false),
Field::new("o_orderpriority", DataType::Utf8, false),
Field::new("o_clerk", DataType::Utf8, false),
Field::new("o_shippriority", DataType::Int32, false),
Field::new("o_comment", DataType::Utf8, false),
]);

let part_schema = Schema::new(vec![
Field::new("p_partkey", DataType::Int64, false),
Field::new("p_name", DataType::Utf8, false),
Field::new("p_mfgr", DataType::Utf8, false),
Field::new("p_brand", DataType::Utf8, false),
Field::new("p_type", DataType::Utf8, false),
Field::new("p_size", DataType::Int32, false),
Field::new("p_container", DataType::Utf8, false),
Field::new("p_retailprice", DataType::Decimal128(15, 2), false),
Field::new("p_comment", DataType::Utf8, false),
]);

let supplier_schema = Schema::new(vec![
Field::new("s_suppkey", DataType::Int64, false),
Field::new("s_name", DataType::Utf8, false),
Field::new("s_address", DataType::Utf8, false),
Field::new("s_nationkey", DataType::Int64, false),
Field::new("s_phone", DataType::Utf8, false),
Field::new("s_acctbal", DataType::Decimal128(15, 2), false),
Field::new("s_comment", DataType::Utf8, false),
]);

let partsupp_schema = Schema::new(vec![
Field::new("ps_partkey", DataType::Int64, false),
Field::new("ps_suppkey", DataType::Int64, false),
Field::new("ps_availqty", DataType::Int32, false),
Field::new("ps_supplycost", DataType::Decimal128(15, 2), false),
Field::new("ps_comment", DataType::Utf8, false),
]);

let customer_schema = Schema::new(vec![
Field::new("c_custkey", DataType::Int64, false),
Field::new("c_name", DataType::Utf8, false),
Field::new("c_address", DataType::Utf8, false),
Field::new("c_nationkey", DataType::Int64, false),
Field::new("c_phone", DataType::Utf8, false),
Field::new("c_acctbal", DataType::Decimal128(15, 2), false),
Field::new("c_mktsegment", DataType::Utf8, false),
Field::new("c_comment", DataType::Utf8, false),
]);

let nation_schema = Schema::new(vec![
Field::new("n_nationkey", DataType::Int64, false),
Field::new("n_name", DataType::Utf8, false),
Field::new("n_regionkey", DataType::Int64, false),
Field::new("n_comment", DataType::Utf8, false),
]);

let region_schema = Schema::new(vec![
Field::new("r_regionkey", DataType::Int64, false),
Field::new("r_name", DataType::Utf8, false),
Field::new("r_comment", DataType::Utf8, false),
]);

[
("lineitem".to_string(), lineitem_schema),
("orders".to_string(), orders_schema),
("part".to_string(), part_schema),
("supplier".to_string(), supplier_schema),
("partsupp".to_string(), partsupp_schema),
("customer".to_string(), customer_schema),
("nation".to_string(), nation_schema),
("region".to_string(), region_schema),
]
}

fn create_context() -> SessionContext {
let ctx = SessionContext::new();
ctx.register_table("t1", create_table_provider("a", 200))
Expand All @@ -168,16 +73,19 @@ fn create_context() -> SessionContext {
.unwrap();
ctx.register_table("t1000", create_table_provider("d", 1000))
.unwrap();
ctx
}

let tpch_schemas = create_tpch_schemas();
tpch_schemas.iter().for_each(|(name, schema)| {
/// Register the table definitions as a MemTable with the context and return the
/// context
fn register_defs(ctx: SessionContext, defs: Vec<TableDef>) -> SessionContext {
defs.iter().for_each(|TableDef { name, schema }| {
ctx.register_table(
name,
Arc::new(MemTable::try_new(Arc::new(schema.clone()), vec![]).unwrap()),
)
.unwrap();
});

ctx
}

Expand Down Expand Up @@ -236,40 +144,79 @@ fn criterion_benchmark(c: &mut Criterion) {
})
});

// --- TPC-H ---

let tpch_ctx = register_defs(SessionContext::new(), tpch_schemas());

let tpch_queries = [
"q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13",
"q14", // "q15", q15 has multiple SQL statements which is not supported
"q16", "q17", "q18", "q19", "q20", "q21", "q22",
];

for q in tpch_queries {
let sql = std::fs::read_to_string(format!("../../benchmarks/queries/{}.sql", q))
.unwrap();
let sql =
std::fs::read_to_string(format!("../../benchmarks/queries/{q}.sql")).unwrap();
c.bench_function(&format!("physical_plan_tpch_{}", q), |b| {
b.iter(|| physical_plan(&ctx, &sql))
b.iter(|| physical_plan(&tpch_ctx, &sql))
});
}

let all_tpch_sql_queries = tpch_queries
.iter()
.map(|q| {
std::fs::read_to_string(format!("../../benchmarks/queries/{}.sql", q))
.unwrap()
std::fs::read_to_string(format!("../../benchmarks/queries/{q}.sql")).unwrap()
})
.collect::<Vec<_>>();

c.bench_function("physical_plan_tpch_all", |b| {
b.iter(|| {
for sql in &all_tpch_sql_queries {
physical_plan(&ctx, sql)
physical_plan(&tpch_ctx, sql)
}
})
});

c.bench_function("logical_plan_tpch_all", |b| {
b.iter(|| {
for sql in &all_tpch_sql_queries {
logical_plan(&ctx, sql)
logical_plan(&tpch_ctx, sql)
}
})
});

// --- TPC-DS ---

let tpcds_ctx = register_defs(SessionContext::new(), tpcds_schemas());

// 10, 35: Physical plan does not support logical expression Exists(<subquery>)
// 45: Physical plan does not support logical expression (<subquery>)
// 41: Optimizing disjunctions not supported
let ignored = [10, 35, 41, 45];

let raw_tpcds_sql_queries = (1..100)
.filter(|q| !ignored.contains(q))
.map(|q| std::fs::read_to_string(format!("./tests/tpc-ds/{q}.sql")).unwrap())
.collect::<Vec<_>>();

// some queries have multiple statements
let all_tpcds_sql_queries = raw_tpcds_sql_queries
.iter()
.flat_map(|sql| sql.split(';').filter(|s| !s.trim().is_empty()))
.collect::<Vec<_>>();

c.bench_function("physical_plan_tpcds_all", |b| {
b.iter(|| {
for sql in &all_tpcds_sql_queries {
physical_plan(&tpcds_ctx, sql)
}
})
});

c.bench_function("logical_plan_tpcds_all", |b| {
b.iter(|| {
for sql in &all_tpcds_sql_queries {
logical_plan(&tpcds_ctx, sql)
}
})
});
Expand Down
Loading

0 comments on commit 1a1979a

Please sign in to comment.