Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor: Consolidate UDF tests #7704

Merged
merged 6 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 0 additions & 81 deletions datafusion/core/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2330,87 +2330,6 @@ mod tests {
Ok(())
}

#[tokio::test]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved to user defined tests

async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> {
let ctx = SessionContext::new();
ctx.register_table("t", test::table_with_sequence(1, 1).unwrap())
.unwrap();

let myfunc = |args: &[ArrayRef]| Ok(Arc::clone(&args[0]));
let myfunc = make_scalar_function(myfunc);

ctx.register_udf(create_udf(
"MY_FUNC",
vec![DataType::Int32],
Arc::new(DataType::Int32),
Volatility::Immutable,
myfunc,
));

// doesn't work as it was registered with non lowercase
let err = plan_and_collect(&ctx, "SELECT MY_FUNC(i) FROM t")
.await
.unwrap_err();
assert!(err
.to_string()
.contains("Error during planning: Invalid function \'my_func\'"));

// Can call it if you put quotes
let result = plan_and_collect(&ctx, "SELECT \"MY_FUNC\"(i) FROM t").await?;

let expected = [
"+--------------+",
"| MY_FUNC(t.i) |",
"+--------------+",
"| 1 |",
"+--------------+",
];
assert_batches_eq!(expected, &result);

Ok(())
}

#[tokio::test]
async fn case_sensitive_identifiers_user_defined_aggregates() -> Result<()> {
let ctx = SessionContext::new();
ctx.register_table("t", test::table_with_sequence(1, 1).unwrap())
.unwrap();

// Note capitalization
let my_avg = create_udaf(
"MY_AVG",
vec![DataType::Float64],
Arc::new(DataType::Float64),
Volatility::Immutable,
Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
Arc::new(vec![DataType::UInt64, DataType::Float64]),
);

ctx.register_udaf(my_avg);

// doesn't work as it was registered as non lowercase
let err = plan_and_collect(&ctx, "SELECT MY_AVG(i) FROM t")
.await
.unwrap_err();
assert!(err
.to_string()
.contains("Error during planning: Invalid function \'my_avg\'"));

// Can call it if you put quotes
let result = plan_and_collect(&ctx, "SELECT \"MY_AVG\"(i) FROM t").await?;

let expected = [
"+-------------+",
"| MY_AVG(t.i) |",
"+-------------+",
"| 1.0 |",
"+-------------+",
];
assert_batches_eq!(expected, &result);

Ok(())
}

