Skip to content

Commit

Permalink
feat:implement sql style 'ends_with' and 'instr' string function (#8862)
Browse files Browse the repository at this point in the history
* feat:implement sql style 'ends_with' and 'instr' string function

* Use Arrow comparison functions for ends_with and starts_with implementations and extend tests for instr function
  • Loading branch information
zy-kkk committed Jan 23, 2024
1 parent 31b9b48 commit 084fdfb
Show file tree
Hide file tree
Showing 12 changed files with 366 additions and 26 deletions.
40 changes: 40 additions & 0 deletions datafusion/core/tests/dataframe/dataframe_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,26 @@ async fn test_fn_initcap() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn test_fn_instr() -> Result<()> {
let expr = instr(col("a"), lit("b"));

let expected = [
"+-------------------------+",
"| instr(test.a,Utf8(\"b\")) |",
"+-------------------------+",
"| 2 |",
"| 2 |",
"| 0 |",
"| 5 |",
"+-------------------------+",
];

assert_fn_batches!(expr, expected);

Ok(())
}

#[tokio::test]
#[cfg(feature = "unicode_expressions")]
async fn test_fn_left() -> Result<()> {
Expand Down Expand Up @@ -634,6 +654,26 @@ async fn test_fn_starts_with() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn test_fn_ends_with() -> Result<()> {
let expr = ends_with(col("a"), lit("DEF"));

let expected = [
"+-------------------------------+",
"| ends_with(test.a,Utf8(\"DEF\")) |",
"+-------------------------------+",
"| true |",
"| false |",
"| false |",
"| false |",
"+-------------------------------+",
];

assert_fn_batches!(expr, expected);

Ok(())
}

#[tokio::test]
#[cfg(feature = "unicode_expressions")]
async fn test_fn_strpos() -> Result<()> {
Expand Down
36 changes: 25 additions & 11 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,12 @@ pub enum BuiltinScalarFunction {
DateTrunc,
/// date_bin
DateBin,
/// ends_with
EndsWith,
/// initcap
InitCap,
/// InStr
InStr,
/// left
Left,
/// lpad
Expand Down Expand Up @@ -446,7 +450,9 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::DatePart => Volatility::Immutable,
BuiltinScalarFunction::DateTrunc => Volatility::Immutable,
BuiltinScalarFunction::DateBin => Volatility::Immutable,
BuiltinScalarFunction::EndsWith => Volatility::Immutable,
BuiltinScalarFunction::InitCap => Volatility::Immutable,
BuiltinScalarFunction::InStr => Volatility::Immutable,
BuiltinScalarFunction::Left => Volatility::Immutable,
BuiltinScalarFunction::Lpad => Volatility::Immutable,
BuiltinScalarFunction::Lower => Volatility::Immutable,
Expand Down Expand Up @@ -708,6 +714,9 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::InitCap => {
utf8_to_str_type(&input_expr_types[0], "initcap")
}
BuiltinScalarFunction::InStr => {
utf8_to_int_type(&input_expr_types[0], "instr")
}
BuiltinScalarFunction::Left => utf8_to_str_type(&input_expr_types[0], "left"),
BuiltinScalarFunction::Lower => {
utf8_to_str_type(&input_expr_types[0], "lower")
Expand Down Expand Up @@ -795,6 +804,7 @@ impl BuiltinScalarFunction {
true,
)))),
BuiltinScalarFunction::StartsWith => Ok(Boolean),
BuiltinScalarFunction::EndsWith => Ok(Boolean),
BuiltinScalarFunction::Strpos => {
utf8_to_int_type(&input_expr_types[0], "strpos")
}
Expand Down Expand Up @@ -1211,17 +1221,19 @@ impl BuiltinScalarFunction {
],
self.volatility(),
),
BuiltinScalarFunction::Strpos | BuiltinScalarFunction::StartsWith => {
Signature::one_of(
vec![
Exact(vec![Utf8, Utf8]),
Exact(vec![Utf8, LargeUtf8]),
Exact(vec![LargeUtf8, Utf8]),
Exact(vec![LargeUtf8, LargeUtf8]),
],
self.volatility(),
)
}

BuiltinScalarFunction::EndsWith
| BuiltinScalarFunction::InStr
| BuiltinScalarFunction::Strpos
| BuiltinScalarFunction::StartsWith => Signature::one_of(
vec![
Exact(vec![Utf8, Utf8]),
Exact(vec![Utf8, LargeUtf8]),
Exact(vec![LargeUtf8, Utf8]),
Exact(vec![LargeUtf8, LargeUtf8]),
],
self.volatility(),
),

BuiltinScalarFunction::Substr => Signature::one_of(
vec![
Expand Down Expand Up @@ -1473,7 +1485,9 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Concat => &["concat"],
BuiltinScalarFunction::ConcatWithSeparator => &["concat_ws"],
BuiltinScalarFunction::Chr => &["chr"],
BuiltinScalarFunction::EndsWith => &["ends_with"],
BuiltinScalarFunction::InitCap => &["initcap"],
BuiltinScalarFunction::InStr => &["instr"],
BuiltinScalarFunction::Left => &["left"],
BuiltinScalarFunction::Lower => &["lower"],
BuiltinScalarFunction::Lpad => &["lpad"],
Expand Down
4 changes: 4 additions & 0 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,7 @@ scalar_expr!(Digest, digest, input algorithm, "compute the binary hash of `input
scalar_expr!(Encode, encode, input encoding, "encode the `input`, using the `encoding`. encoding can be base64 or hex");
scalar_expr!(Decode, decode, input encoding, "decode the`input`, using the `encoding`. encoding can be base64 or hex");
scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase");
scalar_expr!(InStr, instr, string substring, "returns the position of the first occurrence of `substring` in `string`");
scalar_expr!(Left, left, string n, "returns the first `n` characters in the `string`");
scalar_expr!(Lower, lower, string, "convert the string to lower case");
scalar_expr!(
Expand Down Expand Up @@ -830,6 +831,7 @@ scalar_expr!(SHA512, sha512, string, "SHA-512 hash");
scalar_expr!(SplitPart, split_part, string delimiter index, "splits a string based on a delimiter and picks out the desired field based on the index.");
scalar_expr!(StringToArray, string_to_array, string delimiter null_string, "splits a `string` based on a `delimiter` and returns an array of parts. Any parts matching the optional `null_string` will be replaced with `NULL`");
scalar_expr!(StartsWith, starts_with, string prefix, "whether the `string` starts with the `prefix`");
scalar_expr!(EndsWith, ends_with, string suffix, "whether the `string` ends with the `suffix`");
scalar_expr!(Strpos, strpos, string substring, "finds the position from where the `substring` matches the `string`");
scalar_expr!(Substr, substr, string position, "substring from the `position` to the end");
scalar_expr!(Substr, substring, string position length, "substring from the `position` with `length` characters");
Expand Down Expand Up @@ -1372,6 +1374,7 @@ mod test {
test_scalar_expr!(Gcd, gcd, arg_1, arg_2);
test_scalar_expr!(Lcm, lcm, arg_1, arg_2);
test_scalar_expr!(InitCap, initcap, string);
test_scalar_expr!(InStr, instr, string, substring);
test_scalar_expr!(Left, left, string, count);
test_scalar_expr!(Lower, lower, string);
test_nary_scalar_expr!(Lpad, lpad, string, count);
Expand Down Expand Up @@ -1410,6 +1413,7 @@ mod test {
test_scalar_expr!(SplitPart, split_part, expr, delimiter, index);
test_scalar_expr!(StringToArray, string_to_array, expr, delimiter, null_value);
test_scalar_expr!(StartsWith, starts_with, string, characters);
test_scalar_expr!(EndsWith, ends_with, string, characters);
test_scalar_expr!(Strpos, strpos, string, substring);
test_scalar_expr!(Substr, substr, string, position);
test_scalar_expr!(Substr, substring, string, position, count);
Expand Down
143 changes: 142 additions & 1 deletion datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,15 @@ pub fn create_physical_fun(
internal_err!("Unsupported data type {other:?} for function initcap")
}
}),
BuiltinScalarFunction::InStr => Arc::new(|args| match args[0].data_type() {
DataType::Utf8 => {
make_scalar_function(string_expressions::instr::<i32>)(args)
}
DataType::LargeUtf8 => {
make_scalar_function(string_expressions::instr::<i64>)(args)
}
other => internal_err!("Unsupported data type {other:?} for function instr"),
}),
BuiltinScalarFunction::Left => Arc::new(|args| match args[0].data_type() {
DataType::Utf8 => {
let func = invoke_if_unicode_expressions_feature_flag!(left, i32, "left");
Expand Down Expand Up @@ -779,6 +788,17 @@ pub fn create_physical_fun(
internal_err!("Unsupported data type {other:?} for function starts_with")
}
}),
BuiltinScalarFunction::EndsWith => Arc::new(|args| match args[0].data_type() {
DataType::Utf8 => {
make_scalar_function(string_expressions::ends_with::<i32>)(args)
}
DataType::LargeUtf8 => {
make_scalar_function(string_expressions::ends_with::<i64>)(args)
}
other => {
internal_err!("Unsupported data type {other:?} for function ends_with")
}
}),
BuiltinScalarFunction::Strpos => Arc::new(|args| match args[0].data_type() {
DataType::Utf8 => {
let func = invoke_if_unicode_expressions_feature_flag!(
Expand Down Expand Up @@ -1001,7 +1021,7 @@ mod tests {
use arrow::{
array::{
Array, ArrayRef, BinaryArray, BooleanArray, Float32Array, Float64Array,
Int32Array, StringArray, UInt64Array,
Int32Array, Int64Array, StringArray, UInt64Array,
},
datatypes::Field,
record_batch::RecordBatch,
Expand Down Expand Up @@ -1393,6 +1413,95 @@ mod tests {
Utf8,
StringArray
);
test_function!(
InStr,
&[lit("abc"), lit("b")],
Ok(Some(2)),
i32,
Int32,
Int32Array
);
test_function!(
InStr,
&[lit("abc"), lit("c")],
Ok(Some(3)),
i32,
Int32,
Int32Array
);
test_function!(
InStr,
&[lit("abc"), lit("d")],
Ok(Some(0)),
i32,
Int32,
Int32Array
);
test_function!(
InStr,
&[lit("abc"), lit("")],
Ok(Some(1)),
i32,
Int32,
Int32Array
);
test_function!(
InStr,
&[lit("Helloworld"), lit("world")],
Ok(Some(6)),
i32,
Int32,
Int32Array
);
test_function!(
InStr,
&[lit("Helloworld"), lit(ScalarValue::Utf8(None))],
Ok(None),
i32,
Int32,
Int32Array
);
test_function!(
InStr,
&[lit(ScalarValue::Utf8(None)), lit("Hello")],
Ok(None),
i32,
Int32,
Int32Array
);
test_function!(
InStr,
&[
lit(ScalarValue::LargeUtf8(Some("Helloworld".to_string()))),
lit(ScalarValue::LargeUtf8(Some("world".to_string())))
],
Ok(Some(6)),
i64,
Int64,
Int64Array
);
test_function!(
InStr,
&[
lit(ScalarValue::LargeUtf8(None)),
lit(ScalarValue::LargeUtf8(Some("world".to_string())))
],
Ok(None),
i64,
Int64,
Int64Array
);
test_function!(
InStr,
&[
lit(ScalarValue::LargeUtf8(Some("Helloworld".to_string()))),
lit(ScalarValue::LargeUtf8(None))
],
Ok(None),
i64,
Int64,
Int64Array
);
#[cfg(feature = "unicode_expressions")]
test_function!(
Left,
Expand Down Expand Up @@ -2511,6 +2620,38 @@ mod tests {
Boolean,
BooleanArray
);
test_function!(
EndsWith,
&[lit("alphabet"), lit("alph"),],
Ok(Some(false)),
bool,
Boolean,
BooleanArray
);
test_function!(
EndsWith,
&[lit("alphabet"), lit("bet"),],
Ok(Some(true)),
bool,
Boolean,
BooleanArray
);
test_function!(
EndsWith,
&[lit(ScalarValue::Utf8(None)), lit("alph"),],
Ok(None),
bool,
Boolean,
BooleanArray
);
test_function!(
EndsWith,
&[lit("alphabet"), lit(ScalarValue::Utf8(None)),],
Ok(None),
bool,
Boolean,
BooleanArray
);
#[cfg(feature = "unicode_expressions")]
test_function!(
Strpos,
Expand Down
Loading

0 comments on commit 084fdfb

Please sign in to comment.