diff --git a/Cargo.lock b/Cargo.lock index 77ab07d3bf8c..b16b59e67ba2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3031,6 +3031,7 @@ dependencies = [ "async-backtrace", "async-trait-fn", "chrono", + "common-ast", "common-base", "common-building", "common-catalog", diff --git a/src/query/ast/src/ast/expr.rs b/src/query/ast/src/ast/expr.rs index c52c8b7b1c9d..5675c586f2fa 100644 --- a/src/query/ast/src/ast/expr.rs +++ b/src/query/ast/src/ast/expr.rs @@ -530,6 +530,21 @@ impl Expr { | Expr::DateTrunc { span, .. } => *span, } } + + pub fn all_function_like_syntaxes() -> &'static [&'static str] { + &[ + "CAST", + "TRY_CAST", + "EXTRACT", + "DATE_PART", + "POSITION", + "SUBSTRING", + "TRIM", + "DATE_ADD", + "DATE_SUB", + "DATE_TRUNC", + ] + } } impl Display for IntervalKind { diff --git a/src/query/sql/src/planner/semantic/type_check.rs b/src/query/sql/src/planner/semantic/type_check.rs index 6867c87a51cd..7c9291c9c437 100644 --- a/src/query/sql/src/planner/semantic/type_check.rs +++ b/src/query/sql/src/planner/semantic/type_check.rs @@ -673,7 +673,7 @@ impl<'a> TypeChecker<'a> { let func_name = normalize_identifier(name, self.name_resolution_ctx).to_string(); let func_name = func_name.as_str(); if !is_builtin_function(func_name) - && !Self::all_rewritable_scalar_function().contains(&func_name) + && !Self::all_sugar_functions().contains(&func_name) { if let Some(udf) = self.resolve_udf(*span, func_name, args).await? { return Ok(udf); @@ -686,7 +686,7 @@ impl<'a> TypeChecker<'a> { .chain(GENERAL_WINDOW_FUNCTIONS.iter().cloned().map(str::to_string)) .chain(GENERAL_LAMBDA_FUNCTIONS.iter().cloned().map(str::to_string)) .chain( - Self::all_rewritable_scalar_function() + Self::all_sugar_functions() .iter() .cloned() .map(str::to_string), @@ -1742,7 +1742,7 @@ impl<'a> TypeChecker<'a> { ) -> Result> { // Check if current function is a virtual function, e.g. `database`, `version` if let Some(rewritten_func_result) = self - .try_rewrite_scalar_function(span, func_name, arguments) + .try_rewrite_sugar_function(span, func_name, arguments) .await { return rewritten_func_result; @@ -2133,7 +2133,7 @@ impl<'a> TypeChecker<'a> { Ok(Box::new((subquery_expr.into(), data_type))) } - pub fn all_rewritable_scalar_function() -> &'static [&'static str] { + pub fn all_sugar_functions() -> &'static [&'static str] { &[ "database", "currentdatabase", @@ -2162,7 +2162,7 @@ impl<'a> TypeChecker<'a> { #[async_recursion::async_recursion] #[async_backtrace::framed] - async fn try_rewrite_scalar_function( + async fn try_rewrite_sugar_function( &mut self, span: Span, func_name: &str, diff --git a/src/query/storages/system/Cargo.toml b/src/query/storages/system/Cargo.toml index 0a9300c2b974..bbdb4bd28f0b 100644 --- a/src/query/storages/system/Cargo.toml +++ b/src/query/storages/system/Cargo.toml @@ -15,6 +15,7 @@ test = false enable-histogram-metrics = ["common-metrics/enable-histogram"] [dependencies] +common-ast = { path = "../../ast" } common-base = { path = "../../../common/base" } common-catalog = { path = "../../catalog" } common-config = { path = "../../config" } diff --git a/src/query/storages/system/src/functions_table.rs b/src/query/storages/system/src/functions_table.rs index d69b531cfa6e..6bf0b0dae035 100644 --- a/src/query/storages/system/src/functions_table.rs +++ b/src/query/storages/system/src/functions_table.rs @@ -31,6 +31,7 @@ use common_meta_app::principal::UserDefinedFunction; use common_meta_app::schema::TableIdent; use common_meta_app::schema::TableInfo; use common_meta_app::schema::TableMeta; +use common_sql::TypeChecker; use common_users::UserApiProvider; use crate::table::AsyncOneBlockSystemTable; @@ -54,27 +55,37 @@ impl AsyncSystemTable for FunctionsTable { ctx: Arc, _push_downs: Option, ) -> Result { - // TODO(andylokandy): add rewritable function names, e.g. database() - let func_names = BUILTIN_FUNCTIONS.registered_names(); + let mut scalar_func_names: Vec = BUILTIN_FUNCTIONS.registered_names(); + scalar_func_names.extend( + TypeChecker::all_sugar_functions() + .iter() + .map(|name| name.to_string()), + ); + scalar_func_names.extend( + common_ast::ast::Expr::all_function_like_syntaxes() + .iter() + .map(|name| name.to_lowercase()), + ); + scalar_func_names.sort(); let aggregate_function_factory = AggregateFunctionFactory::instance(); let aggr_func_names = aggregate_function_factory.registered_names(); let udfs = FunctionsTable::get_udfs(ctx).await?; - let names: Vec<&str> = func_names + let names: Vec<&str> = scalar_func_names .iter() - .chain(aggr_func_names.iter()) + .chain(&aggr_func_names) .chain(udfs.iter().map(|udf| &udf.name)) .map(|x| x.as_str()) .collect(); - let builtin_func_len = func_names.len() + aggr_func_names.len(); + let builtin_func_len = scalar_func_names.len() + aggr_func_names.len(); let is_builtin = (0..names.len()) .map(|i| i < builtin_func_len) .collect::>(); let is_aggregate = (0..names.len()) - .map(|i| i >= func_names.len() && i < builtin_func_len) + .map(|i| i >= scalar_func_names.len() && i < builtin_func_len) .collect::>(); let definitions = (0..names.len())