#[tokio::test]
async fn query_csv_with_custom_partition_extension() -> Result<()> {
let tmp_dir = TempDir::new()?;
Expand Down
16 changes: 2 additions & 14 deletions datafusion/core/tests/sql/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ async fn test_array_cast_expressions() -> Result<()> {

#[tokio::test]
async fn test_random_expression() -> Result<()> {
let ctx = create_ctx();
let ctx = SessionContext::new();
let sql = "SELECT random() r1";
let actual = execute(&ctx, sql).await;
let r1 = actual[0][0].parse::<f64>().unwrap();
Expand All @@ -627,7 +627,7 @@ async fn test_random_expression() -> Result<()> {

#[tokio::test]
async fn test_uuid_expression() -> Result<()> {
let ctx = create_ctx();
let ctx = SessionContext::new();
let sql = "SELECT uuid()";
let actual = execute(&ctx, sql).await;
let uuid = actual[0][0].parse::<uuid::Uuid>().unwrap();
Expand Down Expand Up @@ -886,18 +886,6 @@ async fn csv_query_nullif_divide_by_0() -> Result<()> {
Ok(())
}

#[tokio::test]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was randomly in the file for testing exprs when it is actually an (user defined) aggregate query 😕

async fn csv_query_avg_sqrt() -> Result<()> {
let ctx = create_ctx();
register_aggregate_csv(&ctx).await?;
let sql = "SELECT avg(custom_sqrt(c12)) FROM aggregate_test_100";
let mut actual = execute(&ctx, sql).await;
actual.sort();
let expected = vec![vec!["0.6706002946036462"]];
assert_float_eq(&expected, &actual);
Ok(())
}

#[tokio::test]
async fn nested_subquery() -> Result<()> {
let ctx = SessionContext::new();
Expand Down
55 changes: 1 addition & 54 deletions datafusion/core/tests/sql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use chrono::prelude::*;
use chrono::Duration;

use datafusion::datasource::TableProvider;
use datafusion::error::{DataFusionError, Result};
use datafusion::logical_expr::{Aggregate, LogicalPlan, TableScan};
use datafusion::physical_plan::metrics::MetricValue;
use datafusion::physical_plan::ExecutionPlan;
Expand All @@ -34,15 +35,9 @@ use datafusion::prelude::*;
use datafusion::test_util;
use datafusion::{assert_batches_eq, assert_batches_sorted_eq};
use datafusion::{datasource::MemTable, physical_plan::collect};
use datafusion::{
error::{DataFusionError, Result},
physical_plan::ColumnarValue,
};
use datafusion::{execution::context::SessionContext, physical_plan::displayable};
use datafusion_common::cast::as_float64_array;
use datafusion_common::plan_err;
use datafusion_common::{assert_contains, assert_not_contains};
use datafusion_expr::Volatility;
use object_store::path::Path;
use std::fs::File;
use std::io::Write;
Expand Down Expand Up @@ -101,54 +96,6 @@ pub mod select;
mod sql_api;
pub mod subqueries;
pub mod timestamp;
pub mod udf;

fn assert_float_eq<T>(expected: &[Vec<T>], received: &[Vec<String>])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed

where
T: AsRef<str>,
{
expected
.iter()
.flatten()
.zip(received.iter().flatten())
.for_each(|(l, r)| {
let (l, r) = (
l.as_ref().parse::<f64>().unwrap(),
r.as_str().parse::<f64>().unwrap(),
);
if l.is_nan() || r.is_nan() {
assert!(l.is_nan() && r.is_nan());
} else if (l - r).abs() > 2.0 * f64::EPSILON {
panic!("{l} != {r}")
}
});
}

fn create_ctx() -> SessionContext {
let ctx = SessionContext::new();

// register a custom UDF
ctx.register_udf(create_udf(
"custom_sqrt",
vec![DataType::Float64],
Arc::new(DataType::Float64),
Volatility::Immutable,
Arc::new(custom_sqrt),
));

ctx
}

fn custom_sqrt(args: &[ColumnarValue]) -> Result<ColumnarValue> {
let arg = &args[0];
if let ColumnarValue::Array(v) = arg {
let input = as_float64_array(v).expect("cast failed");
let array: Float64Array = input.iter().map(|v| v.map(|x| x.sqrt())).collect();
Ok(ColumnarValue::Array(Arc::new(array)))
} else {
unimplemented!()
}
}

fn create_join_context(
column_left: &str,
Expand Down
3 changes: 3 additions & 0 deletions datafusion/core/tests/user_defined/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
// specific language governing permissions and limitations
// under the License.

/// Tests for user defined Scalar functions
mod user_defined_scalar_functions;

/// Tests for User Defined Aggregate Functions
mod user_defined_aggregates;

Expand Down
94 changes: 94 additions & 0 deletions datafusion/core/tests/user_defined/user_defined_aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@
//! user defined aggregate functions

use arrow::{array::AsArray, datatypes::Fields};
use arrow_array::Int32Array;
use arrow_schema::Schema;
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};

use datafusion::datasource::MemTable;
use datafusion::{
arrow::{
array::{ArrayRef, Float64Array, TimestampNanosecondArray},
Expand All @@ -43,6 +46,8 @@ use datafusion::{
use datafusion_common::{
assert_contains, cast::as_primitive_array, exec_err, DataFusionError,
};
use datafusion_expr::create_udaf;
use datafusion_physical_expr::expressions::AvgAccumulator;

/// Test to show the contents of the setup
#[tokio::test]
Expand Down Expand Up @@ -204,6 +209,95 @@ async fn execute(ctx: &SessionContext, sql: &str) -> Result<Vec<RecordBatch>> {
ctx.sql(sql).await?.collect().await
}

/// tests the creation, registration and usage of a UDAF
#[tokio::test]
async fn simple_udaf() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);

let batch1 = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)?;
let batch2 = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(Int32Array::from(vec![4, 5]))],
)?;

let ctx = SessionContext::new();

let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch1], vec![batch2]])?;
ctx.register_table("t", Arc::new(provider))?;

// define a udaf, using a DataFusion's accumulator
let my_avg = create_udaf(
"my_avg",
vec![DataType::Float64],
Arc::new(DataType::Float64),
Volatility::Immutable,
Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
Arc::new(vec![DataType::UInt64, DataType::Float64]),
);

ctx.register_udaf(my_avg);

let result = ctx.sql("SELECT MY_AVG(a) FROM t").await?.collect().await?;

let expected = [
"+-------------+",
"| my_avg(t.a) |",
"+-------------+",
"| 3.0 |",
"+-------------+",
];
assert_batches_eq!(expected, &result);

Ok(())
}

#[tokio::test]
async fn case_sensitive_identifiers_user_defined_aggregates() -> Result<()> {
let ctx = SessionContext::new();
let arr = Int32Array::from(vec![1]);
let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?;
ctx.register_batch("t", batch).unwrap();

// Note capitalization
let my_avg = create_udaf(
"MY_AVG",
vec![DataType::Float64],
Arc::new(DataType::Float64),
Volatility::Immutable,
Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
Arc::new(vec![DataType::UInt64, DataType::Float64]),
);

ctx.register_udaf(my_avg);

// doesn't work as it was registered as non lowercase
let err = ctx.sql("SELECT MY_AVG(i) FROM t").await.unwrap_err();
assert!(err
.to_string()
.contains("Error during planning: Invalid function \'my_avg\'"));

// Can call it if you put quotes
let result = ctx
.sql("SELECT \"MY_AVG\"(i) FROM t")
.await?
.collect()
.await?;

let expected = [
"+-------------+",
"| MY_AVG(t.i) |",
"+-------------+",
"| 1.0 |",
"+-------------+",
];
assert_batches_eq!(expected, &result);

Ok(())
}

/// Returns an context with a table "t" and the "first" and "time_sum"
/// aggregate functions registered.
///
Expand Down
Loading
Loading