Skip to content

Commit

Permalink
add unit test
Browse files Browse the repository at this point in the history
Signed-off-by: Ruihang Xia <[email protected]>
  • Loading branch information
waynexia committed Sep 29, 2023
1 parent 8a73718 commit 31044d2
Showing 1 changed file with 33 additions and 4 deletions.
37 changes: 33 additions & 4 deletions datafusion/core/tests/user_defined/user_defined_aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,37 @@ async fn test_udaf_returning_struct_subquery() {
assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap());
}

#[tokio::test]
async fn test_udaf_shadows_builtin_fn() {
let TestContext {
mut ctx,
test_state,
} = TestContext::new();
let sql = "SELECT sum(arrow_cast(time, 'Int64')) from t";

// compute with builtin `sum` aggregator
let expected = [
"+-------------+",
"| SUM(t.time) |",
"+-------------+",
"| 19000 |",
"+-------------+",
];
assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap());

// Register `TimeSum` with name `sum`. This will shadow the builtin one
let sql = "SELECT sum(time) from t";
TimeSum::register(&mut ctx, test_state.clone(), "sum");
let expected = [
"+----------------------------+",
"| sum(t.time) |",
"+----------------------------+",
"| 1970-01-01T00:00:00.000019 |",
"+----------------------------+",
];
assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap());
}

async fn execute(ctx: &SessionContext, sql: &str) -> Result<Vec<RecordBatch>> {
ctx.sql(sql).await?.collect().await
}
Expand Down Expand Up @@ -214,7 +245,7 @@ impl TestContext {
// Tell DataFusion about the "first" function
FirstSelector::register(&mut ctx);
// Tell DataFusion about the "time_sum" function
TimeSum::register(&mut ctx, Arc::clone(&test_state));
TimeSum::register(&mut ctx, Arc::clone(&test_state), "time_sum");

Self { ctx, test_state }
}
Expand Down Expand Up @@ -281,7 +312,7 @@ impl TimeSum {
Self { sum: 0, test_state }
}

fn register(ctx: &mut SessionContext, test_state: Arc<TestState>) {
fn register(ctx: &mut SessionContext, test_state: Arc<TestState>, name: &str) {
let timestamp_type = DataType::Timestamp(TimeUnit::Nanosecond, None);

// Returns the same type as its input
Expand All @@ -301,8 +332,6 @@ impl TimeSum {
let accumulator: AccumulatorFactoryFunction =
Arc::new(move |_| Ok(Box::new(Self::new(Arc::clone(&captured_state)))));

let name = "time_sum";

let time_sum =
AggregateUDF::new(name, &signature, &return_type, &accumulator, &state_type);

Expand Down

0 comments on commit 31044d2

Please sign in to comment.