From 562113b5c0175666887e7bfb722c441fcbc95d0f Mon Sep 17 00:00:00 2001 From: Tomoyuki Morita Date: Tue, 17 Sep 2024 10:41:35 -0700 Subject: [PATCH 01/14] Implement SQL validation based on grammar element Signed-off-by: Tomoyuki Morita --- .../DenyListGrammarElementValidator.java | 19 + .../sql/spark/validator/GrammarElement.java | 87 ++++ .../validator/GrammarElementValidator.java | 10 + .../GrammarElementValidatorFactory.java | 74 +++ .../spark/validator/SQLQueryValidator.java | 491 ++++++++++++++++++ .../validator/SQLQueryValidatorTest.java | 301 +++++++++++ 6 files changed, 982 insertions(+) create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/validator/DenyListGrammarElementValidator.java create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElement.java create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidator.java create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorFactory.java create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java create mode 100644 async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/DenyListGrammarElementValidator.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/DenyListGrammarElementValidator.java new file mode 100644 index 0000000000..514e2c8ad8 --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/DenyListGrammarElementValidator.java @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.validator; + +import java.util.Set; +import lombok.RequiredArgsConstructor; + +@RequiredArgsConstructor +public class DenyListGrammarElementValidator implements GrammarElementValidator { + private final Set denyList; + + @Override + public boolean isValid(GrammarElement element) { + return !denyList.contains(element); + } +} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElement.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElement.java new file mode 100644 index 0000000000..562a83dcd4 --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElement.java @@ -0,0 +1,87 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.validator; + +import lombok.AllArgsConstructor; + +@AllArgsConstructor +enum GrammarElement { + ALTER_NAMESPACE("ALTER DATABASE/TABLE/NAMESPACE"), + ALTER_VIEW("ALTER VIEW"), + CREATE_NAMESPACE("CREATE DATABASE/TABLE/NAMESPACE"), + CREATE_FUNCTION("CREATE FUNCTION"), + CREATE_VIEW("CREATE VIEW"), + DROP_NAMESPACE("DROP DATABASE/TABLE/NAMESPACE"), + DROP_FUNCTION("DROP FUNCTION"), + DROP_VIEW("DROP VIEW"), + DROP_TABLE("DROP TABLE"), + REPAIR_TABLE("REPAIR TABLE"), // does this conflict with DROP_NAMESPACE? + TRUNCATE_TABLE("TRUNCATE TABLE"), + // DML Statements + INSERT("INSERT"), + LOAD("LOAD"), + + // Data Retrieval Statements + EXPLAIN("EXPLAIN"), + WITH("WITH"), + CLUSTER_BY("CLUSTER BY"), + DISTRIBUTE_BY("DISTRIBUTE BY"), + GROUP_BY("GROUP BY"), + HAVING("HAVING"), + HINTS("HINTS"), + INLINE_TABLE("Inline Table(VALUES)"), + INNER_JOIN("INNER JOIN"), + CROSS_JOIN("CROSS JOIN"), + LEFT_OUTER_JOIN("LEFT OUTER JOIN"), + LEFT_SEMI_JOIN("LEFT SEMI JOIN"), + RIGHT_OUTER_JOIN("RIGHT OUTER JOIN"), + FULL_OUTER_JOIN("FULL OUTER JOIN"), + LEFT_ANTI_JOIN("LEFT ANTI JOIN"), + TABLESAMPLE("TABLESAMPLE"), + TABLE_VALUED_FUNCTION("Table-valued function"), + LATERAL_VIEW("LATERAL VIEW"), + LATERAL_SUBQUERY("LATERAL SUBQUERY"), + TRANSFORM("TRANSFORM"), + + // Auxiliary Statements + MANAGE_RESOURCE("Resource management statements"), + ANALYZE_TABLE("ANALYZE TABLE(S)"), + CACHE_TABLE("CACHE TABLE"), + CLEAR_CACHE("CLEAR CACHE"), + DESCRIBE_NAMESPACE("DESCRIBE (NAMESPACE|DATABASE|SCHEMA"), + DESCRIBE_FUNCTION("DESCRIBE FUNCTION"), + DESCRIBE_QUERY("DESCRIBE QUERY"), + DESCRIBE_TABLE("DESCRIBE TABLE"), + REFRESH_RESOURCE("REFRESH"), + REFRESH_TABLE("REFRESH TABLE"), + REFRESH_FUNCTION("REFRESH FUNCTION"), + RESET("RESET"), + SET("SET"), + SHOW_COLUMNS("SHOW COLUMNS"), + SHOW_CREATE_TABLE("SHOW CREATE TABLE"), + SHOW_NAMESPACES("SHOW (DATABASES|SCHEMAS)"), + SHOW_FUNCTIONS("SHOW FUNCTIONS"), + SHOW_PARTITIONS("SHOW PARTITIONS"), + SHOW_TABLE_EXTENDED("SHOW TABLE EXTENDED"), + SHOW_TABLES("SHOW TABLES"), + SHOW_TBLPROPERTIES("SHOW TBLPROPERTIES"), + SHOW_VIEWS("SHOW VIEWS"), + UNCACHE_TABLE("UNCACHE TABLE"), + + // Functions + MAP_FUNCTIONS("Map functions"), + CSV_FUNCTIONS("CSV functions"), + MISC_FUNCTIONS("Misc functions"), + + SELECT("SELECT"); + + String description; + + @Override + public String toString() { + return description; + } +} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidator.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidator.java new file mode 100644 index 0000000000..b11999b5d1 --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidator.java @@ -0,0 +1,10 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.validator; + +public interface GrammarElementValidator { + boolean isValid(GrammarElement element); +} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorFactory.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorFactory.java new file mode 100644 index 0000000000..99cecf18ae --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorFactory.java @@ -0,0 +1,74 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.validator; + +import static org.opensearch.sql.spark.validator.GrammarElement.*; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import java.util.Map; +import java.util.Set; +import org.opensearch.sql.datasource.model.DataSourceType; + +public class GrammarElementValidatorFactory { + private static final Set DEFAULT_DENY_LIST = + ImmutableSet.of(CREATE_FUNCTION, DROP_FUNCTION, INSERT, LOAD, HINTS, TABLESAMPLE); + + private static final Set CWL_DENY_LIST = + copyBuilder(DEFAULT_DENY_LIST) + .add( + ALTER_NAMESPACE, + ALTER_VIEW, + CREATE_NAMESPACE, + CREATE_VIEW, + DROP_NAMESPACE, + DROP_VIEW, + REPAIR_TABLE, + TRUNCATE_TABLE) + .build(); + + private static final Set S3GLUE_DENY_LIST = + copyBuilder(DEFAULT_DENY_LIST) + .add( + ALTER_VIEW, + CREATE_VIEW, + DROP_VIEW, + REPAIR_TABLE, + DISTRIBUTE_BY, + INLINE_TABLE, + TRUNCATE_TABLE, + CLUSTER_BY, + DISTRIBUTE_BY, + CROSS_JOIN, + LEFT_SEMI_JOIN, + RIGHT_OUTER_JOIN, + FULL_OUTER_JOIN, + LEFT_ANTI_JOIN, + TABLESAMPLE, + TABLE_VALUED_FUNCTION, + TRANSFORM, + MANAGE_RESOURCE, + DESCRIBE_FUNCTION, + REFRESH_RESOURCE, + REFRESH_FUNCTION, + RESET, + SET, + SHOW_FUNCTIONS, + SHOW_VIEWS, + MISC_FUNCTIONS) + .build(); + + private static Map validatorMap = + ImmutableMap.of(DataSourceType.S3GLUE, new DenyListGrammarElementValidator(S3GLUE_DENY_LIST)); + + public GrammarElementValidator getValidatorForDatasource(DataSourceType dataSourceType) { + return validatorMap.get(dataSourceType); + } + + private static ImmutableSet.Builder copyBuilder(Set original) { + return ImmutableSet.builder().addAll(original); + } +} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java new file mode 100644 index 0000000000..a737c62071 --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java @@ -0,0 +1,491 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.validator; + +import lombok.AllArgsConstructor; +import org.antlr.v4.runtime.tree.TerminalNode; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.AlterViewQueryContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.AlterViewSchemaBindingContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.AnalyzeContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.AnalyzeTablesContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.CacheTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ClearCacheContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ClusterBySpecContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.CreateNamespaceContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.CreateViewContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.CtesContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DescribeFunctionContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DescribeNamespaceContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DescribeQueryContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DescribeRelationContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DropFunctionContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DropNamespaceContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DropViewContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ExplainContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.FunctionIdentifierContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.HintContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.InlineTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.InsertIntoReplaceWhereContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.InsertIntoTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.InsertOverwriteDirContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.InsertOverwriteHiveDirContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.InsertOverwriteTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.JoinRelationContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.JoinTypeContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.LateralViewContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.LoadDataContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ManageResourceContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.QueryOrganizationContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RefreshFunctionContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RefreshResourceContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RefreshTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RelationContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RenameTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ResetConfigurationContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ResetQuotedConfigurationContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SampleContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SelectClauseContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SetConfigurationContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SetNamespaceLocationContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SetNamespacePropertiesContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SetQuantifierContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowColumnsContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowCreateTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowFunctionsContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowNamespacesContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowPartitionsContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowTableExtendedContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowTablesContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowTblPropertiesContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowViewsContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.TableValuedFunctionContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.TransformClauseContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.UncacheTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.UnsetNamespacePropertiesContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParserBaseVisitor; + +@AllArgsConstructor +public class SQLQueryValidator extends SqlBaseParserBaseVisitor { + private final GrammarElementValidator grammarElementValidator; + + public void validate(SqlBaseParser.SingleStatementContext statement) { + this.visit(statement); + } + + @Override + public Void visitCreateFunction(SqlBaseParser.CreateFunctionContext ctx) { + validateAllowed(GrammarElement.CREATE_FUNCTION); + return super.visitCreateFunction(ctx); + } + + @Override + public Void visitSelectClause(SelectClauseContext ctx) { + validateAllowed(GrammarElement.SELECT); + return super.visitSelectClause(ctx); + } + + @Override + public Void visitSetNamespaceProperties(SetNamespacePropertiesContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitSetNamespaceProperties(ctx); + } + + @Override + public Void visitUnsetNamespaceProperties(UnsetNamespacePropertiesContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitUnsetNamespaceProperties(ctx); + } + + @Override + public Void visitSetNamespaceLocation(SetNamespaceLocationContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitSetNamespaceLocation(ctx); + } + + @Override + public Void visitAlterViewQuery(AlterViewQueryContext ctx) { + validateAllowed(GrammarElement.ALTER_VIEW); + return super.visitAlterViewQuery(ctx); + } + + @Override + public Void visitAlterViewSchemaBinding(AlterViewSchemaBindingContext ctx) { + validateAllowed(GrammarElement.ALTER_VIEW); + return super.visitAlterViewSchemaBinding(ctx); + } + + @Override + public Void visitRenameTable(RenameTableContext ctx) { + TerminalNode view = ctx.VIEW(); + TerminalNode table = ctx.TABLE(); + if (ctx.VIEW() != null) { + validateAllowed(GrammarElement.ALTER_VIEW); + } else if (ctx.TABLE() != null) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + } + + return super.visitRenameTable(ctx); + } + + @Override + public Void visitCreateNamespace(CreateNamespaceContext ctx) { + validateAllowed(GrammarElement.CREATE_NAMESPACE); + return super.visitCreateNamespace(ctx); + } + + @Override + public Void visitDropNamespace(DropNamespaceContext ctx) { + validateAllowed(GrammarElement.DROP_NAMESPACE); + return super.visitDropNamespace(ctx); + } + + @Override + public Void visitCreateView(CreateViewContext ctx) { + validateAllowed(GrammarElement.CREATE_VIEW); + return super.visitCreateView(ctx); + } + + @Override + public Void visitDropView(DropViewContext ctx) { + validateAllowed(GrammarElement.DROP_VIEW); + return super.visitDropView(ctx); + } + + @Override + public Void visitDropFunction(DropFunctionContext ctx) { + validateAllowed(GrammarElement.DROP_FUNCTION); + return super.visitDropFunction(ctx); + } + + @Override + public Void visitInsertOverwriteTable(InsertOverwriteTableContext ctx) { + validateAllowed(GrammarElement.INSERT); + return super.visitInsertOverwriteTable(ctx); + } + + @Override + public Void visitInsertIntoReplaceWhere(InsertIntoReplaceWhereContext ctx) { + validateAllowed(GrammarElement.INSERT); + return super.visitInsertIntoReplaceWhere(ctx); + } + + @Override + public Void visitInsertIntoTable(InsertIntoTableContext ctx) { + validateAllowed(GrammarElement.INSERT); + return super.visitInsertIntoTable(ctx); + } + + @Override + public Void visitInsertOverwriteDir(InsertOverwriteDirContext ctx) { + validateAllowed(GrammarElement.INSERT); + return super.visitInsertOverwriteDir(ctx); + } + + @Override + public Void visitInsertOverwriteHiveDir(InsertOverwriteHiveDirContext ctx) { + validateAllowed(GrammarElement.INSERT); + return super.visitInsertOverwriteHiveDir(ctx); + } + + @Override + public Void visitLoadData(LoadDataContext ctx) { + validateAllowed(GrammarElement.LOAD); + return super.visitLoadData(ctx); + } + + @Override + public Void visitExplain(ExplainContext ctx) { + validateAllowed(GrammarElement.EXPLAIN); + return super.visitExplain(ctx); + } + + @Override + public Void visitCtes(CtesContext ctx) { + validateAllowed(GrammarElement.WITH); + return super.visitCtes(ctx); + } + + @Override + public Void visitClusterBySpec(ClusterBySpecContext ctx) { + validateAllowed(GrammarElement.CLUSTER_BY); + return super.visitClusterBySpec(ctx); + } + + @Override + public Void visitQueryOrganization(QueryOrganizationContext ctx) { + if (ctx.CLUSTER() != null) { + validateAllowed(GrammarElement.CLUSTER_BY); + } else if (ctx.DISTRIBUTE() != null) { + validateAllowed(GrammarElement.DISTRIBUTE_BY); + } + return super.visitQueryOrganization(ctx); + } + + @Override + public Void visitHint(HintContext ctx) { + validateAllowed(GrammarElement.HINTS); + return super.visitHint(ctx); + } + + @Override + public Void visitInlineTable(InlineTableContext ctx) { + validateAllowed(GrammarElement.INLINE_TABLE); + return super.visitInlineTable(ctx); + } + + @Override + public Void visitJoinType(JoinTypeContext ctx) { + if (ctx.CROSS() != null) { + validateAllowed(GrammarElement.CROSS_JOIN); + } else if (ctx.LEFT() != null && ctx.SEMI() != null) { + validateAllowed(GrammarElement.LEFT_SEMI_JOIN); + } else if (ctx.ANTI() != null) { + validateAllowed(GrammarElement.LEFT_ANTI_JOIN); + } else if (ctx.LEFT() != null) { + validateAllowed(GrammarElement.LEFT_OUTER_JOIN); + } else if (ctx.RIGHT() != null) { + validateAllowed(GrammarElement.RIGHT_OUTER_JOIN); + } else if (ctx.FULL() != null) { + validateAllowed(GrammarElement.FULL_OUTER_JOIN); + } else { + validateAllowed(GrammarElement.INNER_JOIN); + } + return super.visitJoinType(ctx); + } + + @Override + public Void visitSample(SampleContext ctx) { + validateAllowed(GrammarElement.TABLESAMPLE); + return super.visitSample(ctx); + } + + @Override + public Void visitTableValuedFunction(TableValuedFunctionContext ctx) { + validateAllowed(GrammarElement.TABLE_VALUED_FUNCTION); + return super.visitTableValuedFunction(ctx); + } + + @Override + public Void visitLateralView(LateralViewContext ctx) { + validateAllowed(GrammarElement.LATERAL_VIEW); + return super.visitLateralView(ctx); + } + + @Override + public Void visitRelation(RelationContext ctx) { + if (ctx.LATERAL() != null) { + validateAllowed(GrammarElement.LATERAL_SUBQUERY); + } + return super.visitRelation(ctx); + } + + @Override + public Void visitJoinRelation(JoinRelationContext ctx) { + if (ctx.LATERAL() != null) { + validateAllowed(GrammarElement.LATERAL_SUBQUERY); + } + return super.visitJoinRelation(ctx); + } + + @Override + public Void visitTransformClause(TransformClauseContext ctx) { + if (ctx.TRANSFORM() != null) { + validateAllowed(GrammarElement.TRANSFORM); + } + return super.visitTransformClause(ctx); + } + + @Override + public Void visitManageResource(ManageResourceContext ctx) { + validateAllowed(GrammarElement.MANAGE_RESOURCE); + return super.visitManageResource(ctx); + } + + @Override + public Void visitAnalyze(AnalyzeContext ctx) { + validateAllowed(GrammarElement.ANALYZE_TABLE); + return super.visitAnalyze(ctx); + } + + @Override + public Void visitAnalyzeTables(AnalyzeTablesContext ctx) { + validateAllowed(GrammarElement.ANALYZE_TABLE); + return super.visitAnalyzeTables(ctx); + } + + @Override + public Void visitCacheTable(CacheTableContext ctx) { + validateAllowed(GrammarElement.CACHE_TABLE); + return super.visitCacheTable(ctx); + } + + @Override + public Void visitClearCache(ClearCacheContext ctx) { + validateAllowed(GrammarElement.CLEAR_CACHE); + return super.visitClearCache(ctx); + } + + @Override + public Void visitDescribeNamespace(DescribeNamespaceContext ctx) { + validateAllowed(GrammarElement.DESCRIBE_NAMESPACE); + return super.visitDescribeNamespace(ctx); + } + + @Override + public Void visitDescribeFunction(DescribeFunctionContext ctx) { + validateAllowed(GrammarElement.DESCRIBE_FUNCTION); + return super.visitDescribeFunction(ctx); + } + + @Override + public Void visitDescribeRelation(DescribeRelationContext ctx) { + validateAllowed(GrammarElement.DESCRIBE_TABLE); + return super.visitDescribeRelation(ctx); + } + + @Override + public Void visitDescribeQuery(DescribeQueryContext ctx) { + validateAllowed(GrammarElement.DESCRIBE_QUERY); + return super.visitDescribeQuery(ctx); + } + + @Override + public Void visitRefreshResource(RefreshResourceContext ctx) { + validateAllowed(GrammarElement.REFRESH_RESOURCE); + return super.visitRefreshResource(ctx); + } + + @Override + public Void visitRefreshTable(RefreshTableContext ctx) { + validateAllowed(GrammarElement.REFRESH_TABLE); + return super.visitRefreshTable(ctx); + } + + @Override + public Void visitRefreshFunction(RefreshFunctionContext ctx) { + validateAllowed(GrammarElement.REFRESH_FUNCTION); + return super.visitRefreshFunction(ctx); + } + + @Override + public Void visitResetConfiguration(ResetConfigurationContext ctx) { + validateAllowed(GrammarElement.RESET); + return super.visitResetConfiguration(ctx); + } + + @Override + public Void visitResetQuotedConfiguration(ResetQuotedConfigurationContext ctx) { + validateAllowed(GrammarElement.RESET); + return super.visitResetQuotedConfiguration(ctx); + } + + @Override + public Void visitSetConfiguration(SetConfigurationContext ctx) { + validateAllowed(GrammarElement.SET); + return super.visitSetConfiguration(ctx); + } + + @Override + public Void visitSetQuantifier(SetQuantifierContext ctx) { + validateAllowed(GrammarElement.SET); + return super.visitSetQuantifier(ctx); + } + + @Override + public Void visitShowColumns(ShowColumnsContext ctx) { + validateAllowed(GrammarElement.SHOW_COLUMNS); + return super.visitShowColumns(ctx); + } + + @Override + public Void visitShowCreateTable(ShowCreateTableContext ctx) { + validateAllowed(GrammarElement.SHOW_CREATE_TABLE); + return super.visitShowCreateTable(ctx); + } + + @Override + public Void visitShowNamespaces(ShowNamespacesContext ctx) { + validateAllowed(GrammarElement.SHOW_NAMESPACES); + return super.visitShowNamespaces(ctx); + } + + @Override + public Void visitShowFunctions(ShowFunctionsContext ctx) { + validateAllowed(GrammarElement.SHOW_FUNCTIONS); + return super.visitShowFunctions(ctx); + } + + @Override + public Void visitShowPartitions(ShowPartitionsContext ctx) { + validateAllowed(GrammarElement.SHOW_PARTITIONS); + return super.visitShowPartitions(ctx); + } + + @Override + public Void visitShowTableExtended(ShowTableExtendedContext ctx) { + validateAllowed(GrammarElement.SHOW_TABLE_EXTENDED); + return super.visitShowTableExtended(ctx); + } + + @Override + public Void visitShowTables(ShowTablesContext ctx) { + validateAllowed(GrammarElement.SHOW_TABLES); + return super.visitShowTables(ctx); + } + + @Override + public Void visitShowTblProperties(ShowTblPropertiesContext ctx) { + validateAllowed(GrammarElement.SHOW_TBLPROPERTIES); + return super.visitShowTblProperties(ctx); + } + + @Override + public Void visitShowViews(ShowViewsContext ctx) { + validateAllowed(GrammarElement.SHOW_VIEWS); + return super.visitShowViews(ctx); + } + + @Override + public Void visitUncacheTable(UncacheTableContext ctx) { + validateAllowed(GrammarElement.UNCACHE_TABLE); + return super.visitUncacheTable(ctx); + } + + @Override + public Void visitFunctionIdentifier(FunctionIdentifierContext ctx) { + String function = ctx.function.getText().toLowerCase(); + if (isMapFunctions(function)) { + validateAllowed(GrammarElement.MAP_FUNCTIONS); + } else if (isCsvFunctions(function)) { + validateAllowed(GrammarElement.CSV_FUNCTIONS); + } else if (isMiscFunctions(function)) { + validateAllowed(GrammarElement.MISC_FUNCTIONS); + } + return super.visitFunctionIdentifier(ctx); + } + + private boolean isMapFunctions(String function) { + // TODO: to be implemented + return false; + } + + private boolean isCsvFunctions(String function) { + // TODO: to be implemented + return false; + } + + private boolean isMiscFunctions(String function) { + // TODO: to be implemented + return false; + } + + private void validateAllowed(GrammarElement element) { + if (!grammarElementValidator.isValid(element)) { + throw new IllegalArgumentException(element + " is not allowed."); + } + } +} diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java new file mode 100644 index 0000000000..85f9d0f284 --- /dev/null +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java @@ -0,0 +1,301 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.validator; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +import lombok.AllArgsConstructor; +import org.antlr.v4.runtime.CommonTokenStream; +import org.junit.jupiter.api.Test; +import org.opensearch.sql.common.antlr.CaseInsensitiveCharStream; +import org.opensearch.sql.datasource.model.DataSourceType; +import org.opensearch.sql.spark.antlr.parser.SqlBaseLexer; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SingleStatementContext; + +class SQLQueryValidatorTest { + GrammarElementValidatorFactory factory = new GrammarElementValidatorFactory(); + + @AllArgsConstructor + private enum TestQuery { + // DDL Statements + ALTER_DATABASE( + "ALTER DATABASE inventory SET DBPROPERTIES ('Edited-by' = 'John', 'Edit-date' =" + + " '01/01/2001');"), + ALTER_TABLE( + "ALTER TABLE default.StudentInfo PARTITION (age='10') RENAME TO PARTITION (age='15');"), + ALTER_VIEW("ALTER VIEW tempdb1.v1 RENAME TO tempdb1.v2;"), + CREATE_DATABASE("CREATE DATABASE IF NOT EXISTS customer_db;\n"), + CREATE_FUNCTION("CREATE FUNCTION simple_udf AS 'SimpleUdf' USING JAR '/tmp/SimpleUdf.jar';"), + CREATE_TABLE("CREATE TABLE Student_Dupli like Student;"), + CREATE_VIEW( + "CREATE OR REPLACE VIEW experienced_employee" + + " (ID COMMENT 'Unique identification number', Name)" + + " COMMENT 'View for experienced employees'" + + " AS SELECT id, name FROM all_employee" + + " WHERE working_years > 5;"), + DROP_DATABASE("DROP DATABASE inventory_db CASCADE;"), + DROP_FUNCTION("DROP FUNCTION test_avg;"), + DROP_TABLE("DROP TABLE employeetable;"), + DROP_VIEW("DROP VIEW employeeView;"), + REPAIR_TABLE("REPAIR TABLE t1;"), + TRUNCATE_TABLE("TRUNCATE TABLE Student partition(age=10);"), + + // DML Statements + INSERT_TABLE("INSERT INTO target_table SELECT * FROM source_table;"), + INSERT_OVERWRITE_DIRECTORY( + "INSERT OVERWRITE DIRECTORY '/path/to/output' SELECT * FROM source_table;"), + LOAD("LOAD DATA INPATH '/path/to/data' INTO TABLE target_table;"), + + // Data Retrieval Statements + SELECT("SELECT 1"), + EXPLAIN("EXPLAIN SELECT * FROM my_table;"), + COMMON_TABLE_EXPRESSION( + "WITH cte AS (SELECT * FROM my_table WHERE age > 30) SELECT * FROM cte;"), + CLUSTER_BY_CLAUSE("SELECT * FROM my_table CLUSTER BY age;"), + DISTRIBUTE_BY_CLAUSE("SELECT * FROM my_table DISTRIBUTE BY name;"), + GROUP_BY_CLAUSE("SELECT name, count(*) FROM my_table GROUP BY name;"), + HAVING_CLAUSE("SELECT name, count(*) FROM my_table GROUP BY name HAVING count(*) > 1;"), + HINTS("SELECT /*+ BROADCAST(my_table) */ * FROM my_table;"), + INLINE_TABLE("SELECT * FROM (VALUES (1, 'a'), (2, 'b')) AS inline_table(id, value);"), + FILE("SELECT * FROM text.`/path/to/file.txt`;"), + INNER_JOIN("SELECT t1.name, t2.age FROM table1 t1 INNER JOIN table2 t2 ON t1.id = t2.id;"), + CROSS_JOIN("SELECT t1.name, t2.age FROM table1 t1 CROSS JOIN table2 t2;"), + LEFT_OUTER_JOIN( + "SELECT t1.name, t2.age FROM table1 t1 LEFT OUTER JOIN table2 t2 ON t1.id = t2.id;"), + LEFT_SEMI_JOIN("SELECT t1.name FROM table1 t1 LEFT SEMI JOIN table2 t2 ON t1.id = t2.id;"), + RIGHT_OUTER_JOIN( + "SELECT t1.name, t2.age FROM table1 t1 RIGHT OUTER JOIN table2 t2 ON t1.id = t2.id;"), + FULL_OUTER_JOIN( + "SELECT t1.name, t2.age FROM table1 t1 FULL OUTER JOIN table2 t2 ON t1.id = t2.id;"), + LEFT_ANTI_JOIN("SELECT t1.name FROM table1 t1 LEFT ANTI JOIN table2 t2 ON t1.id = t2.id;"), + LIKE_PREDICATE("SELECT * FROM my_table WHERE name LIKE 'A%';"), + LIMIT_CLAUSE("SELECT * FROM my_table LIMIT 10;"), + OFFSET_CLAUSE("SELECT * FROM my_table OFFSET 5 ROWS;"), + ORDER_BY_CLAUSE("SELECT * FROM my_table ORDER BY age DESC;"), + SET_OPERATORS("SELECT * FROM table1 UNION SELECT * FROM table2;"), + SORT_BY_CLAUSE("SELECT * FROM my_table SORT BY age DESC;"), + TABLESAMPLE("SELECT * FROM my_table TABLESAMPLE(10 PERCENT);"), + // TABLE_VALUED_FUNCTION("SELECT explode(array(10, 20));"), TODO: Need to handle this case + TABLE_VALUED_FUNCTION("SELECT * FROM explode(array(10, 20));"), + WHERE_CLAUSE("SELECT * FROM my_table WHERE age > 30;"), + AGGREGATE_FUNCTION("SELECT count(*) FROM my_table;"), + WINDOW_FUNCTION("SELECT name, age, rank() OVER (ORDER BY age DESC) FROM my_table;"), + CASE_CLAUSE("SELECT name, CASE WHEN age > 30 THEN 'Adult' ELSE 'Young' END FROM my_table;"), + PIVOT_CLAUSE( + "SELECT * FROM (SELECT name, age, gender FROM my_table) PIVOT (COUNT(*) FOR gender IN ('M'," + + " 'F'));"), + UNPIVOT_CLAUSE( + "SELECT name, value, category FROM (SELECT name, 'M' AS gender, age AS male_age, 0 AS" + + " female_age FROM my_table) UNPIVOT (value FOR category IN (male_age, female_age));"), + LATERAL_VIEW_CLAUSE( + "SELECT name, age, exploded_value FROM my_table LATERAL VIEW OUTER EXPLODE(split(comments," + + " ',')) exploded_table AS exploded_value;"), + LATERAL_SUBQUERY( + "SELECT name, age, (SELECT max(age) FROM my_table t2 WHERE t1.age < t2.age) AS next_age" + + " FROM my_table t1;"), + TRANSFORM_CLAUSE( + "SELECT transform(zip_code, name, age) USING 'cat' AS (a, b, c) FROM my_table;"), + + // Auxiliary Statements + ADD_FILE("ADD FILE /tmp/test.txt;"), + ADD_JAR("ADD JAR /path/to/my.jar;"), + ANALYZE_TABLE("ANALYZE TABLE my_table COMPUTE STATISTICS;"), + CACHE_TABLE("CACHE TABLE my_table;"), + CLEAR_CACHE("CLEAR CACHE;"), + DESCRIBE_DATABASE("DESCRIBE DATABASE my_db;"), + DESCRIBE_FUNCTION("DESCRIBE FUNCTION my_function;"), + DESCRIBE_QUERY("DESCRIBE QUERY SELECT * FROM my_table;"), + DESCRIBE_TABLE("DESCRIBE TABLE my_table;"), + LIST_FILE("LIST FILE '/path/to/files';"), + LIST_JAR("LIST JAR;"), + REFRESH("REFRESH;"), + REFRESH_TABLE("REFRESH TABLE my_table;"), + REFRESH_FUNCTION("REFRESH FUNCTION my_function;"), + RESET("RESET;"), + SET("SET spark.sql.shuffle.partitions=200;"), + SHOW_COLUMNS("SHOW COLUMNS FROM my_table;"), + SHOW_CREATE_TABLE("SHOW CREATE TABLE my_table;"), + SHOW_DATABASES("SHOW DATABASES;"), + SHOW_FUNCTIONS("SHOW FUNCTIONS;"), + SHOW_PARTITIONS("SHOW PARTITIONS my_table;"), + SHOW_TABLE_EXTENDED("SHOW TABLE EXTENDED LIKE 'my_table';"), + SHOW_TABLES("SHOW TABLES;"), + SHOW_TBLPROPERTIES("SHOW TBLPROPERTIES my_table;"), + SHOW_VIEWS("SHOW VIEWS;"), + UNCACHE_TABLE("UNCACHE TABLE my_table;"), + + // Functions + ARRAY_FUNCTIONS("SELECT array_contains(array(1, 2, 3), 2);"), + MAP_FUNCTIONS("SELECT map_keys(map('a', 1, 'b', 2));"), + DATE_AND_TIMESTAMP_FUNCTIONS("SELECT date_format(current_date(), 'yyyy-MM-dd');"), + JSON_FUNCTIONS("SELECT json_tuple('{\"a\":1, \"b\":2}', 'a', 'b');"), + MATHEMATICAL_FUNCTIONS("SELECT round(3.1415, 2);"), + STRING_FUNCTIONS("SELECT concat('Hello', ' ', 'World');"), + BITWISE_FUNCTIONS("SELECT bitwiseNOT(42);"), + CONVERSION_FUNCTIONS("SELECT cast('2023-04-01' as date);"), + CONDITIONAL_FUNCTIONS("SELECT if(1 > 0, 'true', 'false');"), + PREDICATE_FUNCTIONS("SELECT array_exists(array(1, 2, 3), x -> x > 2);"), + CSV_FUNCTIONS("SELECT csv_from_array(array('a', 'b', 'c'), ',');"), + MISC_FUNCTIONS("SELECT hash('Hello World');"), + + // Aggregate-like Functions + AGGREGATE_FUNCTIONS("SELECT count(*), max(age), min(age) FROM my_table;"), + WINDOW_FUNCTIONS("SELECT name, age, rank() OVER (ORDER BY age DESC) FROM my_table;"), + + // Generator Functions + GENERATOR_FUNCTIONS("SELECT explode(array(1, 2, 3));"), + + // UDFs (User-Defined Functions) + SCALAR_USER_DEFINED_FUNCTIONS("SELECT my_udf(name) FROM my_table;"), + USER_DEFINED_AGGREGATE_FUNCTIONS("SELECT my_udaf(age) FROM my_table GROUP BY name;"), + INTEGRATION_WITH_HIVE_UDFS_UDAFS_UDTFS("SELECT my_hive_udf(name) FROM my_table;"); + + private final String query; + + @Override + public String toString() { + return query; + } + } + + @Test + void s3glueQueries() { + SQLQueryValidator v = + new SQLQueryValidator(factory.getValidatorForDatasource(DataSourceType.S3GLUE)); + verifyValid(v, TestQuery.ALTER_DATABASE); + verifyValid(v, TestQuery.ALTER_TABLE); + verifyInvalid(v, TestQuery.ALTER_VIEW); + verifyValid(v, TestQuery.CREATE_DATABASE); + verifyInvalid(v, TestQuery.CREATE_FUNCTION); + verifyValid(v, TestQuery.CREATE_TABLE); + verifyInvalid(v, TestQuery.CREATE_VIEW); + verifyValid(v, TestQuery.DROP_DATABASE); + verifyInvalid(v, TestQuery.DROP_FUNCTION); + verifyValid(v, TestQuery.DROP_TABLE); + verifyInvalid(v, TestQuery.DROP_VIEW); + verifyValid(v, TestQuery.REPAIR_TABLE); + verifyValid(v, TestQuery.TRUNCATE_TABLE); + + // DML Statements + verifyInvalid(v, TestQuery.INSERT_TABLE); + verifyInvalid(v, TestQuery.INSERT_OVERWRITE_DIRECTORY); + verifyInvalid(v, TestQuery.LOAD); + + // Data Retrieval + verifyValid(v, TestQuery.SELECT); + verifyValid(v, TestQuery.EXPLAIN); + verifyValid(v, TestQuery.COMMON_TABLE_EXPRESSION); + verifyInvalid(v, TestQuery.CLUSTER_BY_CLAUSE); + verifyInvalid(v, TestQuery.DISTRIBUTE_BY_CLAUSE); + verifyValid(v, TestQuery.GROUP_BY_CLAUSE); + verifyValid(v, TestQuery.HAVING_CLAUSE); + verifyInvalid(v, TestQuery.HINTS); + verifyInvalid(v, TestQuery.INLINE_TABLE); + // verifyInvalid(v, TestQuery.FILE); TODO: need dive deep + verifyValid(v, TestQuery.INNER_JOIN); + verifyInvalid(v, TestQuery.CROSS_JOIN); + verifyValid(v, TestQuery.LEFT_OUTER_JOIN); + verifyInvalid(v, TestQuery.LEFT_SEMI_JOIN); + verifyInvalid(v, TestQuery.RIGHT_OUTER_JOIN); + verifyInvalid(v, TestQuery.FULL_OUTER_JOIN); + verifyInvalid(v, TestQuery.LEFT_ANTI_JOIN); + verifyValid(v, TestQuery.LIKE_PREDICATE); + verifyValid(v, TestQuery.LIMIT_CLAUSE); + verifyValid(v, TestQuery.OFFSET_CLAUSE); + verifyValid(v, TestQuery.ORDER_BY_CLAUSE); + verifyValid(v, TestQuery.SET_OPERATORS); + verifyValid(v, TestQuery.SORT_BY_CLAUSE); + verifyInvalid(v, TestQuery.TABLESAMPLE); + verifyInvalid(v, TestQuery.TABLE_VALUED_FUNCTION); + verifyValid(v, TestQuery.WHERE_CLAUSE); + verifyValid(v, TestQuery.AGGREGATE_FUNCTION); + verifyValid(v, TestQuery.WINDOW_FUNCTION); + verifyValid(v, TestQuery.CASE_CLAUSE); + verifyValid(v, TestQuery.PIVOT_CLAUSE); + verifyValid(v, TestQuery.UNPIVOT_CLAUSE); + verifyValid(v, TestQuery.LATERAL_VIEW_CLAUSE); + verifyValid(v, TestQuery.LATERAL_SUBQUERY); + verifyInvalid(v, TestQuery.TRANSFORM_CLAUSE); + + // Auxiliary Statements + verifyInvalid(v, TestQuery.ADD_FILE); + verifyInvalid(v, TestQuery.ADD_JAR); + verifyValid(v, TestQuery.ANALYZE_TABLE); + verifyValid(v, TestQuery.CACHE_TABLE); + verifyValid(v, TestQuery.CLEAR_CACHE); + verifyValid(v, TestQuery.DESCRIBE_DATABASE); + verifyInvalid(v, TestQuery.DESCRIBE_FUNCTION); + verifyValid(v, TestQuery.DESCRIBE_QUERY); + verifyValid(v, TestQuery.DESCRIBE_TABLE); + verifyInvalid(v, TestQuery.LIST_FILE); + verifyInvalid(v, TestQuery.LIST_JAR); + verifyInvalid(v, TestQuery.REFRESH); + // verifyValid(v, TestQuery.REFRESH_TABLE); TODO: refreshTable rule won't match (matches to + // refreshResource) + verifyInvalid(v, TestQuery.REFRESH_FUNCTION); + verifyInvalid(v, TestQuery.RESET); + verifyInvalid(v, TestQuery.SET); + verifyValid(v, TestQuery.SHOW_COLUMNS); + verifyValid(v, TestQuery.SHOW_CREATE_TABLE); + verifyValid(v, TestQuery.SHOW_DATABASES); + verifyInvalid(v, TestQuery.SHOW_FUNCTIONS); + verifyValid(v, TestQuery.SHOW_PARTITIONS); + verifyValid(v, TestQuery.SHOW_TABLE_EXTENDED); + verifyValid(v, TestQuery.SHOW_TABLES); + verifyValid(v, TestQuery.SHOW_TBLPROPERTIES); + verifyInvalid(v, TestQuery.SHOW_VIEWS); + verifyValid(v, TestQuery.UNCACHE_TABLE); + + // Functions + // verifyValid(v, TestQuery.ARRAY_FUNCTIONS); + // verifyValid(v, TestQuery.MAP_FUNCTIONS); + // verifyValid(v, TestQuery.DATE_AND_TIMESTAMP_FUNCTIONS); + // verifyValid(v, TestQuery.JSON_FUNCTIONS); + // verifyValid(v, TestQuery.MATHEMATICAL_FUNCTIONS); + // verifyValid(v, TestQuery.STRING_FUNCTIONS); + // verifyValid(v, TestQuery.BITWISE_FUNCTIONS); + // verifyValid(v, TestQuery.CONVERSION_FUNCTIONS); + // verifyValid(v, TestQuery.CONDITIONAL_FUNCTIONS); + // verifyValid(v, TestQuery.PREDICATE_FUNCTIONS); + // verifyValid(v, TestQuery.CSV_FUNCTIONS); + // verifyValid(v, TestQuery.MISC_FUNCTIONS); + + // Aggregate-like Functions + // verifyValid(v, TestQuery.AGGREGATE_FUNCTIONS); + // verifyValid(v, TestQuery.WINDOW_FUNCTIONS); + + // Generator Functions + // verifyValid(v, TestQuery.GENERATOR_FUNCTIONS); + + // UDFs + // verifyInvalid(v, TestQuery.SCALAR_USER_DEFINED_FUNCTIONS); + // verifyInvalid(v, TestQuery.USER_DEFINED_AGGREGATE_FUNCTIONS); + // verifyInvalid(v, TestQuery.INTEGRATION_WITH_HIVE_UDFS_UDAFS_UDTFS); + } + + void verifyValid(SQLQueryValidator validator, TestQuery query) { + runValidate(validator, query.toString()); + } + + void verifyInvalid(SQLQueryValidator validator, TestQuery query) { + assertThrows( + IllegalArgumentException.class, + () -> runValidate(validator, query.toString()), + "The query should throw: query=`" + query.toString() + "`"); + } + + void runValidate(SQLQueryValidator validator, String query) { + validator.validate(getParser(query)); + } + + SingleStatementContext getParser(String query) { + SqlBaseParser sqlBaseParser = + new SqlBaseParser( + new CommonTokenStream(new SqlBaseLexer(new CaseInsensitiveCharStream(query)))); + return sqlBaseParser.singleStatement(); + } +} From d77e408f5aadc36e0009cfb1687b0607de8d62ab Mon Sep 17 00:00:00 2001 From: Tomoyuki Morita Date: Tue, 17 Sep 2024 15:01:29 -0700 Subject: [PATCH 02/14] Add function types Signed-off-by: Tomoyuki Morita --- .../sql/spark/validator/FunctionType.java | 432 ++++++++++++++++++ .../sql/spark/validator/GrammarElement.java | 3 +- .../GrammarElementValidatorFactory.java | 47 +- .../spark/validator/SQLQueryValidator.java | 47 +- .../sql/spark/validator/FunctionTypeTest.java | 46 ++ .../validator/SQLQueryValidatorTest.java | 46 +- 6 files changed, 570 insertions(+), 51 deletions(-) create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/validator/FunctionType.java create mode 100644 async-query-core/src/test/java/org/opensearch/sql/spark/validator/FunctionTypeTest.java diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/FunctionType.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/FunctionType.java new file mode 100644 index 0000000000..0a821a7a8c --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/FunctionType.java @@ -0,0 +1,432 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.validator; + +import com.google.common.collect.ImmutableMap; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import lombok.AllArgsConstructor; + +@AllArgsConstructor +public enum FunctionType { + AGGREGATE("Aggregate"), + WINDOW("Window"), + ARRAY("Array"), + MAP("Map"), + DATE_TIMESTAMP("Date and Timestamp"), + JSON("JSON"), + MATH("Math"), + STRING("String"), + CONDITIONAL("Conditional"), + BITWISE("Bitwise"), + CONVERSION("Conversion"), + PREDICATE("Predicate"), + CSV("CSV"), + MISC("Misc"), + GENERATOR("Generator"), + UDF("User Defined Function"); + + private final String name; + + private static final Map> FUNCTION_TYPE_TO_FUNCTION_NAMES_MAP = + ImmutableMap.>builder() + .put( + AGGREGATE, + Set.of( + "any", + "any_value", + "approx_count_distinct", + "approx_percentile", + "array_agg", + "avg", + "bit_and", + "bit_or", + "bit_xor", + "bitmap_construct_agg", + "bitmap_or_agg", + "bool_and", + "bool_or", + "collect_list", + "collect_set", + "corr", + "count", + "count_if", + "count_min_sketch", + "covar_pop", + "covar_samp", + "every", + "first", + "first_value", + "grouping", + "grouping_id", + "histogram_numeric", + "hll_sketch_agg", + "hll_union_agg", + "kurtosis", + "last", + "last_value", + "max", + "max_by", + "mean", + "median", + "min", + "min_by", + "mode", + "percentile", + "percentile_approx", + "regr_avgx", + "regr_avgy", + "regr_count", + "regr_intercept", + "regr_r2", + "regr_slope", + "regr_sxx", + "regr_sxy", + "regr_syy", + "skewness", + "some", + "std", + "stddev", + "stddev_pop", + "stddev_samp", + "sum", + "try_avg", + "try_sum", + "var_pop", + "var_samp", + "variance")) + .put( + WINDOW, + Set.of( + "cume_dist", + "dense_rank", + "lag", + "lead", + "nth_value", + "ntile", + "percent_rank", + "rank", + "row_number")) + .put( + ARRAY, + Set.of( + "array", + "array_append", + "array_compact", + "array_contains", + "array_distinct", + "array_except", + "array_insert", + "array_intersect", + "array_join", + "array_max", + "array_min", + "array_position", + "array_prepend", + "array_remove", + "array_repeat", + "array_union", + "arrays_overlap", + "arrays_zip", + "flatten", + "get", + "sequence", + "shuffle", + "slice", + "sort_array")) + .put( + MAP, + Set.of( + "element_at", + "map", + "map_concat", + "map_contains_key", + "map_entries", + "map_from_arrays", + "map_from_entries", + "map_keys", + "map_values", + "str_to_map", + "try_element_at")) + .put( + DATE_TIMESTAMP, + Set.of( + "add_months", + "convert_timezone", + "curdate", + "current_date", + "current_timestamp", + "current_timezone", + "date_add", + "date_diff", + "date_format", + "date_from_unix_date", + "date_part", + "date_sub", + "date_trunc", + "dateadd", + "datediff", + "datepart", + "day", + "dayofmonth", + "dayofweek", + "dayofyear", + "extract", + "from_unixtime", + "from_utc_timestamp", + "hour", + "last_day", + "localtimestamp", + "make_date", + "make_dt_interval", + "make_interval", + "make_timestamp", + "make_timestamp_ltz", + "make_timestamp_ntz", + "make_ym_interval", + "minute", + "month", + "months_between", + "next_day", + "now", + "quarter", + "second", + "session_window", + "timestamp_micros", + "timestamp_millis", + "timestamp_seconds", + "to_date", + "to_timestamp", + "to_timestamp_ltz", + "to_timestamp_ntz", + "to_unix_timestamp", + "to_utc_timestamp", + "trunc", + "try_to_timestamp", + "unix_date", + "unix_micros", + "unix_millis", + "unix_seconds", + "unix_timestamp", + "weekday", + "weekofyear", + "window", + "window_time", + "year")) + .put( + JSON, + Set.of( + "from_json", + "get_json_object", + "json_array_length", + "json_object_keys", + "json_tuple", + "schema_of_json", + "to_json")) + .put( + MATH, + Set.of( + "abs", + "acos", + "acosh", + "asin", + "asinh", + "atan", + "atan2", + "atanh", + "bin", + "bround", + "cbrt", + "ceil", + "ceiling", + "conv", + "cos", + "cosh", + "cot", + "csc", + "degrees", + "e", + "exp", + "expm1", + "factorial", + "floor", + "greatest", + "hex", + "hypot", + "least", + "ln", + "log", + "log10", + "log1p", + "log2", + "negative", + "pi", + "pmod", + "positive", + "pow", + "power", + "radians", + "rand", + "randn", + "random", + "rint", + "round", + "sec", + "shiftleft", + "sign", + "signum", + "sin", + "sinh", + "sqrt", + "tan", + "tanh", + "try_add", + "try_divide", + "try_multiply", + "try_subtract", + "unhex", + "width_bucket")) + .put( + STRING, + Set.of( + "ascii", + "base64", + "bit_length", + "btrim", + "char", + "char_length", + "character_length", + "chr", + "concat", + "concat_ws", + "contains", + "decode", + "elt", + "encode", + "endswith", + "find_in_set", + "format_number", + "format_string", + "initcap", + "instr", + "lcase", + "left", + "len", + "length", + "levenshtein", + "locate", + "lower", + "lpad", + "ltrim", + "luhn_check", + "mask", + "octet_length", + "overlay", + "position", + "printf", + "regexp_count", + "regexp_extract", + "regexp_extract_all", + "regexp_instr", + "regexp_replace", + "regexp_substr", + "repeat", + "replace", + "right", + "rpad", + "rtrim", + "sentences", + "soundex", + "space", + "split", + "split_part", + "startswith", + "substr", + "substring", + "substring_index", + "to_binary", + "to_char", + "to_number", + "to_varchar", + "translate", + "trim", + "try_to_binary", + "try_to_number", + "ucase", + "unbase64", + "upper")) + .put(CONDITIONAL, Set.of("coalesce", "if", "ifnull", "nanvl", "nullif", "nvl", "nvl2")) + .put( + BITWISE, Set.of("bit_count", "bit_get", "getbit", "shiftright", "shiftrightunsigned")) + .put( + CONVERSION, + Set.of( + "bigint", + "binary", + "boolean", + "cast", + "date", + "decimal", + "double", + "float", + "int", + "smallint", + "string", + "timestamp", + "tinyint")) + .put(PREDICATE, Set.of("isnan", "isnotnull", "isnull", "regexp", "regexp_like", "rlike")) + .put(CSV, Set.of("from_csv","schema_of_csv","to_csv")) + .put( + MISC, + Set.of( + "aes_decrypt", + "aes_encrypt", + "assert_true", + "bitmap_bit_position", + "bitmap_bucket_number", + "bitmap_count", + "current_catalog", + "current_database", + "current_schema", + "current_user", + "equal_null", + "hll_sketch_estimate", + "hll_union", + "input_file_block_length", + "input_file_block_start", + "input_file_name", + "java_method", + "monotonically_increasing_id", + "reflect", + "spark_partition_id", + "try_aes_decrypt", + "typeof", + "user", + "uuid", + "version")) + .put( + GENERATOR, + Set.of( + "explode", + "explode_outer", + "inline", + "inline_outer", + "posexplode", + "posexplode_outer", + "stack")) + .build(); + + private static final Map FUNCTION_NAME_TO_FUNCTION_TYPE_MAP = + FUNCTION_TYPE_TO_FUNCTION_NAMES_MAP.entrySet().stream() + .flatMap( + entry -> entry.getValue().stream().map(value -> Map.entry(value, entry.getKey()))) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + public static FunctionType fromFunctionName(String functionName) { + return FUNCTION_NAME_TO_FUNCTION_TYPE_MAP.getOrDefault(functionName.toLowerCase(), UDF); + } +} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElement.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElement.java index 562a83dcd4..05d878e9dc 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElement.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElement.java @@ -76,7 +76,8 @@ enum GrammarElement { CSV_FUNCTIONS("CSV functions"), MISC_FUNCTIONS("Misc functions"), - SELECT("SELECT"); + // UDF + UDF("User Defined functions"); String description; diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorFactory.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorFactory.java index 99cecf18ae..57374b852c 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorFactory.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorFactory.java @@ -27,7 +27,49 @@ public class GrammarElementValidatorFactory { DROP_NAMESPACE, DROP_VIEW, REPAIR_TABLE, - TRUNCATE_TABLE) + TRUNCATE_TABLE, + EXPLAIN, + WITH, + CLUSTER_BY, + DISTRIBUTE_BY, + HINTS, + INLINE_TABLE, + CROSS_JOIN, + LEFT_SEMI_JOIN, + RIGHT_OUTER_JOIN, + FULL_OUTER_JOIN, + LEFT_ANTI_JOIN, + TABLESAMPLE, + TABLE_VALUED_FUNCTION, + LATERAL_VIEW, + LATERAL_SUBQUERY, + TRANSFORM, + MANAGE_RESOURCE, + ANALYZE_TABLE, + CACHE_TABLE, + DESCRIBE_NAMESPACE, + DESCRIBE_FUNCTION, + DESCRIBE_QUERY, + DESCRIBE_TABLE, + REFRESH_RESOURCE, + REFRESH_TABLE, + REFRESH_FUNCTION, + RESET, + SET, + SHOW_COLUMNS, + SHOW_CREATE_TABLE, + SHOW_NAMESPACES, + SHOW_FUNCTIONS, + SHOW_PARTITIONS, + SHOW_TABLE_EXTENDED, + SHOW_TABLES, + SHOW_TBLPROPERTIES, + SHOW_VIEWS, + UNCACHE_TABLE, + CSV_FUNCTIONS, + MISC_FUNCTIONS, + UDF + ) .build(); private static final Set S3GLUE_DENY_LIST = @@ -58,7 +100,8 @@ public class GrammarElementValidatorFactory { SET, SHOW_FUNCTIONS, SHOW_VIEWS, - MISC_FUNCTIONS) + MISC_FUNCTIONS, + UDF) .build(); private static Map validatorMap = diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java index a737c62071..fba973f930 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java @@ -27,6 +27,7 @@ import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DropViewContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ExplainContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.FunctionIdentifierContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.FunctionNameContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.HintContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.InlineTableContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.InsertIntoReplaceWhereContext; @@ -82,12 +83,6 @@ public Void visitCreateFunction(SqlBaseParser.CreateFunctionContext ctx) { return super.visitCreateFunction(ctx); } - @Override - public Void visitSelectClause(SelectClauseContext ctx) { - validateAllowed(GrammarElement.SELECT); - return super.visitSelectClause(ctx); - } - @Override public Void visitSetNamespaceProperties(SetNamespacePropertiesContext ctx) { validateAllowed(GrammarElement.ALTER_NAMESPACE); @@ -457,30 +452,32 @@ public Void visitUncacheTable(UncacheTableContext ctx) { @Override public Void visitFunctionIdentifier(FunctionIdentifierContext ctx) { - String function = ctx.function.getText().toLowerCase(); - if (isMapFunctions(function)) { - validateAllowed(GrammarElement.MAP_FUNCTIONS); - } else if (isCsvFunctions(function)) { - validateAllowed(GrammarElement.CSV_FUNCTIONS); - } else if (isMiscFunctions(function)) { - validateAllowed(GrammarElement.MISC_FUNCTIONS); - } + validateFunctionAllowed(ctx.function.getText()); return super.visitFunctionIdentifier(ctx); } - private boolean isMapFunctions(String function) { - // TODO: to be implemented - return false; - } - - private boolean isCsvFunctions(String function) { - // TODO: to be implemented - return false; + @Override + public Void visitFunctionName(FunctionNameContext ctx) { + validateFunctionAllowed(ctx.qualifiedName().getText()); + return super.visitFunctionName(ctx); } - private boolean isMiscFunctions(String function) { - // TODO: to be implemented - return false; + private void validateFunctionAllowed(String function) { + FunctionType type = FunctionType.fromFunctionName(function.toLowerCase()); + switch(type) { + case MAP: + validateAllowed(GrammarElement.MAP_FUNCTIONS); + break; + case CSV: + validateAllowed(GrammarElement.CSV_FUNCTIONS); + break; + case MISC: + validateAllowed(GrammarElement.MISC_FUNCTIONS); + break; + case UDF: + validateAllowed(GrammarElement.UDF); + break; + } } private void validateAllowed(GrammarElement element) { diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/validator/FunctionTypeTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/FunctionTypeTest.java new file mode 100644 index 0000000000..920d35df2f --- /dev/null +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/FunctionTypeTest.java @@ -0,0 +1,46 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.validator; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; + +class FunctionTypeTest { + @Test + public void test() { + assertEquals(FunctionType.AGGREGATE, FunctionType.fromFunctionName("any")); + assertEquals(FunctionType.AGGREGATE, FunctionType.fromFunctionName("variance")); + assertEquals(FunctionType.WINDOW, FunctionType.fromFunctionName("cume_dist")); + assertEquals(FunctionType.WINDOW, FunctionType.fromFunctionName("row_number")); + assertEquals(FunctionType.ARRAY, FunctionType.fromFunctionName("array")); + assertEquals(FunctionType.ARRAY, FunctionType.fromFunctionName("sort_array")); + assertEquals(FunctionType.MAP, FunctionType.fromFunctionName("element_at")); + assertEquals(FunctionType.MAP, FunctionType.fromFunctionName("try_element_at")); + assertEquals(FunctionType.DATE_TIMESTAMP, FunctionType.fromFunctionName("add_months")); + assertEquals(FunctionType.DATE_TIMESTAMP, FunctionType.fromFunctionName("year")); + assertEquals(FunctionType.JSON, FunctionType.fromFunctionName("from_json")); + assertEquals(FunctionType.JSON, FunctionType.fromFunctionName("to_json")); + assertEquals(FunctionType.MATH, FunctionType.fromFunctionName("abs")); + assertEquals(FunctionType.MATH, FunctionType.fromFunctionName("width_bucket")); + assertEquals(FunctionType.STRING, FunctionType.fromFunctionName("ascii")); + assertEquals(FunctionType.STRING, FunctionType.fromFunctionName("upper")); + assertEquals(FunctionType.CONDITIONAL, FunctionType.fromFunctionName("coalesce")); + assertEquals(FunctionType.CONDITIONAL, FunctionType.fromFunctionName("nvl2")); + assertEquals(FunctionType.BITWISE, FunctionType.fromFunctionName("bit_count")); + assertEquals(FunctionType.BITWISE, FunctionType.fromFunctionName("shiftrightunsigned")); + assertEquals(FunctionType.CONVERSION, FunctionType.fromFunctionName("bigint")); + assertEquals(FunctionType.CONVERSION, FunctionType.fromFunctionName("tinyint")); + assertEquals(FunctionType.PREDICATE, FunctionType.fromFunctionName("isnan")); + assertEquals(FunctionType.PREDICATE, FunctionType.fromFunctionName("rlike")); + assertEquals(FunctionType.CSV, FunctionType.fromFunctionName("from_csv")); + assertEquals(FunctionType.CSV, FunctionType.fromFunctionName("to_csv")); + assertEquals(FunctionType.MISC, FunctionType.fromFunctionName("aes_decrypt")); + assertEquals(FunctionType.MISC, FunctionType.fromFunctionName("version")); + assertEquals(FunctionType.GENERATOR, FunctionType.fromFunctionName("explode")); + assertEquals(FunctionType.GENERATOR, FunctionType.fromFunctionName("stack")); + } +} diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java index 85f9d0f284..88fe273aec 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java @@ -134,13 +134,13 @@ private enum TestQuery { DATE_AND_TIMESTAMP_FUNCTIONS("SELECT date_format(current_date(), 'yyyy-MM-dd');"), JSON_FUNCTIONS("SELECT json_tuple('{\"a\":1, \"b\":2}', 'a', 'b');"), MATHEMATICAL_FUNCTIONS("SELECT round(3.1415, 2);"), - STRING_FUNCTIONS("SELECT concat('Hello', ' ', 'World');"), - BITWISE_FUNCTIONS("SELECT bitwiseNOT(42);"), + STRING_FUNCTIONS("SELECT map_concat('Hello', ' ', 'World');"), + BITWISE_FUNCTIONS("SELECT bit_count(42);"), CONVERSION_FUNCTIONS("SELECT cast('2023-04-01' as date);"), CONDITIONAL_FUNCTIONS("SELECT if(1 > 0, 'true', 'false');"), - PREDICATE_FUNCTIONS("SELECT array_exists(array(1, 2, 3), x -> x > 2);"), - CSV_FUNCTIONS("SELECT csv_from_array(array('a', 'b', 'c'), ',');"), - MISC_FUNCTIONS("SELECT hash('Hello World');"), + PREDICATE_FUNCTIONS("SELECT isnotnull(1);"), + CSV_FUNCTIONS("SELECT from_csv(array('a', 'b', 'c'), ',');"), + MISC_FUNCTIONS("SELECT current_user();"), // Aggregate-like Functions AGGREGATE_FUNCTIONS("SELECT count(*), max(age), min(age) FROM my_table;"), @@ -251,30 +251,30 @@ void s3glueQueries() { verifyValid(v, TestQuery.UNCACHE_TABLE); // Functions - // verifyValid(v, TestQuery.ARRAY_FUNCTIONS); - // verifyValid(v, TestQuery.MAP_FUNCTIONS); - // verifyValid(v, TestQuery.DATE_AND_TIMESTAMP_FUNCTIONS); - // verifyValid(v, TestQuery.JSON_FUNCTIONS); - // verifyValid(v, TestQuery.MATHEMATICAL_FUNCTIONS); - // verifyValid(v, TestQuery.STRING_FUNCTIONS); - // verifyValid(v, TestQuery.BITWISE_FUNCTIONS); - // verifyValid(v, TestQuery.CONVERSION_FUNCTIONS); - // verifyValid(v, TestQuery.CONDITIONAL_FUNCTIONS); - // verifyValid(v, TestQuery.PREDICATE_FUNCTIONS); - // verifyValid(v, TestQuery.CSV_FUNCTIONS); - // verifyValid(v, TestQuery.MISC_FUNCTIONS); + verifyValid(v, TestQuery.ARRAY_FUNCTIONS); + verifyValid(v, TestQuery.MAP_FUNCTIONS); + verifyValid(v, TestQuery.DATE_AND_TIMESTAMP_FUNCTIONS); + verifyValid(v, TestQuery.JSON_FUNCTIONS); + verifyValid(v, TestQuery.MATHEMATICAL_FUNCTIONS); + verifyValid(v, TestQuery.STRING_FUNCTIONS); + verifyValid(v, TestQuery.BITWISE_FUNCTIONS); + verifyValid(v, TestQuery.CONVERSION_FUNCTIONS); + verifyValid(v, TestQuery.CONDITIONAL_FUNCTIONS); + verifyValid(v, TestQuery.PREDICATE_FUNCTIONS); + verifyValid(v, TestQuery.CSV_FUNCTIONS); + verifyInvalid(v, TestQuery.MISC_FUNCTIONS); // Aggregate-like Functions - // verifyValid(v, TestQuery.AGGREGATE_FUNCTIONS); - // verifyValid(v, TestQuery.WINDOW_FUNCTIONS); + verifyValid(v, TestQuery.AGGREGATE_FUNCTIONS); + verifyValid(v, TestQuery.WINDOW_FUNCTIONS); // Generator Functions - // verifyValid(v, TestQuery.GENERATOR_FUNCTIONS); + verifyValid(v, TestQuery.GENERATOR_FUNCTIONS); // UDFs - // verifyInvalid(v, TestQuery.SCALAR_USER_DEFINED_FUNCTIONS); - // verifyInvalid(v, TestQuery.USER_DEFINED_AGGREGATE_FUNCTIONS); - // verifyInvalid(v, TestQuery.INTEGRATION_WITH_HIVE_UDFS_UDAFS_UDTFS); + verifyInvalid(v, TestQuery.SCALAR_USER_DEFINED_FUNCTIONS); + verifyInvalid(v, TestQuery.USER_DEFINED_AGGREGATE_FUNCTIONS); + verifyInvalid(v, TestQuery.INTEGRATION_WITH_HIVE_UDFS_UDAFS_UDTFS); } void verifyValid(SQLQueryValidator validator, TestQuery query) { From c9f960f631964fdd4aaa16f26071c347917ceab1 Mon Sep 17 00:00:00 2001 From: Tomoyuki Morita Date: Tue, 17 Sep 2024 16:02:06 -0700 Subject: [PATCH 03/14] fix style Signed-off-by: Tomoyuki Morita --- .../java/org/opensearch/sql/spark/validator/FunctionType.java | 2 +- .../sql/spark/validator/GrammarElementValidatorFactory.java | 3 +-- .../org/opensearch/sql/spark/validator/SQLQueryValidator.java | 3 +-- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/FunctionType.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/FunctionType.java index 0a821a7a8c..a17f2f8b21 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/FunctionType.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/FunctionType.java @@ -379,7 +379,7 @@ public enum FunctionType { "timestamp", "tinyint")) .put(PREDICATE, Set.of("isnan", "isnotnull", "isnull", "regexp", "regexp_like", "rlike")) - .put(CSV, Set.of("from_csv","schema_of_csv","to_csv")) + .put(CSV, Set.of("from_csv", "schema_of_csv", "to_csv")) .put( MISC, Set.of( diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorFactory.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorFactory.java index 57374b852c..bd297c7afa 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorFactory.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorFactory.java @@ -68,8 +68,7 @@ public class GrammarElementValidatorFactory { UNCACHE_TABLE, CSV_FUNCTIONS, MISC_FUNCTIONS, - UDF - ) + UDF) .build(); private static final Set S3GLUE_DENY_LIST = diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java index fba973f930..6c84be2a57 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java @@ -49,7 +49,6 @@ import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ResetConfigurationContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ResetQuotedConfigurationContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SampleContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SelectClauseContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SetConfigurationContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SetNamespaceLocationContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SetNamespacePropertiesContext; @@ -464,7 +463,7 @@ public Void visitFunctionName(FunctionNameContext ctx) { private void validateFunctionAllowed(String function) { FunctionType type = FunctionType.fromFunctionName(function.toLowerCase()); - switch(type) { + switch (type) { case MAP: validateAllowed(GrammarElement.MAP_FUNCTIONS); break; From 9665a627f49db8a78a3bf9d2dc002c248c9a603d Mon Sep 17 00:00:00 2001 From: Tomoyuki Morita Date: Tue, 17 Sep 2024 17:53:12 -0700 Subject: [PATCH 04/14] Add security lake Signed-off-by: Tomoyuki Morita --- .../sql/spark/validator/GrammarElement.java | 1 + .../GrammarElementValidatorFactory.java | 60 ++- .../spark/validator/SQLQueryValidator.java | 128 +++++++ .../validator/SQLQueryValidatorTest.java | 349 ++++++++++++------ 4 files changed, 426 insertions(+), 112 deletions(-) diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElement.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElement.java index 05d878e9dc..3ee33d38fa 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElement.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElement.java @@ -33,6 +33,7 @@ enum GrammarElement { HAVING("HAVING"), HINTS("HINTS"), INLINE_TABLE("Inline Table(VALUES)"), + FILE("File"), INNER_JOIN("INNER JOIN"), CROSS_JOIN("CROSS JOIN"), LEFT_OUTER_JOIN("LEFT OUTER JOIN"), diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorFactory.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorFactory.java index bd297c7afa..71f989e456 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorFactory.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorFactory.java @@ -17,6 +17,7 @@ public class GrammarElementValidatorFactory { private static final Set DEFAULT_DENY_LIST = ImmutableSet.of(CREATE_FUNCTION, DROP_FUNCTION, INSERT, LOAD, HINTS, TABLESAMPLE); + // Deny List for CloudWatch Logs datasource private static final Set CWL_DENY_LIST = copyBuilder(DEFAULT_DENY_LIST) .add( @@ -71,18 +72,53 @@ public class GrammarElementValidatorFactory { UDF) .build(); + // Deny list for S3 Glue datasource private static final Set S3GLUE_DENY_LIST = copyBuilder(DEFAULT_DENY_LIST) .add( ALTER_VIEW, CREATE_VIEW, DROP_VIEW, - REPAIR_TABLE, DISTRIBUTE_BY, INLINE_TABLE, + CLUSTER_BY, + DISTRIBUTE_BY, + CROSS_JOIN, + LEFT_SEMI_JOIN, + RIGHT_OUTER_JOIN, + FULL_OUTER_JOIN, + LEFT_ANTI_JOIN, + TABLESAMPLE, + TABLE_VALUED_FUNCTION, + TRANSFORM, + MANAGE_RESOURCE, + DESCRIBE_FUNCTION, + REFRESH_RESOURCE, + REFRESH_FUNCTION, + RESET, + SET, + SHOW_FUNCTIONS, + SHOW_VIEWS, + MISC_FUNCTIONS, + UDF) + .build(); + + // Deny list for Security Lake datasource + private static final Set SL_DENY_LIST = + copyBuilder(DEFAULT_DENY_LIST) + .add( + ALTER_NAMESPACE, + ALTER_VIEW, + CREATE_NAMESPACE, + CREATE_VIEW, + DROP_NAMESPACE, + DROP_VIEW, + REPAIR_TABLE, TRUNCATE_TABLE, CLUSTER_BY, DISTRIBUTE_BY, + HINTS, + INLINE_TABLE, CROSS_JOIN, LEFT_SEMI_JOIN, RIGHT_OUTER_JOIN, @@ -92,19 +128,39 @@ public class GrammarElementValidatorFactory { TABLE_VALUED_FUNCTION, TRANSFORM, MANAGE_RESOURCE, + ANALYZE_TABLE, + CACHE_TABLE, + CLEAR_CACHE, + DESCRIBE_NAMESPACE, DESCRIBE_FUNCTION, + DESCRIBE_QUERY, + DESCRIBE_TABLE, REFRESH_RESOURCE, + REFRESH_TABLE, REFRESH_FUNCTION, RESET, SET, + SHOW_COLUMNS, + SHOW_CREATE_TABLE, + SHOW_NAMESPACES, SHOW_FUNCTIONS, + SHOW_PARTITIONS, + SHOW_TABLE_EXTENDED, + SHOW_TABLES, + SHOW_TBLPROPERTIES, SHOW_VIEWS, + UNCACHE_TABLE, + CSV_FUNCTIONS, MISC_FUNCTIONS, UDF) .build(); + private static Map validatorMap = - ImmutableMap.of(DataSourceType.S3GLUE, new DenyListGrammarElementValidator(S3GLUE_DENY_LIST)); + ImmutableMap.of( + DataSourceType.S3GLUE, new DenyListGrammarElementValidator(S3GLUE_DENY_LIST), + DataSourceType.SECURITY_LAKE, new DenyListGrammarElementValidator(SL_DENY_LIST) + ); public GrammarElementValidator getValidatorForDatasource(DataSourceType dataSourceType) { return validatorMap.get(dataSourceType); diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java index 6c84be2a57..14d7b1ce22 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java @@ -8,6 +8,10 @@ import lombok.AllArgsConstructor; import org.antlr.v4.runtime.tree.TerminalNode; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.AddTableColumnsContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.AddTablePartitionContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.AlterClusterByContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.AlterTableAlterColumnContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.AlterViewQueryContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.AlterViewSchemaBindingContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.AnalyzeContext; @@ -16,6 +20,8 @@ import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ClearCacheContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ClusterBySpecContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.CreateNamespaceContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.CreateTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.CreateTableLikeContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.CreateViewContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.CtesContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DescribeFunctionContext; @@ -24,11 +30,15 @@ import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DescribeRelationContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DropFunctionContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DropNamespaceContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DropTableColumnsContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DropTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DropTablePartitionsContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DropViewContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ExplainContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.FunctionIdentifierContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.FunctionNameContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.HintContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.HiveReplaceColumnsContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.InlineTableContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.InsertIntoReplaceWhereContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.InsertIntoTableContext; @@ -41,11 +51,16 @@ import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.LoadDataContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ManageResourceContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.QueryOrganizationContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RecoverPartitionsContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RefreshFunctionContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RefreshResourceContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RefreshTableContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RelationContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RenameTableColumnContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RenameTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RenameTablePartitionContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RepairTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ReplaceTableContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ResetConfigurationContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ResetQuotedConfigurationContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SampleContext; @@ -53,6 +68,8 @@ import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SetNamespaceLocationContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SetNamespacePropertiesContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SetQuantifierContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SetTableLocationContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SetTableSerDeContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowColumnsContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowCreateTableContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowFunctionsContext; @@ -64,6 +81,7 @@ import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowViewsContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.TableValuedFunctionContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.TransformClauseContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.TruncateTableContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.UncacheTableContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.UnsetNamespacePropertiesContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParserBaseVisitor; @@ -94,6 +112,80 @@ public Void visitUnsetNamespaceProperties(UnsetNamespacePropertiesContext ctx) { return super.visitUnsetNamespaceProperties(ctx); } + @Override + public Void visitAddTableColumns(AddTableColumnsContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitAddTableColumns(ctx); + } + + @Override + public Void visitAddTablePartition(AddTablePartitionContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitAddTablePartition(ctx); + } + + @Override + public Void visitRenameTableColumn(RenameTableColumnContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitRenameTableColumn(ctx); + } + + @Override + public Void visitDropTableColumns(DropTableColumnsContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitDropTableColumns(ctx); + } + + @Override + public Void visitAlterTableAlterColumn(AlterTableAlterColumnContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitAlterTableAlterColumn(ctx); + } + + @Override + public Void visitHiveReplaceColumns(HiveReplaceColumnsContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitHiveReplaceColumns(ctx); + } + + @Override + public Void visitSetTableSerDe(SetTableSerDeContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitSetTableSerDe(ctx); + } + + @Override + public Void visitRenameTablePartition(RenameTablePartitionContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitRenameTablePartition(ctx); + } + + @Override + public Void visitDropTablePartitions(DropTablePartitionsContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitDropTablePartitions(ctx); + } + + @Override + public Void visitSetTableLocation(SetTableLocationContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitSetTableLocation(ctx); + } + + @Override + public Void visitRecoverPartitions(RecoverPartitionsContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitRecoverPartitions(ctx); + } + + @Override + public Void visitAlterClusterBy(AlterClusterByContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitAlterClusterBy(ctx); + } + + + @Override public Void visitSetNamespaceLocation(SetNamespaceLocationContext ctx) { validateAllowed(GrammarElement.ALTER_NAMESPACE); @@ -131,12 +223,36 @@ public Void visitCreateNamespace(CreateNamespaceContext ctx) { return super.visitCreateNamespace(ctx); } + @Override + public Void visitCreateTable(CreateTableContext ctx) { + validateAllowed(GrammarElement.CREATE_NAMESPACE); + return super.visitCreateTable(ctx); + } + + @Override + public Void visitCreateTableLike(CreateTableLikeContext ctx) { + validateAllowed(GrammarElement.CREATE_NAMESPACE); + return super.visitCreateTableLike(ctx); + } + + @Override + public Void visitReplaceTable(ReplaceTableContext ctx) { + validateAllowed(GrammarElement.CREATE_NAMESPACE); + return super.visitReplaceTable(ctx); + } + @Override public Void visitDropNamespace(DropNamespaceContext ctx) { validateAllowed(GrammarElement.DROP_NAMESPACE); return super.visitDropNamespace(ctx); } + @Override + public Void visitDropTable(DropTableContext ctx) { + validateAllowed(GrammarElement.DROP_NAMESPACE); + return super.visitDropTable(ctx); + } + @Override public Void visitCreateView(CreateViewContext ctx) { validateAllowed(GrammarElement.CREATE_VIEW); @@ -155,6 +271,18 @@ public Void visitDropFunction(DropFunctionContext ctx) { return super.visitDropFunction(ctx); } + @Override + public Void visitRepairTable(RepairTableContext ctx) { + validateAllowed(GrammarElement.REPAIR_TABLE); + return super.visitRepairTable(ctx); + } + + @Override + public Void visitTruncateTable(TruncateTableContext ctx) { + validateAllowed(GrammarElement.TRUNCATE_TABLE); + return super.visitTruncateTable(ctx); + } + @Override public Void visitInsertOverwriteTable(InsertOverwriteTableContext ctx) { validateAllowed(GrammarElement.INSERT); diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java index 88fe273aec..53a9a94a57 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java @@ -164,138 +164,267 @@ public String toString() { @Test void s3glueQueries() { - SQLQueryValidator v = - new SQLQueryValidator(factory.getValidatorForDatasource(DataSourceType.S3GLUE)); - verifyValid(v, TestQuery.ALTER_DATABASE); - verifyValid(v, TestQuery.ALTER_TABLE); - verifyInvalid(v, TestQuery.ALTER_VIEW); - verifyValid(v, TestQuery.CREATE_DATABASE); - verifyInvalid(v, TestQuery.CREATE_FUNCTION); - verifyValid(v, TestQuery.CREATE_TABLE); - verifyInvalid(v, TestQuery.CREATE_VIEW); - verifyValid(v, TestQuery.DROP_DATABASE); - verifyInvalid(v, TestQuery.DROP_FUNCTION); - verifyValid(v, TestQuery.DROP_TABLE); - verifyInvalid(v, TestQuery.DROP_VIEW); - verifyValid(v, TestQuery.REPAIR_TABLE); - verifyValid(v, TestQuery.TRUNCATE_TABLE); + VerifyValidator v = new VerifyValidator(new SQLQueryValidator(factory.getValidatorForDatasource(DataSourceType.S3GLUE))); + // DDL Statements + v.ok(TestQuery.ALTER_DATABASE); + v.ok(TestQuery.ALTER_TABLE); + v.ng(TestQuery.ALTER_VIEW); + v.ok(TestQuery.CREATE_DATABASE); + v.ng(TestQuery.CREATE_FUNCTION); + v.ok(TestQuery.CREATE_TABLE); + v.ng(TestQuery.CREATE_VIEW); + v.ok(TestQuery.DROP_DATABASE); + v.ng(TestQuery.DROP_FUNCTION); + v.ok(TestQuery.DROP_TABLE); + v.ng(TestQuery.DROP_VIEW); + v.ok(TestQuery.REPAIR_TABLE); + v.ok(TestQuery.TRUNCATE_TABLE); // DML Statements - verifyInvalid(v, TestQuery.INSERT_TABLE); - verifyInvalid(v, TestQuery.INSERT_OVERWRITE_DIRECTORY); - verifyInvalid(v, TestQuery.LOAD); + v.ng(TestQuery.INSERT_TABLE); + v.ng(TestQuery.INSERT_OVERWRITE_DIRECTORY); + v.ng(TestQuery.LOAD); // Data Retrieval - verifyValid(v, TestQuery.SELECT); - verifyValid(v, TestQuery.EXPLAIN); - verifyValid(v, TestQuery.COMMON_TABLE_EXPRESSION); - verifyInvalid(v, TestQuery.CLUSTER_BY_CLAUSE); - verifyInvalid(v, TestQuery.DISTRIBUTE_BY_CLAUSE); - verifyValid(v, TestQuery.GROUP_BY_CLAUSE); - verifyValid(v, TestQuery.HAVING_CLAUSE); - verifyInvalid(v, TestQuery.HINTS); - verifyInvalid(v, TestQuery.INLINE_TABLE); - // verifyInvalid(v, TestQuery.FILE); TODO: need dive deep - verifyValid(v, TestQuery.INNER_JOIN); - verifyInvalid(v, TestQuery.CROSS_JOIN); - verifyValid(v, TestQuery.LEFT_OUTER_JOIN); - verifyInvalid(v, TestQuery.LEFT_SEMI_JOIN); - verifyInvalid(v, TestQuery.RIGHT_OUTER_JOIN); - verifyInvalid(v, TestQuery.FULL_OUTER_JOIN); - verifyInvalid(v, TestQuery.LEFT_ANTI_JOIN); - verifyValid(v, TestQuery.LIKE_PREDICATE); - verifyValid(v, TestQuery.LIMIT_CLAUSE); - verifyValid(v, TestQuery.OFFSET_CLAUSE); - verifyValid(v, TestQuery.ORDER_BY_CLAUSE); - verifyValid(v, TestQuery.SET_OPERATORS); - verifyValid(v, TestQuery.SORT_BY_CLAUSE); - verifyInvalid(v, TestQuery.TABLESAMPLE); - verifyInvalid(v, TestQuery.TABLE_VALUED_FUNCTION); - verifyValid(v, TestQuery.WHERE_CLAUSE); - verifyValid(v, TestQuery.AGGREGATE_FUNCTION); - verifyValid(v, TestQuery.WINDOW_FUNCTION); - verifyValid(v, TestQuery.CASE_CLAUSE); - verifyValid(v, TestQuery.PIVOT_CLAUSE); - verifyValid(v, TestQuery.UNPIVOT_CLAUSE); - verifyValid(v, TestQuery.LATERAL_VIEW_CLAUSE); - verifyValid(v, TestQuery.LATERAL_SUBQUERY); - verifyInvalid(v, TestQuery.TRANSFORM_CLAUSE); + v.ok(TestQuery.SELECT); + v.ok(TestQuery.EXPLAIN); + v.ok(TestQuery.COMMON_TABLE_EXPRESSION); + v.ng(TestQuery.CLUSTER_BY_CLAUSE); + v.ng(TestQuery.DISTRIBUTE_BY_CLAUSE); + v.ok(TestQuery.GROUP_BY_CLAUSE); + v.ok(TestQuery.HAVING_CLAUSE); + v.ng(TestQuery.HINTS); + v.ng(TestQuery.INLINE_TABLE); + // v.ng(TestQuery.FILE); TODO: need dive deep + v.ok(TestQuery.INNER_JOIN); + v.ng(TestQuery.CROSS_JOIN); + v.ok(TestQuery.LEFT_OUTER_JOIN); + v.ng(TestQuery.LEFT_SEMI_JOIN); + v.ng(TestQuery.RIGHT_OUTER_JOIN); + v.ng(TestQuery.FULL_OUTER_JOIN); + v.ng(TestQuery.LEFT_ANTI_JOIN); + v.ok(TestQuery.LIKE_PREDICATE); + v.ok(TestQuery.LIMIT_CLAUSE); + v.ok(TestQuery.OFFSET_CLAUSE); + v.ok(TestQuery.ORDER_BY_CLAUSE); + v.ok(TestQuery.SET_OPERATORS); + v.ok(TestQuery.SORT_BY_CLAUSE); + v.ng(TestQuery.TABLESAMPLE); + v.ng(TestQuery.TABLE_VALUED_FUNCTION); + v.ok(TestQuery.WHERE_CLAUSE); + v.ok(TestQuery.AGGREGATE_FUNCTION); + v.ok(TestQuery.WINDOW_FUNCTION); + v.ok(TestQuery.CASE_CLAUSE); + v.ok(TestQuery.PIVOT_CLAUSE); + v.ok(TestQuery.UNPIVOT_CLAUSE); + v.ok(TestQuery.LATERAL_VIEW_CLAUSE); + v.ok(TestQuery.LATERAL_SUBQUERY); + v.ng(TestQuery.TRANSFORM_CLAUSE); // Auxiliary Statements - verifyInvalid(v, TestQuery.ADD_FILE); - verifyInvalid(v, TestQuery.ADD_JAR); - verifyValid(v, TestQuery.ANALYZE_TABLE); - verifyValid(v, TestQuery.CACHE_TABLE); - verifyValid(v, TestQuery.CLEAR_CACHE); - verifyValid(v, TestQuery.DESCRIBE_DATABASE); - verifyInvalid(v, TestQuery.DESCRIBE_FUNCTION); - verifyValid(v, TestQuery.DESCRIBE_QUERY); - verifyValid(v, TestQuery.DESCRIBE_TABLE); - verifyInvalid(v, TestQuery.LIST_FILE); - verifyInvalid(v, TestQuery.LIST_JAR); - verifyInvalid(v, TestQuery.REFRESH); - // verifyValid(v, TestQuery.REFRESH_TABLE); TODO: refreshTable rule won't match (matches to + v.ng(TestQuery.ADD_FILE); + v.ng(TestQuery.ADD_JAR); + v.ok(TestQuery.ANALYZE_TABLE); + v.ok(TestQuery.CACHE_TABLE); + v.ok(TestQuery.CLEAR_CACHE); + v.ok(TestQuery.DESCRIBE_DATABASE); + v.ng(TestQuery.DESCRIBE_FUNCTION); + v.ok(TestQuery.DESCRIBE_QUERY); + v.ok(TestQuery.DESCRIBE_TABLE); + v.ng(TestQuery.LIST_FILE); + v.ng(TestQuery.LIST_JAR); + v.ng(TestQuery.REFRESH); + // v.ok(TestQuery.REFRESH_TABLE); TODO: refreshTable rule won't match (matches to // refreshResource) - verifyInvalid(v, TestQuery.REFRESH_FUNCTION); - verifyInvalid(v, TestQuery.RESET); - verifyInvalid(v, TestQuery.SET); - verifyValid(v, TestQuery.SHOW_COLUMNS); - verifyValid(v, TestQuery.SHOW_CREATE_TABLE); - verifyValid(v, TestQuery.SHOW_DATABASES); - verifyInvalid(v, TestQuery.SHOW_FUNCTIONS); - verifyValid(v, TestQuery.SHOW_PARTITIONS); - verifyValid(v, TestQuery.SHOW_TABLE_EXTENDED); - verifyValid(v, TestQuery.SHOW_TABLES); - verifyValid(v, TestQuery.SHOW_TBLPROPERTIES); - verifyInvalid(v, TestQuery.SHOW_VIEWS); - verifyValid(v, TestQuery.UNCACHE_TABLE); + v.ng(TestQuery.REFRESH_FUNCTION); + v.ng(TestQuery.RESET); + v.ng(TestQuery.SET); + v.ok(TestQuery.SHOW_COLUMNS); + v.ok(TestQuery.SHOW_CREATE_TABLE); + v.ok(TestQuery.SHOW_DATABASES); + v.ng(TestQuery.SHOW_FUNCTIONS); + v.ok(TestQuery.SHOW_PARTITIONS); + v.ok(TestQuery.SHOW_TABLE_EXTENDED); + v.ok(TestQuery.SHOW_TABLES); + v.ok(TestQuery.SHOW_TBLPROPERTIES); + v.ng(TestQuery.SHOW_VIEWS); + v.ok(TestQuery.UNCACHE_TABLE); // Functions - verifyValid(v, TestQuery.ARRAY_FUNCTIONS); - verifyValid(v, TestQuery.MAP_FUNCTIONS); - verifyValid(v, TestQuery.DATE_AND_TIMESTAMP_FUNCTIONS); - verifyValid(v, TestQuery.JSON_FUNCTIONS); - verifyValid(v, TestQuery.MATHEMATICAL_FUNCTIONS); - verifyValid(v, TestQuery.STRING_FUNCTIONS); - verifyValid(v, TestQuery.BITWISE_FUNCTIONS); - verifyValid(v, TestQuery.CONVERSION_FUNCTIONS); - verifyValid(v, TestQuery.CONDITIONAL_FUNCTIONS); - verifyValid(v, TestQuery.PREDICATE_FUNCTIONS); - verifyValid(v, TestQuery.CSV_FUNCTIONS); - verifyInvalid(v, TestQuery.MISC_FUNCTIONS); + v.ok(TestQuery.ARRAY_FUNCTIONS); + v.ok(TestQuery.MAP_FUNCTIONS); + v.ok(TestQuery.DATE_AND_TIMESTAMP_FUNCTIONS); + v.ok(TestQuery.JSON_FUNCTIONS); + v.ok(TestQuery.MATHEMATICAL_FUNCTIONS); + v.ok(TestQuery.STRING_FUNCTIONS); + v.ok(TestQuery.BITWISE_FUNCTIONS); + v.ok(TestQuery.CONVERSION_FUNCTIONS); + v.ok(TestQuery.CONDITIONAL_FUNCTIONS); + v.ok(TestQuery.PREDICATE_FUNCTIONS); + v.ok(TestQuery.CSV_FUNCTIONS); + v.ng(TestQuery.MISC_FUNCTIONS); // Aggregate-like Functions - verifyValid(v, TestQuery.AGGREGATE_FUNCTIONS); - verifyValid(v, TestQuery.WINDOW_FUNCTIONS); + v.ok(TestQuery.AGGREGATE_FUNCTIONS); + v.ok(TestQuery.WINDOW_FUNCTIONS); // Generator Functions - verifyValid(v, TestQuery.GENERATOR_FUNCTIONS); + v.ok(TestQuery.GENERATOR_FUNCTIONS); // UDFs - verifyInvalid(v, TestQuery.SCALAR_USER_DEFINED_FUNCTIONS); - verifyInvalid(v, TestQuery.USER_DEFINED_AGGREGATE_FUNCTIONS); - verifyInvalid(v, TestQuery.INTEGRATION_WITH_HIVE_UDFS_UDAFS_UDTFS); + v.ng(TestQuery.SCALAR_USER_DEFINED_FUNCTIONS); + v.ng(TestQuery.USER_DEFINED_AGGREGATE_FUNCTIONS); + v.ng(TestQuery.INTEGRATION_WITH_HIVE_UDFS_UDAFS_UDTFS); } - void verifyValid(SQLQueryValidator validator, TestQuery query) { - runValidate(validator, query.toString()); + @Test + void securityLakeQueries() { + VerifyValidator v = new VerifyValidator(new SQLQueryValidator(factory.getValidatorForDatasource(DataSourceType.SECURITY_LAKE))); + // DDL Statements + v.ng(TestQuery.ALTER_DATABASE); + v.ng(TestQuery.ALTER_TABLE); + v.ng(TestQuery.ALTER_VIEW); + v.ng(TestQuery.CREATE_DATABASE); + v.ng(TestQuery.CREATE_FUNCTION); + v.ng(TestQuery.CREATE_TABLE); + v.ng(TestQuery.CREATE_VIEW); + v.ng(TestQuery.DROP_DATABASE); + v.ng(TestQuery.DROP_FUNCTION); + v.ng(TestQuery.DROP_TABLE); + v.ng(TestQuery.DROP_VIEW); + v.ng(TestQuery.REPAIR_TABLE); + v.ng(TestQuery.TRUNCATE_TABLE); + + // DML Statements + v.ng(TestQuery.INSERT_TABLE); + v.ng(TestQuery.INSERT_OVERWRITE_DIRECTORY); + v.ng(TestQuery.LOAD); + + // Data Retrieval + v.ok(TestQuery.SELECT); + v.ok(TestQuery.EXPLAIN); + v.ok(TestQuery.COMMON_TABLE_EXPRESSION); + v.ng(TestQuery.CLUSTER_BY_CLAUSE); + v.ng(TestQuery.DISTRIBUTE_BY_CLAUSE); + v.ok(TestQuery.GROUP_BY_CLAUSE); + v.ok(TestQuery.HAVING_CLAUSE); + v.ng(TestQuery.HINTS); + v.ng(TestQuery.INLINE_TABLE); + // v.ng(TestQuery.FILE); TODO: need dive deep + v.ok(TestQuery.INNER_JOIN); + v.ng(TestQuery.CROSS_JOIN); + v.ok(TestQuery.LEFT_OUTER_JOIN); + v.ng(TestQuery.LEFT_SEMI_JOIN); + v.ng(TestQuery.RIGHT_OUTER_JOIN); + v.ng(TestQuery.FULL_OUTER_JOIN); + v.ng(TestQuery.LEFT_ANTI_JOIN); + v.ok(TestQuery.LIKE_PREDICATE); + v.ok(TestQuery.LIMIT_CLAUSE); + v.ok(TestQuery.OFFSET_CLAUSE); + v.ok(TestQuery.ORDER_BY_CLAUSE); + v.ok(TestQuery.SET_OPERATORS); + v.ok(TestQuery.SORT_BY_CLAUSE); + v.ng(TestQuery.TABLESAMPLE); + v.ng(TestQuery.TABLE_VALUED_FUNCTION); + v.ok(TestQuery.WHERE_CLAUSE); + v.ok(TestQuery.AGGREGATE_FUNCTION); + v.ok(TestQuery.WINDOW_FUNCTION); + v.ok(TestQuery.CASE_CLAUSE); + v.ok(TestQuery.PIVOT_CLAUSE); + v.ok(TestQuery.UNPIVOT_CLAUSE); + v.ok(TestQuery.LATERAL_VIEW_CLAUSE); + v.ok(TestQuery.LATERAL_SUBQUERY); + v.ng(TestQuery.TRANSFORM_CLAUSE); + + // Auxiliary Statements + v.ng(TestQuery.ADD_FILE); + v.ng(TestQuery.ADD_JAR); + v.ng(TestQuery.ANALYZE_TABLE); + v.ng(TestQuery.CACHE_TABLE); + v.ng(TestQuery.CLEAR_CACHE); + v.ng(TestQuery.DESCRIBE_DATABASE); + v.ng(TestQuery.DESCRIBE_FUNCTION); + v.ng(TestQuery.DESCRIBE_QUERY); + v.ng(TestQuery.DESCRIBE_TABLE); + v.ng(TestQuery.LIST_FILE); + v.ng(TestQuery.LIST_JAR); + v.ng(TestQuery.REFRESH); + // v.ng(TestQuery.REFRESH_TABLE); TODO: refreshTable rule won't match (matches to + // refreshResource) + v.ng(TestQuery.REFRESH_FUNCTION); + v.ng(TestQuery.RESET); + v.ng(TestQuery.SET); + v.ng(TestQuery.SHOW_COLUMNS); + v.ng(TestQuery.SHOW_CREATE_TABLE); + v.ng(TestQuery.SHOW_DATABASES); + v.ng(TestQuery.SHOW_FUNCTIONS); + v.ng(TestQuery.SHOW_PARTITIONS); + v.ng(TestQuery.SHOW_TABLE_EXTENDED); + v.ng(TestQuery.SHOW_TABLES); + v.ng(TestQuery.SHOW_TBLPROPERTIES); + v.ng(TestQuery.SHOW_VIEWS); + v.ng(TestQuery.UNCACHE_TABLE); + + // Functions + v.ok(TestQuery.ARRAY_FUNCTIONS); + v.ok(TestQuery.MAP_FUNCTIONS); + v.ok(TestQuery.DATE_AND_TIMESTAMP_FUNCTIONS); + v.ok(TestQuery.JSON_FUNCTIONS); + v.ok(TestQuery.MATHEMATICAL_FUNCTIONS); + v.ok(TestQuery.STRING_FUNCTIONS); + v.ok(TestQuery.BITWISE_FUNCTIONS); + v.ok(TestQuery.CONVERSION_FUNCTIONS); + v.ok(TestQuery.CONDITIONAL_FUNCTIONS); + v.ok(TestQuery.PREDICATE_FUNCTIONS); + v.ng(TestQuery.CSV_FUNCTIONS); + v.ng(TestQuery.MISC_FUNCTIONS); + + // Aggregate-like Functions + v.ok(TestQuery.AGGREGATE_FUNCTIONS); + v.ok(TestQuery.WINDOW_FUNCTIONS); + + // Generator Functions + v.ok(TestQuery.GENERATOR_FUNCTIONS); + + // UDFs + v.ng(TestQuery.SCALAR_USER_DEFINED_FUNCTIONS); + v.ng(TestQuery.USER_DEFINED_AGGREGATE_FUNCTIONS); + v.ng(TestQuery.INTEGRATION_WITH_HIVE_UDFS_UDAFS_UDTFS); } - void verifyInvalid(SQLQueryValidator validator, TestQuery query) { - assertThrows( - IllegalArgumentException.class, - () -> runValidate(validator, query.toString()), - "The query should throw: query=`" + query.toString() + "`"); + @AllArgsConstructor + private static class VerifyValidator { + private final SQLQueryValidator validator; + + public void ok(TestQuery query) { + runValidate(validator, query.toString()); + + } + + public void ng(TestQuery query) { + assertThrows( + IllegalArgumentException.class, + () -> runValidate(validator, query.toString()), + "The query should throw: query=`" + query.toString() + "`"); + } + + void runValidate(SQLQueryValidator validator, String query) { + validator.validate(getParser(query)); + } + + SingleStatementContext getParser(String query) { + SqlBaseParser sqlBaseParser = + new SqlBaseParser( + new CommonTokenStream(new SqlBaseLexer(new CaseInsensitiveCharStream(query)))); + return sqlBaseParser.singleStatement(); + } } - void runValidate(SQLQueryValidator validator, String query) { - validator.validate(getParser(query)); + void ok(SQLQueryValidator validator, TestQuery query) { } - SingleStatementContext getParser(String query) { - SqlBaseParser sqlBaseParser = - new SqlBaseParser( - new CommonTokenStream(new SqlBaseLexer(new CaseInsensitiveCharStream(query)))); - return sqlBaseParser.singleStatement(); + void ng(SQLQueryValidator validator, TestQuery query) { } + + } From 9fafcb9d206a40794e18fd25de86553bb927cb66 Mon Sep 17 00:00:00 2001 From: Tomoyuki Morita Date: Wed, 18 Sep 2024 12:28:00 -0700 Subject: [PATCH 05/14] Add File support Signed-off-by: Tomoyuki Morita --- .../GrammarElementValidatorFactory.java | 32 ++++++++++++------- .../spark/validator/SQLQueryValidator.java | 19 +++++++++-- .../sql/spark/validator/FunctionTypeTest.java | 1 + .../validator/SQLQueryValidatorTest.java | 21 +++++------- 4 files changed, 47 insertions(+), 26 deletions(-) diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorFactory.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorFactory.java index 71f989e456..ba2672e3df 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorFactory.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorFactory.java @@ -14,27 +14,29 @@ import org.opensearch.sql.datasource.model.DataSourceType; public class GrammarElementValidatorFactory { - private static final Set DEFAULT_DENY_LIST = - ImmutableSet.of(CREATE_FUNCTION, DROP_FUNCTION, INSERT, LOAD, HINTS, TABLESAMPLE); - // Deny List for CloudWatch Logs datasource private static final Set CWL_DENY_LIST = - copyBuilder(DEFAULT_DENY_LIST) + ImmutableSet.builder() .add( ALTER_NAMESPACE, ALTER_VIEW, CREATE_NAMESPACE, + CREATE_FUNCTION, CREATE_VIEW, + DROP_FUNCTION, DROP_NAMESPACE, DROP_VIEW, REPAIR_TABLE, TRUNCATE_TABLE, + INSERT, + LOAD, EXPLAIN, WITH, CLUSTER_BY, DISTRIBUTE_BY, HINTS, INLINE_TABLE, + FILE, CROSS_JOIN, LEFT_SEMI_JOIN, RIGHT_OUTER_JOIN, @@ -74,15 +76,20 @@ public class GrammarElementValidatorFactory { // Deny list for S3 Glue datasource private static final Set S3GLUE_DENY_LIST = - copyBuilder(DEFAULT_DENY_LIST) + ImmutableSet.builder() .add( ALTER_VIEW, + CREATE_FUNCTION, CREATE_VIEW, + DROP_FUNCTION, DROP_VIEW, - DISTRIBUTE_BY, - INLINE_TABLE, + INSERT, + LOAD, CLUSTER_BY, DISTRIBUTE_BY, + HINTS, + INLINE_TABLE, + FILE, CROSS_JOIN, LEFT_SEMI_JOIN, RIGHT_OUTER_JOIN, @@ -105,20 +112,25 @@ public class GrammarElementValidatorFactory { // Deny list for Security Lake datasource private static final Set SL_DENY_LIST = - copyBuilder(DEFAULT_DENY_LIST) + ImmutableSet.builder() .add( ALTER_NAMESPACE, ALTER_VIEW, CREATE_NAMESPACE, + CREATE_FUNCTION, CREATE_VIEW, + DROP_FUNCTION, DROP_NAMESPACE, DROP_VIEW, REPAIR_TABLE, TRUNCATE_TABLE, + INSERT, + LOAD, CLUSTER_BY, DISTRIBUTE_BY, HINTS, INLINE_TABLE, + FILE, CROSS_JOIN, LEFT_SEMI_JOIN, RIGHT_OUTER_JOIN, @@ -155,12 +167,10 @@ public class GrammarElementValidatorFactory { UDF) .build(); - private static Map validatorMap = ImmutableMap.of( DataSourceType.S3GLUE, new DenyListGrammarElementValidator(S3GLUE_DENY_LIST), - DataSourceType.SECURITY_LAKE, new DenyListGrammarElementValidator(SL_DENY_LIST) - ); + DataSourceType.SECURITY_LAKE, new DenyListGrammarElementValidator(SL_DENY_LIST)); public GrammarElementValidator getValidatorForDatasource(DataSourceType dataSourceType) { return validatorMap.get(dataSourceType); diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java index 14d7b1ce22..2fd91bf87a 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java @@ -79,6 +79,7 @@ import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowTablesContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowTblPropertiesContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowViewsContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.TableNameContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.TableValuedFunctionContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.TransformClauseContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.TruncateTableContext; @@ -184,8 +185,6 @@ public Void visitAlterClusterBy(AlterClusterByContext ctx) { return super.visitAlterClusterBy(ctx); } - - @Override public Void visitSetNamespaceLocation(SetNamespaceLocationContext ctx) { validateAllowed(GrammarElement.ALTER_NAMESPACE); @@ -325,6 +324,22 @@ public Void visitExplain(ExplainContext ctx) { return super.visitExplain(ctx); } + @Override + public Void visitTableName(TableNameContext ctx) { + String reference = ctx.identifierReference().getText(); + System.out.println(reference); + if (isFileReference(reference)) { + validateAllowed(GrammarElement.FILE); + } + return super.visitTableName(ctx); + } + + private static final String FILE_REFERENCE_PATTERN = "^[a-zA-Z]+\\.`[^`]+`$"; + + private boolean isFileReference(String reference) { + return reference.matches(FILE_REFERENCE_PATTERN); + } + @Override public Void visitCtes(CtesContext ctx) { validateAllowed(GrammarElement.WITH); diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/validator/FunctionTypeTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/FunctionTypeTest.java index 920d35df2f..a5f868421c 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/validator/FunctionTypeTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/FunctionTypeTest.java @@ -42,5 +42,6 @@ public void test() { assertEquals(FunctionType.MISC, FunctionType.fromFunctionName("version")); assertEquals(FunctionType.GENERATOR, FunctionType.fromFunctionName("explode")); assertEquals(FunctionType.GENERATOR, FunctionType.fromFunctionName("stack")); + assertEquals(FunctionType.UDF, FunctionType.fromFunctionName("unknown")); } } diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java index 53a9a94a57..747485e718 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java @@ -164,7 +164,9 @@ public String toString() { @Test void s3glueQueries() { - VerifyValidator v = new VerifyValidator(new SQLQueryValidator(factory.getValidatorForDatasource(DataSourceType.S3GLUE))); + VerifyValidator v = + new VerifyValidator( + new SQLQueryValidator(factory.getValidatorForDatasource(DataSourceType.S3GLUE))); // DDL Statements v.ok(TestQuery.ALTER_DATABASE); v.ok(TestQuery.ALTER_TABLE); @@ -195,7 +197,7 @@ void s3glueQueries() { v.ok(TestQuery.HAVING_CLAUSE); v.ng(TestQuery.HINTS); v.ng(TestQuery.INLINE_TABLE); - // v.ng(TestQuery.FILE); TODO: need dive deep + v.ng(TestQuery.FILE); v.ok(TestQuery.INNER_JOIN); v.ng(TestQuery.CROSS_JOIN); v.ok(TestQuery.LEFT_OUTER_JOIN); @@ -279,7 +281,9 @@ void s3glueQueries() { @Test void securityLakeQueries() { - VerifyValidator v = new VerifyValidator(new SQLQueryValidator(factory.getValidatorForDatasource(DataSourceType.SECURITY_LAKE))); + VerifyValidator v = + new VerifyValidator( + new SQLQueryValidator(factory.getValidatorForDatasource(DataSourceType.SECURITY_LAKE))); // DDL Statements v.ng(TestQuery.ALTER_DATABASE); v.ng(TestQuery.ALTER_TABLE); @@ -310,7 +314,7 @@ void securityLakeQueries() { v.ok(TestQuery.HAVING_CLAUSE); v.ng(TestQuery.HINTS); v.ng(TestQuery.INLINE_TABLE); - // v.ng(TestQuery.FILE); TODO: need dive deep + v.ng(TestQuery.FILE); v.ok(TestQuery.INNER_JOIN); v.ng(TestQuery.CROSS_JOIN); v.ok(TestQuery.LEFT_OUTER_JOIN); @@ -398,7 +402,6 @@ private static class VerifyValidator { public void ok(TestQuery query) { runValidate(validator, query.toString()); - } public void ng(TestQuery query) { @@ -419,12 +422,4 @@ SingleStatementContext getParser(String query) { return sqlBaseParser.singleStatement(); } } - - void ok(SQLQueryValidator validator, TestQuery query) { - } - - void ng(SQLQueryValidator validator, TestQuery query) { - } - - } From 424be1266a1151675751176542c49f63afdf7931 Mon Sep 17 00:00:00 2001 From: Tomoyuki Morita Date: Wed, 18 Sep 2024 14:23:19 -0700 Subject: [PATCH 06/14] Integrate into SparkQueryDispatcher Signed-off-by: Tomoyuki Morita --- .../dispatcher/SparkQueryDispatcher.java | 11 +- .../sql/spark/utils/SQLQueryUtils.java | 8 + ...CloudWatchLogsGrammarElementValidator.java | 76 +++ .../sql/spark/validator/GrammarElement.java | 10 +- .../GrammarElementValidatorFactory.java | 167 +---- .../S3GlueGrammarElementValidator.java | 81 +++ .../validator/SQLQueryValidationVisitor.java | 625 +++++++++++++++++ .../spark/validator/SQLQueryValidator.java | 628 +----------------- .../SecurityLakeGrammarElementValidator.java | 123 ++++ .../asyncquery/AsyncQueryCoreIntegTest.java | 10 +- .../dispatcher/SparkQueryDispatcherTest.java | 24 +- .../validator/SQLQueryValidatorTest.java | 29 +- .../config/AsyncExecutorServiceModule.java | 6 +- .../AsyncQueryExecutorServiceSpec.java | 7 +- 14 files changed, 988 insertions(+), 817 deletions(-) create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/validator/CloudWatchLogsGrammarElementValidator.java create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/validator/S3GlueGrammarElementValidator.java create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidationVisitor.java create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/validator/SecurityLakeGrammarElementValidator.java diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 732f5f71ab..ff8c8d1fe8 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -6,7 +6,6 @@ package org.opensearch.sql.spark.dispatcher; import java.util.HashMap; -import java.util.List; import java.util.Map; import lombok.AllArgsConstructor; import org.jetbrains.annotations.NotNull; @@ -24,6 +23,7 @@ import org.opensearch.sql.spark.execution.session.SessionManager; import org.opensearch.sql.spark.rest.model.LangType; import org.opensearch.sql.spark.utils.SQLQueryUtils; +import org.opensearch.sql.spark.validator.SQLQueryValidator; /** This class takes care of understanding query and dispatching job query to emr serverless. */ @AllArgsConstructor @@ -38,6 +38,7 @@ public class SparkQueryDispatcher { private final SessionManager sessionManager; private final QueryHandlerFactory queryHandlerFactory; private final QueryIdProvider queryIdProvider; + private final SQLQueryValidator sqlQueryValidator; public DispatchQueryResponse dispatch( DispatchQueryRequest dispatchQueryRequest, @@ -54,13 +55,7 @@ public DispatchQueryResponse dispatch( dispatchQueryRequest, asyncQueryRequestContext, dataSourceMetadata); } - List validationErrors = - SQLQueryUtils.validateSparkSqlQuery( - dataSourceService.getDataSource(dispatchQueryRequest.getDatasource()), query); - if (!validationErrors.isEmpty()) { - throw new IllegalArgumentException( - "Query is not allowed: " + String.join(", ", validationErrors)); - } + sqlQueryValidator.validate(query, dataSourceMetadata.getConnector()); } return handleDefaultQuery(dispatchQueryRequest, asyncQueryRequestContext, dataSourceMetadata); } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java b/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java index 7550de2f1e..92717acd9c 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java @@ -103,6 +103,14 @@ public static List validateSparkSqlQuery(DataSource datasource, String s } } + public static SqlBaseParser getBaseParser(String sqlQuery) { + SqlBaseParser sqlBaseParser = + new SqlBaseParser( + new CommonTokenStream(new SqlBaseLexer(new CaseInsensitiveCharStream(sqlQuery)))); + sqlBaseParser.addErrorListener(new SyntaxAnalysisErrorListener()); + return sqlBaseParser; + } + private SqlBaseValidatorVisitor getSparkSqlValidatorVisitor(DataSource datasource) { if (datasource != null && datasource.getConnectorType() != null diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/CloudWatchLogsGrammarElementValidator.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/CloudWatchLogsGrammarElementValidator.java new file mode 100644 index 0000000000..6a78601191 --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/CloudWatchLogsGrammarElementValidator.java @@ -0,0 +1,76 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.validator; + +import static org.opensearch.sql.spark.validator.GrammarElement.*; + +import com.google.common.collect.ImmutableSet; +import java.util.Set; + +public class CloudWatchLogsGrammarElementValidator extends DenyListGrammarElementValidator { + private static final Set CWL_DENY_LIST = + ImmutableSet.builder() + .add( + ALTER_NAMESPACE, + ALTER_VIEW, + CREATE_NAMESPACE, + CREATE_FUNCTION, + CREATE_VIEW, + DROP_FUNCTION, + DROP_NAMESPACE, + DROP_VIEW, + REPAIR_TABLE, + TRUNCATE_TABLE, + INSERT, + LOAD, + EXPLAIN, + WITH, + CLUSTER_BY, + DISTRIBUTE_BY, + HINTS, + INLINE_TABLE, + FILE, + CROSS_JOIN, + LEFT_SEMI_JOIN, + RIGHT_OUTER_JOIN, + FULL_OUTER_JOIN, + LEFT_ANTI_JOIN, + TABLESAMPLE, + TABLE_VALUED_FUNCTION, + LATERAL_VIEW, + LATERAL_SUBQUERY, + TRANSFORM, + MANAGE_RESOURCE, + ANALYZE_TABLE, + CACHE_TABLE, + DESCRIBE_NAMESPACE, + DESCRIBE_FUNCTION, + DESCRIBE_QUERY, + DESCRIBE_TABLE, + REFRESH_RESOURCE, + REFRESH_TABLE, + REFRESH_FUNCTION, + RESET, + SET, + SHOW_COLUMNS, + SHOW_CREATE_TABLE, + SHOW_NAMESPACES, + SHOW_FUNCTIONS, + SHOW_PARTITIONS, + SHOW_TABLE_EXTENDED, + SHOW_TABLES, + SHOW_TBLPROPERTIES, + SHOW_VIEWS, + UNCACHE_TABLE, + CSV_FUNCTIONS, + MISC_FUNCTIONS, + UDF) + .build(); + + public CloudWatchLogsGrammarElementValidator() { + super(CWL_DENY_LIST); + } +} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElement.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElement.java index 3ee33d38fa..d14387a7ac 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElement.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElement.java @@ -9,16 +9,16 @@ @AllArgsConstructor enum GrammarElement { - ALTER_NAMESPACE("ALTER DATABASE/TABLE/NAMESPACE"), + ALTER_NAMESPACE("ALTER (DATABASE|TABLE|NAMESPACE)"), ALTER_VIEW("ALTER VIEW"), - CREATE_NAMESPACE("CREATE DATABASE/TABLE/NAMESPACE"), + CREATE_NAMESPACE("CREATE (DATABASE|TABLE|NAMESPACE)"), CREATE_FUNCTION("CREATE FUNCTION"), CREATE_VIEW("CREATE VIEW"), - DROP_NAMESPACE("DROP DATABASE/TABLE/NAMESPACE"), + DROP_NAMESPACE("DROP (DATABASE|TABLE|NAMESPACE)"), DROP_FUNCTION("DROP FUNCTION"), DROP_VIEW("DROP VIEW"), DROP_TABLE("DROP TABLE"), - REPAIR_TABLE("REPAIR TABLE"), // does this conflict with DROP_NAMESPACE? + REPAIR_TABLE("REPAIR TABLE"), TRUNCATE_TABLE("TRUNCATE TABLE"), // DML Statements INSERT("INSERT"), @@ -52,7 +52,7 @@ enum GrammarElement { ANALYZE_TABLE("ANALYZE TABLE(S)"), CACHE_TABLE("CACHE TABLE"), CLEAR_CACHE("CLEAR CACHE"), - DESCRIBE_NAMESPACE("DESCRIBE (NAMESPACE|DATABASE|SCHEMA"), + DESCRIBE_NAMESPACE("DESCRIBE (NAMESPACE|DATABASE|SCHEMA)"), DESCRIBE_FUNCTION("DESCRIBE FUNCTION"), DESCRIBE_QUERY("DESCRIBE QUERY"), DESCRIBE_TABLE("DESCRIBE TABLE"), diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorFactory.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorFactory.java index ba2672e3df..68c230fc71 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorFactory.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorFactory.java @@ -5,178 +5,21 @@ package org.opensearch.sql.spark.validator; -import static org.opensearch.sql.spark.validator.GrammarElement.*; - import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import java.util.Map; -import java.util.Set; import org.opensearch.sql.datasource.model.DataSourceType; public class GrammarElementValidatorFactory { - // Deny List for CloudWatch Logs datasource - private static final Set CWL_DENY_LIST = - ImmutableSet.builder() - .add( - ALTER_NAMESPACE, - ALTER_VIEW, - CREATE_NAMESPACE, - CREATE_FUNCTION, - CREATE_VIEW, - DROP_FUNCTION, - DROP_NAMESPACE, - DROP_VIEW, - REPAIR_TABLE, - TRUNCATE_TABLE, - INSERT, - LOAD, - EXPLAIN, - WITH, - CLUSTER_BY, - DISTRIBUTE_BY, - HINTS, - INLINE_TABLE, - FILE, - CROSS_JOIN, - LEFT_SEMI_JOIN, - RIGHT_OUTER_JOIN, - FULL_OUTER_JOIN, - LEFT_ANTI_JOIN, - TABLESAMPLE, - TABLE_VALUED_FUNCTION, - LATERAL_VIEW, - LATERAL_SUBQUERY, - TRANSFORM, - MANAGE_RESOURCE, - ANALYZE_TABLE, - CACHE_TABLE, - DESCRIBE_NAMESPACE, - DESCRIBE_FUNCTION, - DESCRIBE_QUERY, - DESCRIBE_TABLE, - REFRESH_RESOURCE, - REFRESH_TABLE, - REFRESH_FUNCTION, - RESET, - SET, - SHOW_COLUMNS, - SHOW_CREATE_TABLE, - SHOW_NAMESPACES, - SHOW_FUNCTIONS, - SHOW_PARTITIONS, - SHOW_TABLE_EXTENDED, - SHOW_TABLES, - SHOW_TBLPROPERTIES, - SHOW_VIEWS, - UNCACHE_TABLE, - CSV_FUNCTIONS, - MISC_FUNCTIONS, - UDF) - .build(); - - // Deny list for S3 Glue datasource - private static final Set S3GLUE_DENY_LIST = - ImmutableSet.builder() - .add( - ALTER_VIEW, - CREATE_FUNCTION, - CREATE_VIEW, - DROP_FUNCTION, - DROP_VIEW, - INSERT, - LOAD, - CLUSTER_BY, - DISTRIBUTE_BY, - HINTS, - INLINE_TABLE, - FILE, - CROSS_JOIN, - LEFT_SEMI_JOIN, - RIGHT_OUTER_JOIN, - FULL_OUTER_JOIN, - LEFT_ANTI_JOIN, - TABLESAMPLE, - TABLE_VALUED_FUNCTION, - TRANSFORM, - MANAGE_RESOURCE, - DESCRIBE_FUNCTION, - REFRESH_RESOURCE, - REFRESH_FUNCTION, - RESET, - SET, - SHOW_FUNCTIONS, - SHOW_VIEWS, - MISC_FUNCTIONS, - UDF) - .build(); - - // Deny list for Security Lake datasource - private static final Set SL_DENY_LIST = - ImmutableSet.builder() - .add( - ALTER_NAMESPACE, - ALTER_VIEW, - CREATE_NAMESPACE, - CREATE_FUNCTION, - CREATE_VIEW, - DROP_FUNCTION, - DROP_NAMESPACE, - DROP_VIEW, - REPAIR_TABLE, - TRUNCATE_TABLE, - INSERT, - LOAD, - CLUSTER_BY, - DISTRIBUTE_BY, - HINTS, - INLINE_TABLE, - FILE, - CROSS_JOIN, - LEFT_SEMI_JOIN, - RIGHT_OUTER_JOIN, - FULL_OUTER_JOIN, - LEFT_ANTI_JOIN, - TABLESAMPLE, - TABLE_VALUED_FUNCTION, - TRANSFORM, - MANAGE_RESOURCE, - ANALYZE_TABLE, - CACHE_TABLE, - CLEAR_CACHE, - DESCRIBE_NAMESPACE, - DESCRIBE_FUNCTION, - DESCRIBE_QUERY, - DESCRIBE_TABLE, - REFRESH_RESOURCE, - REFRESH_TABLE, - REFRESH_FUNCTION, - RESET, - SET, - SHOW_COLUMNS, - SHOW_CREATE_TABLE, - SHOW_NAMESPACES, - SHOW_FUNCTIONS, - SHOW_PARTITIONS, - SHOW_TABLE_EXTENDED, - SHOW_TABLES, - SHOW_TBLPROPERTIES, - SHOW_VIEWS, - UNCACHE_TABLE, - CSV_FUNCTIONS, - MISC_FUNCTIONS, - UDF) - .build(); + private static GrammarElementValidator defaultValidator = new DenyListGrammarElementValidator( + ImmutableSet.of()); private static Map validatorMap = ImmutableMap.of( - DataSourceType.S3GLUE, new DenyListGrammarElementValidator(S3GLUE_DENY_LIST), - DataSourceType.SECURITY_LAKE, new DenyListGrammarElementValidator(SL_DENY_LIST)); + DataSourceType.S3GLUE, new S3GlueGrammarElementValidator(), + DataSourceType.SECURITY_LAKE, new SecurityLakeGrammarElementValidator()); public GrammarElementValidator getValidatorForDatasource(DataSourceType dataSourceType) { - return validatorMap.get(dataSourceType); - } - - private static ImmutableSet.Builder copyBuilder(Set original) { - return ImmutableSet.builder().addAll(original); + return validatorMap.getOrDefault(dataSourceType, defaultValidator); } } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/S3GlueGrammarElementValidator.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/S3GlueGrammarElementValidator.java new file mode 100644 index 0000000000..9ed1fd9e9e --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/S3GlueGrammarElementValidator.java @@ -0,0 +1,81 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.validator; + +import static org.opensearch.sql.spark.validator.GrammarElement.ALTER_VIEW; +import static org.opensearch.sql.spark.validator.GrammarElement.CLUSTER_BY; +import static org.opensearch.sql.spark.validator.GrammarElement.CREATE_FUNCTION; +import static org.opensearch.sql.spark.validator.GrammarElement.CREATE_VIEW; +import static org.opensearch.sql.spark.validator.GrammarElement.CROSS_JOIN; +import static org.opensearch.sql.spark.validator.GrammarElement.DESCRIBE_FUNCTION; +import static org.opensearch.sql.spark.validator.GrammarElement.DISTRIBUTE_BY; +import static org.opensearch.sql.spark.validator.GrammarElement.DROP_FUNCTION; +import static org.opensearch.sql.spark.validator.GrammarElement.DROP_VIEW; +import static org.opensearch.sql.spark.validator.GrammarElement.FILE; +import static org.opensearch.sql.spark.validator.GrammarElement.FULL_OUTER_JOIN; +import static org.opensearch.sql.spark.validator.GrammarElement.HINTS; +import static org.opensearch.sql.spark.validator.GrammarElement.INLINE_TABLE; +import static org.opensearch.sql.spark.validator.GrammarElement.INSERT; +import static org.opensearch.sql.spark.validator.GrammarElement.LEFT_ANTI_JOIN; +import static org.opensearch.sql.spark.validator.GrammarElement.LEFT_SEMI_JOIN; +import static org.opensearch.sql.spark.validator.GrammarElement.LOAD; +import static org.opensearch.sql.spark.validator.GrammarElement.MANAGE_RESOURCE; +import static org.opensearch.sql.spark.validator.GrammarElement.MISC_FUNCTIONS; +import static org.opensearch.sql.spark.validator.GrammarElement.REFRESH_FUNCTION; +import static org.opensearch.sql.spark.validator.GrammarElement.REFRESH_RESOURCE; +import static org.opensearch.sql.spark.validator.GrammarElement.RESET; +import static org.opensearch.sql.spark.validator.GrammarElement.RIGHT_OUTER_JOIN; +import static org.opensearch.sql.spark.validator.GrammarElement.SET; +import static org.opensearch.sql.spark.validator.GrammarElement.SHOW_FUNCTIONS; +import static org.opensearch.sql.spark.validator.GrammarElement.SHOW_VIEWS; +import static org.opensearch.sql.spark.validator.GrammarElement.TABLESAMPLE; +import static org.opensearch.sql.spark.validator.GrammarElement.TABLE_VALUED_FUNCTION; +import static org.opensearch.sql.spark.validator.GrammarElement.TRANSFORM; +import static org.opensearch.sql.spark.validator.GrammarElement.UDF; + +import com.google.common.collect.ImmutableSet; +import java.util.Set; + +public class S3GlueGrammarElementValidator extends DenyListGrammarElementValidator { + private static final Set S3GLUE_DENY_LIST = + ImmutableSet.builder() + .add( + ALTER_VIEW, + CREATE_FUNCTION, + CREATE_VIEW, + DROP_FUNCTION, + DROP_VIEW, + INSERT, + LOAD, + CLUSTER_BY, + DISTRIBUTE_BY, + HINTS, + INLINE_TABLE, + FILE, + CROSS_JOIN, + LEFT_SEMI_JOIN, + RIGHT_OUTER_JOIN, + FULL_OUTER_JOIN, + LEFT_ANTI_JOIN, + TABLESAMPLE, + TABLE_VALUED_FUNCTION, + TRANSFORM, + MANAGE_RESOURCE, + DESCRIBE_FUNCTION, + REFRESH_RESOURCE, + REFRESH_FUNCTION, + RESET, + SET, + SHOW_FUNCTIONS, + SHOW_VIEWS, + MISC_FUNCTIONS, + UDF) + .build(); + + public S3GlueGrammarElementValidator() { + super(S3GLUE_DENY_LIST); + } +} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidationVisitor.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidationVisitor.java new file mode 100644 index 0000000000..605a1d33be --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidationVisitor.java @@ -0,0 +1,625 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.validator; + +import lombok.AllArgsConstructor; +import org.antlr.v4.runtime.tree.TerminalNode; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.AddTableColumnsContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.AddTablePartitionContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.AlterClusterByContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.AlterTableAlterColumnContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.AlterViewQueryContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.AlterViewSchemaBindingContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.AnalyzeContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.AnalyzeTablesContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.CacheTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ClearCacheContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ClusterBySpecContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.CreateNamespaceContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.CreateTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.CreateTableLikeContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.CreateViewContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.CtesContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DescribeFunctionContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DescribeNamespaceContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DescribeQueryContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DescribeRelationContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DropFunctionContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DropNamespaceContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DropTableColumnsContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DropTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DropTablePartitionsContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DropViewContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ExplainContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.FunctionIdentifierContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.FunctionNameContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.HintContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.HiveReplaceColumnsContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.InlineTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.InsertIntoReplaceWhereContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.InsertIntoTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.InsertOverwriteDirContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.InsertOverwriteHiveDirContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.InsertOverwriteTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.JoinRelationContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.JoinTypeContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.LateralViewContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.LoadDataContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ManageResourceContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.QueryOrganizationContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RecoverPartitionsContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RefreshFunctionContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RefreshResourceContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RefreshTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RelationContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RenameTableColumnContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RenameTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RenameTablePartitionContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RepairTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ReplaceTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ResetConfigurationContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ResetQuotedConfigurationContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SampleContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SetConfigurationContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SetNamespaceLocationContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SetNamespacePropertiesContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SetQuantifierContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SetTableLocationContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SetTableSerDeContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowColumnsContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowCreateTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowFunctionsContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowNamespacesContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowPartitionsContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowTableExtendedContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowTablesContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowTblPropertiesContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowViewsContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.TableNameContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.TableValuedFunctionContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.TransformClauseContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.TruncateTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.UncacheTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.UnsetNamespacePropertiesContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParserBaseVisitor; + +@AllArgsConstructor +public class SQLQueryValidationVisitor extends SqlBaseParserBaseVisitor { + private final GrammarElementValidator grammarElementValidator; + + @Override + public Void visitCreateFunction(SqlBaseParser.CreateFunctionContext ctx) { + validateAllowed(GrammarElement.CREATE_FUNCTION); + return super.visitCreateFunction(ctx); + } + + @Override + public Void visitSetNamespaceProperties(SetNamespacePropertiesContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitSetNamespaceProperties(ctx); + } + + @Override + public Void visitUnsetNamespaceProperties(UnsetNamespacePropertiesContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitUnsetNamespaceProperties(ctx); + } + + @Override + public Void visitAddTableColumns(AddTableColumnsContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitAddTableColumns(ctx); + } + + @Override + public Void visitAddTablePartition(AddTablePartitionContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitAddTablePartition(ctx); + } + + @Override + public Void visitRenameTableColumn(RenameTableColumnContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitRenameTableColumn(ctx); + } + + @Override + public Void visitDropTableColumns(DropTableColumnsContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitDropTableColumns(ctx); + } + + @Override + public Void visitAlterTableAlterColumn(AlterTableAlterColumnContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitAlterTableAlterColumn(ctx); + } + + @Override + public Void visitHiveReplaceColumns(HiveReplaceColumnsContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitHiveReplaceColumns(ctx); + } + + @Override + public Void visitSetTableSerDe(SetTableSerDeContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitSetTableSerDe(ctx); + } + + @Override + public Void visitRenameTablePartition(RenameTablePartitionContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitRenameTablePartition(ctx); + } + + @Override + public Void visitDropTablePartitions(DropTablePartitionsContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitDropTablePartitions(ctx); + } + + @Override + public Void visitSetTableLocation(SetTableLocationContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitSetTableLocation(ctx); + } + + @Override + public Void visitRecoverPartitions(RecoverPartitionsContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitRecoverPartitions(ctx); + } + + @Override + public Void visitAlterClusterBy(AlterClusterByContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitAlterClusterBy(ctx); + } + + @Override + public Void visitSetNamespaceLocation(SetNamespaceLocationContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitSetNamespaceLocation(ctx); + } + + @Override + public Void visitAlterViewQuery(AlterViewQueryContext ctx) { + validateAllowed(GrammarElement.ALTER_VIEW); + return super.visitAlterViewQuery(ctx); + } + + @Override + public Void visitAlterViewSchemaBinding(AlterViewSchemaBindingContext ctx) { + validateAllowed(GrammarElement.ALTER_VIEW); + return super.visitAlterViewSchemaBinding(ctx); + } + + @Override + public Void visitRenameTable(RenameTableContext ctx) { + TerminalNode view = ctx.VIEW(); + TerminalNode table = ctx.TABLE(); + if (ctx.VIEW() != null) { + validateAllowed(GrammarElement.ALTER_VIEW); + } else if (ctx.TABLE() != null) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + } + + return super.visitRenameTable(ctx); + } + + @Override + public Void visitCreateNamespace(CreateNamespaceContext ctx) { + validateAllowed(GrammarElement.CREATE_NAMESPACE); + return super.visitCreateNamespace(ctx); + } + + @Override + public Void visitCreateTable(CreateTableContext ctx) { + validateAllowed(GrammarElement.CREATE_NAMESPACE); + return super.visitCreateTable(ctx); + } + + @Override + public Void visitCreateTableLike(CreateTableLikeContext ctx) { + validateAllowed(GrammarElement.CREATE_NAMESPACE); + return super.visitCreateTableLike(ctx); + } + + @Override + public Void visitReplaceTable(ReplaceTableContext ctx) { + validateAllowed(GrammarElement.CREATE_NAMESPACE); + return super.visitReplaceTable(ctx); + } + + @Override + public Void visitDropNamespace(DropNamespaceContext ctx) { + validateAllowed(GrammarElement.DROP_NAMESPACE); + return super.visitDropNamespace(ctx); + } + + @Override + public Void visitDropTable(DropTableContext ctx) { + validateAllowed(GrammarElement.DROP_NAMESPACE); + return super.visitDropTable(ctx); + } + + @Override + public Void visitCreateView(CreateViewContext ctx) { + validateAllowed(GrammarElement.CREATE_VIEW); + return super.visitCreateView(ctx); + } + + @Override + public Void visitDropView(DropViewContext ctx) { + validateAllowed(GrammarElement.DROP_VIEW); + return super.visitDropView(ctx); + } + + @Override + public Void visitDropFunction(DropFunctionContext ctx) { + validateAllowed(GrammarElement.DROP_FUNCTION); + return super.visitDropFunction(ctx); + } + + @Override + public Void visitRepairTable(RepairTableContext ctx) { + validateAllowed(GrammarElement.REPAIR_TABLE); + return super.visitRepairTable(ctx); + } + + @Override + public Void visitTruncateTable(TruncateTableContext ctx) { + validateAllowed(GrammarElement.TRUNCATE_TABLE); + return super.visitTruncateTable(ctx); + } + + @Override + public Void visitInsertOverwriteTable(InsertOverwriteTableContext ctx) { + validateAllowed(GrammarElement.INSERT); + return super.visitInsertOverwriteTable(ctx); + } + + @Override + public Void visitInsertIntoReplaceWhere(InsertIntoReplaceWhereContext ctx) { + validateAllowed(GrammarElement.INSERT); + return super.visitInsertIntoReplaceWhere(ctx); + } + + @Override + public Void visitInsertIntoTable(InsertIntoTableContext ctx) { + validateAllowed(GrammarElement.INSERT); + return super.visitInsertIntoTable(ctx); + } + + @Override + public Void visitInsertOverwriteDir(InsertOverwriteDirContext ctx) { + validateAllowed(GrammarElement.INSERT); + return super.visitInsertOverwriteDir(ctx); + } + + @Override + public Void visitInsertOverwriteHiveDir(InsertOverwriteHiveDirContext ctx) { + validateAllowed(GrammarElement.INSERT); + return super.visitInsertOverwriteHiveDir(ctx); + } + + @Override + public Void visitLoadData(LoadDataContext ctx) { + validateAllowed(GrammarElement.LOAD); + return super.visitLoadData(ctx); + } + + @Override + public Void visitExplain(ExplainContext ctx) { + validateAllowed(GrammarElement.EXPLAIN); + return super.visitExplain(ctx); + } + + @Override + public Void visitTableName(TableNameContext ctx) { + String reference = ctx.identifierReference().getText(); + if (isFileReference(reference)) { + validateAllowed(GrammarElement.FILE); + } + return super.visitTableName(ctx); + } + + private static final String FILE_REFERENCE_PATTERN = "^[a-zA-Z]+\\.`[^`]+`$"; + + private boolean isFileReference(String reference) { + return reference.matches(FILE_REFERENCE_PATTERN); + } + + @Override + public Void visitCtes(CtesContext ctx) { + validateAllowed(GrammarElement.WITH); + return super.visitCtes(ctx); + } + + @Override + public Void visitClusterBySpec(ClusterBySpecContext ctx) { + validateAllowed(GrammarElement.CLUSTER_BY); + return super.visitClusterBySpec(ctx); + } + + @Override + public Void visitQueryOrganization(QueryOrganizationContext ctx) { + if (ctx.CLUSTER() != null) { + validateAllowed(GrammarElement.CLUSTER_BY); + } else if (ctx.DISTRIBUTE() != null) { + validateAllowed(GrammarElement.DISTRIBUTE_BY); + } + return super.visitQueryOrganization(ctx); + } + + @Override + public Void visitHint(HintContext ctx) { + validateAllowed(GrammarElement.HINTS); + return super.visitHint(ctx); + } + + @Override + public Void visitInlineTable(InlineTableContext ctx) { + validateAllowed(GrammarElement.INLINE_TABLE); + return super.visitInlineTable(ctx); + } + + @Override + public Void visitJoinType(JoinTypeContext ctx) { + if (ctx.CROSS() != null) { + validateAllowed(GrammarElement.CROSS_JOIN); + } else if (ctx.LEFT() != null && ctx.SEMI() != null) { + validateAllowed(GrammarElement.LEFT_SEMI_JOIN); + } else if (ctx.ANTI() != null) { + validateAllowed(GrammarElement.LEFT_ANTI_JOIN); + } else if (ctx.LEFT() != null) { + validateAllowed(GrammarElement.LEFT_OUTER_JOIN); + } else if (ctx.RIGHT() != null) { + validateAllowed(GrammarElement.RIGHT_OUTER_JOIN); + } else if (ctx.FULL() != null) { + validateAllowed(GrammarElement.FULL_OUTER_JOIN); + } else { + validateAllowed(GrammarElement.INNER_JOIN); + } + return super.visitJoinType(ctx); + } + + @Override + public Void visitSample(SampleContext ctx) { + validateAllowed(GrammarElement.TABLESAMPLE); + return super.visitSample(ctx); + } + + @Override + public Void visitTableValuedFunction(TableValuedFunctionContext ctx) { + validateAllowed(GrammarElement.TABLE_VALUED_FUNCTION); + return super.visitTableValuedFunction(ctx); + } + + @Override + public Void visitLateralView(LateralViewContext ctx) { + validateAllowed(GrammarElement.LATERAL_VIEW); + return super.visitLateralView(ctx); + } + + @Override + public Void visitRelation(RelationContext ctx) { + if (ctx.LATERAL() != null) { + validateAllowed(GrammarElement.LATERAL_SUBQUERY); + } + return super.visitRelation(ctx); + } + + @Override + public Void visitJoinRelation(JoinRelationContext ctx) { + if (ctx.LATERAL() != null) { + validateAllowed(GrammarElement.LATERAL_SUBQUERY); + } + return super.visitJoinRelation(ctx); + } + + @Override + public Void visitTransformClause(TransformClauseContext ctx) { + if (ctx.TRANSFORM() != null) { + validateAllowed(GrammarElement.TRANSFORM); + } + return super.visitTransformClause(ctx); + } + + @Override + public Void visitManageResource(ManageResourceContext ctx) { + validateAllowed(GrammarElement.MANAGE_RESOURCE); + return super.visitManageResource(ctx); + } + + @Override + public Void visitAnalyze(AnalyzeContext ctx) { + validateAllowed(GrammarElement.ANALYZE_TABLE); + return super.visitAnalyze(ctx); + } + + @Override + public Void visitAnalyzeTables(AnalyzeTablesContext ctx) { + validateAllowed(GrammarElement.ANALYZE_TABLE); + return super.visitAnalyzeTables(ctx); + } + + @Override + public Void visitCacheTable(CacheTableContext ctx) { + validateAllowed(GrammarElement.CACHE_TABLE); + return super.visitCacheTable(ctx); + } + + @Override + public Void visitClearCache(ClearCacheContext ctx) { + validateAllowed(GrammarElement.CLEAR_CACHE); + return super.visitClearCache(ctx); + } + + @Override + public Void visitDescribeNamespace(DescribeNamespaceContext ctx) { + validateAllowed(GrammarElement.DESCRIBE_NAMESPACE); + return super.visitDescribeNamespace(ctx); + } + + @Override + public Void visitDescribeFunction(DescribeFunctionContext ctx) { + validateAllowed(GrammarElement.DESCRIBE_FUNCTION); + return super.visitDescribeFunction(ctx); + } + + @Override + public Void visitDescribeRelation(DescribeRelationContext ctx) { + validateAllowed(GrammarElement.DESCRIBE_TABLE); + return super.visitDescribeRelation(ctx); + } + + @Override + public Void visitDescribeQuery(DescribeQueryContext ctx) { + validateAllowed(GrammarElement.DESCRIBE_QUERY); + return super.visitDescribeQuery(ctx); + } + + @Override + public Void visitRefreshResource(RefreshResourceContext ctx) { + validateAllowed(GrammarElement.REFRESH_RESOURCE); + return super.visitRefreshResource(ctx); + } + + @Override + public Void visitRefreshTable(RefreshTableContext ctx) { + validateAllowed(GrammarElement.REFRESH_TABLE); + return super.visitRefreshTable(ctx); + } + + @Override + public Void visitRefreshFunction(RefreshFunctionContext ctx) { + validateAllowed(GrammarElement.REFRESH_FUNCTION); + return super.visitRefreshFunction(ctx); + } + + @Override + public Void visitResetConfiguration(ResetConfigurationContext ctx) { + validateAllowed(GrammarElement.RESET); + return super.visitResetConfiguration(ctx); + } + + @Override + public Void visitResetQuotedConfiguration(ResetQuotedConfigurationContext ctx) { + validateAllowed(GrammarElement.RESET); + return super.visitResetQuotedConfiguration(ctx); + } + + @Override + public Void visitSetConfiguration(SetConfigurationContext ctx) { + validateAllowed(GrammarElement.SET); + return super.visitSetConfiguration(ctx); + } + + @Override + public Void visitSetQuantifier(SetQuantifierContext ctx) { + validateAllowed(GrammarElement.SET); + return super.visitSetQuantifier(ctx); + } + + @Override + public Void visitShowColumns(ShowColumnsContext ctx) { + validateAllowed(GrammarElement.SHOW_COLUMNS); + return super.visitShowColumns(ctx); + } + + @Override + public Void visitShowCreateTable(ShowCreateTableContext ctx) { + validateAllowed(GrammarElement.SHOW_CREATE_TABLE); + return super.visitShowCreateTable(ctx); + } + + @Override + public Void visitShowNamespaces(ShowNamespacesContext ctx) { + validateAllowed(GrammarElement.SHOW_NAMESPACES); + return super.visitShowNamespaces(ctx); + } + + @Override + public Void visitShowFunctions(ShowFunctionsContext ctx) { + validateAllowed(GrammarElement.SHOW_FUNCTIONS); + return super.visitShowFunctions(ctx); + } + + @Override + public Void visitShowPartitions(ShowPartitionsContext ctx) { + validateAllowed(GrammarElement.SHOW_PARTITIONS); + return super.visitShowPartitions(ctx); + } + + @Override + public Void visitShowTableExtended(ShowTableExtendedContext ctx) { + validateAllowed(GrammarElement.SHOW_TABLE_EXTENDED); + return super.visitShowTableExtended(ctx); + } + + @Override + public Void visitShowTables(ShowTablesContext ctx) { + validateAllowed(GrammarElement.SHOW_TABLES); + return super.visitShowTables(ctx); + } + + @Override + public Void visitShowTblProperties(ShowTblPropertiesContext ctx) { + validateAllowed(GrammarElement.SHOW_TBLPROPERTIES); + return super.visitShowTblProperties(ctx); + } + + @Override + public Void visitShowViews(ShowViewsContext ctx) { + validateAllowed(GrammarElement.SHOW_VIEWS); + return super.visitShowViews(ctx); + } + + @Override + public Void visitUncacheTable(UncacheTableContext ctx) { + validateAllowed(GrammarElement.UNCACHE_TABLE); + return super.visitUncacheTable(ctx); + } + + @Override + public Void visitFunctionIdentifier(FunctionIdentifierContext ctx) { + validateFunctionAllowed(ctx.function.getText()); + return super.visitFunctionIdentifier(ctx); + } + + @Override + public Void visitFunctionName(FunctionNameContext ctx) { + validateFunctionAllowed(ctx.qualifiedName().getText()); + return super.visitFunctionName(ctx); + } + + private void validateFunctionAllowed(String function) { + FunctionType type = FunctionType.fromFunctionName(function.toLowerCase()); + switch (type) { + case MAP: + validateAllowed(GrammarElement.MAP_FUNCTIONS); + break; + case CSV: + validateAllowed(GrammarElement.CSV_FUNCTIONS); + break; + case MISC: + validateAllowed(GrammarElement.MISC_FUNCTIONS); + break; + case UDF: + validateAllowed(GrammarElement.UDF); + break; + } + } + + private void validateAllowed(GrammarElement element) { + if (!grammarElementValidator.isValid(element)) { + throw new IllegalArgumentException(element + " is not allowed."); + } + } +} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java index 2fd91bf87a..6d41a13db8 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java @@ -6,625 +6,17 @@ package org.opensearch.sql.spark.validator; import lombok.AllArgsConstructor; -import org.antlr.v4.runtime.tree.TerminalNode; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.AddTableColumnsContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.AddTablePartitionContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.AlterClusterByContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.AlterTableAlterColumnContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.AlterViewQueryContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.AlterViewSchemaBindingContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.AnalyzeContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.AnalyzeTablesContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.CacheTableContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ClearCacheContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ClusterBySpecContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.CreateNamespaceContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.CreateTableContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.CreateTableLikeContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.CreateViewContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.CtesContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DescribeFunctionContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DescribeNamespaceContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DescribeQueryContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DescribeRelationContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DropFunctionContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DropNamespaceContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DropTableColumnsContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DropTableContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DropTablePartitionsContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DropViewContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ExplainContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.FunctionIdentifierContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.FunctionNameContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.HintContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.HiveReplaceColumnsContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.InlineTableContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.InsertIntoReplaceWhereContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.InsertIntoTableContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.InsertOverwriteDirContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.InsertOverwriteHiveDirContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.InsertOverwriteTableContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.JoinRelationContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.JoinTypeContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.LateralViewContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.LoadDataContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ManageResourceContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.QueryOrganizationContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RecoverPartitionsContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RefreshFunctionContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RefreshResourceContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RefreshTableContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RelationContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RenameTableColumnContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RenameTableContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RenameTablePartitionContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RepairTableContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ReplaceTableContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ResetConfigurationContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ResetQuotedConfigurationContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SampleContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SetConfigurationContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SetNamespaceLocationContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SetNamespacePropertiesContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SetQuantifierContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SetTableLocationContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SetTableSerDeContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowColumnsContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowCreateTableContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowFunctionsContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowNamespacesContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowPartitionsContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowTableExtendedContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowTablesContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowTblPropertiesContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowViewsContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.TableNameContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.TableValuedFunctionContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.TransformClauseContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.TruncateTableContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.UncacheTableContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.UnsetNamespacePropertiesContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParserBaseVisitor; +import org.opensearch.sql.datasource.model.DataSourceType; +import org.opensearch.sql.spark.utils.SQLQueryUtils; @AllArgsConstructor -public class SQLQueryValidator extends SqlBaseParserBaseVisitor { - private final GrammarElementValidator grammarElementValidator; - - public void validate(SqlBaseParser.SingleStatementContext statement) { - this.visit(statement); - } - - @Override - public Void visitCreateFunction(SqlBaseParser.CreateFunctionContext ctx) { - validateAllowed(GrammarElement.CREATE_FUNCTION); - return super.visitCreateFunction(ctx); - } - - @Override - public Void visitSetNamespaceProperties(SetNamespacePropertiesContext ctx) { - validateAllowed(GrammarElement.ALTER_NAMESPACE); - return super.visitSetNamespaceProperties(ctx); - } - - @Override - public Void visitUnsetNamespaceProperties(UnsetNamespacePropertiesContext ctx) { - validateAllowed(GrammarElement.ALTER_NAMESPACE); - return super.visitUnsetNamespaceProperties(ctx); - } - - @Override - public Void visitAddTableColumns(AddTableColumnsContext ctx) { - validateAllowed(GrammarElement.ALTER_NAMESPACE); - return super.visitAddTableColumns(ctx); - } - - @Override - public Void visitAddTablePartition(AddTablePartitionContext ctx) { - validateAllowed(GrammarElement.ALTER_NAMESPACE); - return super.visitAddTablePartition(ctx); - } - - @Override - public Void visitRenameTableColumn(RenameTableColumnContext ctx) { - validateAllowed(GrammarElement.ALTER_NAMESPACE); - return super.visitRenameTableColumn(ctx); - } - - @Override - public Void visitDropTableColumns(DropTableColumnsContext ctx) { - validateAllowed(GrammarElement.ALTER_NAMESPACE); - return super.visitDropTableColumns(ctx); - } - - @Override - public Void visitAlterTableAlterColumn(AlterTableAlterColumnContext ctx) { - validateAllowed(GrammarElement.ALTER_NAMESPACE); - return super.visitAlterTableAlterColumn(ctx); - } - - @Override - public Void visitHiveReplaceColumns(HiveReplaceColumnsContext ctx) { - validateAllowed(GrammarElement.ALTER_NAMESPACE); - return super.visitHiveReplaceColumns(ctx); - } - - @Override - public Void visitSetTableSerDe(SetTableSerDeContext ctx) { - validateAllowed(GrammarElement.ALTER_NAMESPACE); - return super.visitSetTableSerDe(ctx); - } - - @Override - public Void visitRenameTablePartition(RenameTablePartitionContext ctx) { - validateAllowed(GrammarElement.ALTER_NAMESPACE); - return super.visitRenameTablePartition(ctx); - } - - @Override - public Void visitDropTablePartitions(DropTablePartitionsContext ctx) { - validateAllowed(GrammarElement.ALTER_NAMESPACE); - return super.visitDropTablePartitions(ctx); - } - - @Override - public Void visitSetTableLocation(SetTableLocationContext ctx) { - validateAllowed(GrammarElement.ALTER_NAMESPACE); - return super.visitSetTableLocation(ctx); - } - - @Override - public Void visitRecoverPartitions(RecoverPartitionsContext ctx) { - validateAllowed(GrammarElement.ALTER_NAMESPACE); - return super.visitRecoverPartitions(ctx); - } - - @Override - public Void visitAlterClusterBy(AlterClusterByContext ctx) { - validateAllowed(GrammarElement.ALTER_NAMESPACE); - return super.visitAlterClusterBy(ctx); - } - - @Override - public Void visitSetNamespaceLocation(SetNamespaceLocationContext ctx) { - validateAllowed(GrammarElement.ALTER_NAMESPACE); - return super.visitSetNamespaceLocation(ctx); - } - - @Override - public Void visitAlterViewQuery(AlterViewQueryContext ctx) { - validateAllowed(GrammarElement.ALTER_VIEW); - return super.visitAlterViewQuery(ctx); - } - - @Override - public Void visitAlterViewSchemaBinding(AlterViewSchemaBindingContext ctx) { - validateAllowed(GrammarElement.ALTER_VIEW); - return super.visitAlterViewSchemaBinding(ctx); - } - - @Override - public Void visitRenameTable(RenameTableContext ctx) { - TerminalNode view = ctx.VIEW(); - TerminalNode table = ctx.TABLE(); - if (ctx.VIEW() != null) { - validateAllowed(GrammarElement.ALTER_VIEW); - } else if (ctx.TABLE() != null) { - validateAllowed(GrammarElement.ALTER_NAMESPACE); - } - - return super.visitRenameTable(ctx); - } - - @Override - public Void visitCreateNamespace(CreateNamespaceContext ctx) { - validateAllowed(GrammarElement.CREATE_NAMESPACE); - return super.visitCreateNamespace(ctx); - } - - @Override - public Void visitCreateTable(CreateTableContext ctx) { - validateAllowed(GrammarElement.CREATE_NAMESPACE); - return super.visitCreateTable(ctx); - } - - @Override - public Void visitCreateTableLike(CreateTableLikeContext ctx) { - validateAllowed(GrammarElement.CREATE_NAMESPACE); - return super.visitCreateTableLike(ctx); - } - - @Override - public Void visitReplaceTable(ReplaceTableContext ctx) { - validateAllowed(GrammarElement.CREATE_NAMESPACE); - return super.visitReplaceTable(ctx); - } - - @Override - public Void visitDropNamespace(DropNamespaceContext ctx) { - validateAllowed(GrammarElement.DROP_NAMESPACE); - return super.visitDropNamespace(ctx); - } - - @Override - public Void visitDropTable(DropTableContext ctx) { - validateAllowed(GrammarElement.DROP_NAMESPACE); - return super.visitDropTable(ctx); - } - - @Override - public Void visitCreateView(CreateViewContext ctx) { - validateAllowed(GrammarElement.CREATE_VIEW); - return super.visitCreateView(ctx); - } - - @Override - public Void visitDropView(DropViewContext ctx) { - validateAllowed(GrammarElement.DROP_VIEW); - return super.visitDropView(ctx); - } - - @Override - public Void visitDropFunction(DropFunctionContext ctx) { - validateAllowed(GrammarElement.DROP_FUNCTION); - return super.visitDropFunction(ctx); - } - - @Override - public Void visitRepairTable(RepairTableContext ctx) { - validateAllowed(GrammarElement.REPAIR_TABLE); - return super.visitRepairTable(ctx); - } - - @Override - public Void visitTruncateTable(TruncateTableContext ctx) { - validateAllowed(GrammarElement.TRUNCATE_TABLE); - return super.visitTruncateTable(ctx); - } - - @Override - public Void visitInsertOverwriteTable(InsertOverwriteTableContext ctx) { - validateAllowed(GrammarElement.INSERT); - return super.visitInsertOverwriteTable(ctx); - } - - @Override - public Void visitInsertIntoReplaceWhere(InsertIntoReplaceWhereContext ctx) { - validateAllowed(GrammarElement.INSERT); - return super.visitInsertIntoReplaceWhere(ctx); - } - - @Override - public Void visitInsertIntoTable(InsertIntoTableContext ctx) { - validateAllowed(GrammarElement.INSERT); - return super.visitInsertIntoTable(ctx); - } - - @Override - public Void visitInsertOverwriteDir(InsertOverwriteDirContext ctx) { - validateAllowed(GrammarElement.INSERT); - return super.visitInsertOverwriteDir(ctx); - } - - @Override - public Void visitInsertOverwriteHiveDir(InsertOverwriteHiveDirContext ctx) { - validateAllowed(GrammarElement.INSERT); - return super.visitInsertOverwriteHiveDir(ctx); - } - - @Override - public Void visitLoadData(LoadDataContext ctx) { - validateAllowed(GrammarElement.LOAD); - return super.visitLoadData(ctx); - } - - @Override - public Void visitExplain(ExplainContext ctx) { - validateAllowed(GrammarElement.EXPLAIN); - return super.visitExplain(ctx); - } - - @Override - public Void visitTableName(TableNameContext ctx) { - String reference = ctx.identifierReference().getText(); - System.out.println(reference); - if (isFileReference(reference)) { - validateAllowed(GrammarElement.FILE); - } - return super.visitTableName(ctx); - } - - private static final String FILE_REFERENCE_PATTERN = "^[a-zA-Z]+\\.`[^`]+`$"; - - private boolean isFileReference(String reference) { - return reference.matches(FILE_REFERENCE_PATTERN); - } - - @Override - public Void visitCtes(CtesContext ctx) { - validateAllowed(GrammarElement.WITH); - return super.visitCtes(ctx); - } - - @Override - public Void visitClusterBySpec(ClusterBySpecContext ctx) { - validateAllowed(GrammarElement.CLUSTER_BY); - return super.visitClusterBySpec(ctx); - } - - @Override - public Void visitQueryOrganization(QueryOrganizationContext ctx) { - if (ctx.CLUSTER() != null) { - validateAllowed(GrammarElement.CLUSTER_BY); - } else if (ctx.DISTRIBUTE() != null) { - validateAllowed(GrammarElement.DISTRIBUTE_BY); - } - return super.visitQueryOrganization(ctx); - } - - @Override - public Void visitHint(HintContext ctx) { - validateAllowed(GrammarElement.HINTS); - return super.visitHint(ctx); - } - - @Override - public Void visitInlineTable(InlineTableContext ctx) { - validateAllowed(GrammarElement.INLINE_TABLE); - return super.visitInlineTable(ctx); - } - - @Override - public Void visitJoinType(JoinTypeContext ctx) { - if (ctx.CROSS() != null) { - validateAllowed(GrammarElement.CROSS_JOIN); - } else if (ctx.LEFT() != null && ctx.SEMI() != null) { - validateAllowed(GrammarElement.LEFT_SEMI_JOIN); - } else if (ctx.ANTI() != null) { - validateAllowed(GrammarElement.LEFT_ANTI_JOIN); - } else if (ctx.LEFT() != null) { - validateAllowed(GrammarElement.LEFT_OUTER_JOIN); - } else if (ctx.RIGHT() != null) { - validateAllowed(GrammarElement.RIGHT_OUTER_JOIN); - } else if (ctx.FULL() != null) { - validateAllowed(GrammarElement.FULL_OUTER_JOIN); - } else { - validateAllowed(GrammarElement.INNER_JOIN); - } - return super.visitJoinType(ctx); - } - - @Override - public Void visitSample(SampleContext ctx) { - validateAllowed(GrammarElement.TABLESAMPLE); - return super.visitSample(ctx); - } - - @Override - public Void visitTableValuedFunction(TableValuedFunctionContext ctx) { - validateAllowed(GrammarElement.TABLE_VALUED_FUNCTION); - return super.visitTableValuedFunction(ctx); - } - - @Override - public Void visitLateralView(LateralViewContext ctx) { - validateAllowed(GrammarElement.LATERAL_VIEW); - return super.visitLateralView(ctx); - } - - @Override - public Void visitRelation(RelationContext ctx) { - if (ctx.LATERAL() != null) { - validateAllowed(GrammarElement.LATERAL_SUBQUERY); - } - return super.visitRelation(ctx); - } - - @Override - public Void visitJoinRelation(JoinRelationContext ctx) { - if (ctx.LATERAL() != null) { - validateAllowed(GrammarElement.LATERAL_SUBQUERY); - } - return super.visitJoinRelation(ctx); - } - - @Override - public Void visitTransformClause(TransformClauseContext ctx) { - if (ctx.TRANSFORM() != null) { - validateAllowed(GrammarElement.TRANSFORM); - } - return super.visitTransformClause(ctx); - } - - @Override - public Void visitManageResource(ManageResourceContext ctx) { - validateAllowed(GrammarElement.MANAGE_RESOURCE); - return super.visitManageResource(ctx); - } - - @Override - public Void visitAnalyze(AnalyzeContext ctx) { - validateAllowed(GrammarElement.ANALYZE_TABLE); - return super.visitAnalyze(ctx); - } - - @Override - public Void visitAnalyzeTables(AnalyzeTablesContext ctx) { - validateAllowed(GrammarElement.ANALYZE_TABLE); - return super.visitAnalyzeTables(ctx); - } - - @Override - public Void visitCacheTable(CacheTableContext ctx) { - validateAllowed(GrammarElement.CACHE_TABLE); - return super.visitCacheTable(ctx); - } - - @Override - public Void visitClearCache(ClearCacheContext ctx) { - validateAllowed(GrammarElement.CLEAR_CACHE); - return super.visitClearCache(ctx); - } - - @Override - public Void visitDescribeNamespace(DescribeNamespaceContext ctx) { - validateAllowed(GrammarElement.DESCRIBE_NAMESPACE); - return super.visitDescribeNamespace(ctx); - } - - @Override - public Void visitDescribeFunction(DescribeFunctionContext ctx) { - validateAllowed(GrammarElement.DESCRIBE_FUNCTION); - return super.visitDescribeFunction(ctx); - } - - @Override - public Void visitDescribeRelation(DescribeRelationContext ctx) { - validateAllowed(GrammarElement.DESCRIBE_TABLE); - return super.visitDescribeRelation(ctx); - } - - @Override - public Void visitDescribeQuery(DescribeQueryContext ctx) { - validateAllowed(GrammarElement.DESCRIBE_QUERY); - return super.visitDescribeQuery(ctx); - } - - @Override - public Void visitRefreshResource(RefreshResourceContext ctx) { - validateAllowed(GrammarElement.REFRESH_RESOURCE); - return super.visitRefreshResource(ctx); - } - - @Override - public Void visitRefreshTable(RefreshTableContext ctx) { - validateAllowed(GrammarElement.REFRESH_TABLE); - return super.visitRefreshTable(ctx); - } - - @Override - public Void visitRefreshFunction(RefreshFunctionContext ctx) { - validateAllowed(GrammarElement.REFRESH_FUNCTION); - return super.visitRefreshFunction(ctx); - } - - @Override - public Void visitResetConfiguration(ResetConfigurationContext ctx) { - validateAllowed(GrammarElement.RESET); - return super.visitResetConfiguration(ctx); - } - - @Override - public Void visitResetQuotedConfiguration(ResetQuotedConfigurationContext ctx) { - validateAllowed(GrammarElement.RESET); - return super.visitResetQuotedConfiguration(ctx); - } - - @Override - public Void visitSetConfiguration(SetConfigurationContext ctx) { - validateAllowed(GrammarElement.SET); - return super.visitSetConfiguration(ctx); - } - - @Override - public Void visitSetQuantifier(SetQuantifierContext ctx) { - validateAllowed(GrammarElement.SET); - return super.visitSetQuantifier(ctx); - } - - @Override - public Void visitShowColumns(ShowColumnsContext ctx) { - validateAllowed(GrammarElement.SHOW_COLUMNS); - return super.visitShowColumns(ctx); - } - - @Override - public Void visitShowCreateTable(ShowCreateTableContext ctx) { - validateAllowed(GrammarElement.SHOW_CREATE_TABLE); - return super.visitShowCreateTable(ctx); - } - - @Override - public Void visitShowNamespaces(ShowNamespacesContext ctx) { - validateAllowed(GrammarElement.SHOW_NAMESPACES); - return super.visitShowNamespaces(ctx); - } - - @Override - public Void visitShowFunctions(ShowFunctionsContext ctx) { - validateAllowed(GrammarElement.SHOW_FUNCTIONS); - return super.visitShowFunctions(ctx); - } - - @Override - public Void visitShowPartitions(ShowPartitionsContext ctx) { - validateAllowed(GrammarElement.SHOW_PARTITIONS); - return super.visitShowPartitions(ctx); - } - - @Override - public Void visitShowTableExtended(ShowTableExtendedContext ctx) { - validateAllowed(GrammarElement.SHOW_TABLE_EXTENDED); - return super.visitShowTableExtended(ctx); - } - - @Override - public Void visitShowTables(ShowTablesContext ctx) { - validateAllowed(GrammarElement.SHOW_TABLES); - return super.visitShowTables(ctx); - } - - @Override - public Void visitShowTblProperties(ShowTblPropertiesContext ctx) { - validateAllowed(GrammarElement.SHOW_TBLPROPERTIES); - return super.visitShowTblProperties(ctx); - } - - @Override - public Void visitShowViews(ShowViewsContext ctx) { - validateAllowed(GrammarElement.SHOW_VIEWS); - return super.visitShowViews(ctx); - } - - @Override - public Void visitUncacheTable(UncacheTableContext ctx) { - validateAllowed(GrammarElement.UNCACHE_TABLE); - return super.visitUncacheTable(ctx); - } - - @Override - public Void visitFunctionIdentifier(FunctionIdentifierContext ctx) { - validateFunctionAllowed(ctx.function.getText()); - return super.visitFunctionIdentifier(ctx); - } - - @Override - public Void visitFunctionName(FunctionNameContext ctx) { - validateFunctionAllowed(ctx.qualifiedName().getText()); - return super.visitFunctionName(ctx); - } - - private void validateFunctionAllowed(String function) { - FunctionType type = FunctionType.fromFunctionName(function.toLowerCase()); - switch (type) { - case MAP: - validateAllowed(GrammarElement.MAP_FUNCTIONS); - break; - case CSV: - validateAllowed(GrammarElement.CSV_FUNCTIONS); - break; - case MISC: - validateAllowed(GrammarElement.MISC_FUNCTIONS); - break; - case UDF: - validateAllowed(GrammarElement.UDF); - break; - } - } - - private void validateAllowed(GrammarElement element) { - if (!grammarElementValidator.isValid(element)) { - throw new IllegalArgumentException(element + " is not allowed."); - } +public class SQLQueryValidator { + private final GrammarElementValidatorFactory grammarElementValidatorFactory; + + public void validate(String sqlQuery, DataSourceType datasourceType) { + GrammarElementValidator grammarElementValidator = + grammarElementValidatorFactory.getValidatorForDatasource(datasourceType); + SQLQueryValidationVisitor visitor = new SQLQueryValidationVisitor(grammarElementValidator); + visitor.visit(SQLQueryUtils.getBaseParser(sqlQuery).singleStatement()); } } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SecurityLakeGrammarElementValidator.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SecurityLakeGrammarElementValidator.java new file mode 100644 index 0000000000..7dd2b0ee89 --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SecurityLakeGrammarElementValidator.java @@ -0,0 +1,123 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.validator; + +import static org.opensearch.sql.spark.validator.GrammarElement.ALTER_NAMESPACE; +import static org.opensearch.sql.spark.validator.GrammarElement.ALTER_VIEW; +import static org.opensearch.sql.spark.validator.GrammarElement.ANALYZE_TABLE; +import static org.opensearch.sql.spark.validator.GrammarElement.CACHE_TABLE; +import static org.opensearch.sql.spark.validator.GrammarElement.CLEAR_CACHE; +import static org.opensearch.sql.spark.validator.GrammarElement.CLUSTER_BY; +import static org.opensearch.sql.spark.validator.GrammarElement.CREATE_FUNCTION; +import static org.opensearch.sql.spark.validator.GrammarElement.CREATE_NAMESPACE; +import static org.opensearch.sql.spark.validator.GrammarElement.CREATE_VIEW; +import static org.opensearch.sql.spark.validator.GrammarElement.CROSS_JOIN; +import static org.opensearch.sql.spark.validator.GrammarElement.CSV_FUNCTIONS; +import static org.opensearch.sql.spark.validator.GrammarElement.DESCRIBE_FUNCTION; +import static org.opensearch.sql.spark.validator.GrammarElement.DESCRIBE_NAMESPACE; +import static org.opensearch.sql.spark.validator.GrammarElement.DESCRIBE_QUERY; +import static org.opensearch.sql.spark.validator.GrammarElement.DESCRIBE_TABLE; +import static org.opensearch.sql.spark.validator.GrammarElement.DISTRIBUTE_BY; +import static org.opensearch.sql.spark.validator.GrammarElement.DROP_FUNCTION; +import static org.opensearch.sql.spark.validator.GrammarElement.DROP_NAMESPACE; +import static org.opensearch.sql.spark.validator.GrammarElement.DROP_VIEW; +import static org.opensearch.sql.spark.validator.GrammarElement.FILE; +import static org.opensearch.sql.spark.validator.GrammarElement.FULL_OUTER_JOIN; +import static org.opensearch.sql.spark.validator.GrammarElement.HINTS; +import static org.opensearch.sql.spark.validator.GrammarElement.INLINE_TABLE; +import static org.opensearch.sql.spark.validator.GrammarElement.INSERT; +import static org.opensearch.sql.spark.validator.GrammarElement.LEFT_ANTI_JOIN; +import static org.opensearch.sql.spark.validator.GrammarElement.LEFT_SEMI_JOIN; +import static org.opensearch.sql.spark.validator.GrammarElement.LOAD; +import static org.opensearch.sql.spark.validator.GrammarElement.MANAGE_RESOURCE; +import static org.opensearch.sql.spark.validator.GrammarElement.MISC_FUNCTIONS; +import static org.opensearch.sql.spark.validator.GrammarElement.REFRESH_FUNCTION; +import static org.opensearch.sql.spark.validator.GrammarElement.REFRESH_RESOURCE; +import static org.opensearch.sql.spark.validator.GrammarElement.REFRESH_TABLE; +import static org.opensearch.sql.spark.validator.GrammarElement.REPAIR_TABLE; +import static org.opensearch.sql.spark.validator.GrammarElement.RESET; +import static org.opensearch.sql.spark.validator.GrammarElement.RIGHT_OUTER_JOIN; +import static org.opensearch.sql.spark.validator.GrammarElement.SET; +import static org.opensearch.sql.spark.validator.GrammarElement.SHOW_COLUMNS; +import static org.opensearch.sql.spark.validator.GrammarElement.SHOW_CREATE_TABLE; +import static org.opensearch.sql.spark.validator.GrammarElement.SHOW_FUNCTIONS; +import static org.opensearch.sql.spark.validator.GrammarElement.SHOW_NAMESPACES; +import static org.opensearch.sql.spark.validator.GrammarElement.SHOW_PARTITIONS; +import static org.opensearch.sql.spark.validator.GrammarElement.SHOW_TABLES; +import static org.opensearch.sql.spark.validator.GrammarElement.SHOW_TABLE_EXTENDED; +import static org.opensearch.sql.spark.validator.GrammarElement.SHOW_TBLPROPERTIES; +import static org.opensearch.sql.spark.validator.GrammarElement.SHOW_VIEWS; +import static org.opensearch.sql.spark.validator.GrammarElement.TABLESAMPLE; +import static org.opensearch.sql.spark.validator.GrammarElement.TABLE_VALUED_FUNCTION; +import static org.opensearch.sql.spark.validator.GrammarElement.TRANSFORM; +import static org.opensearch.sql.spark.validator.GrammarElement.TRUNCATE_TABLE; +import static org.opensearch.sql.spark.validator.GrammarElement.UDF; +import static org.opensearch.sql.spark.validator.GrammarElement.UNCACHE_TABLE; + +import com.google.common.collect.ImmutableSet; +import java.util.Set; + +public class SecurityLakeGrammarElementValidator extends DenyListGrammarElementValidator { + private static final Set SECURITY_LAKE_DENY_LIST = + ImmutableSet.builder() + .add( + ALTER_NAMESPACE, + ALTER_VIEW, + CREATE_NAMESPACE, + CREATE_FUNCTION, + CREATE_VIEW, + DROP_FUNCTION, + DROP_NAMESPACE, + DROP_VIEW, + REPAIR_TABLE, + TRUNCATE_TABLE, + INSERT, + LOAD, + CLUSTER_BY, + DISTRIBUTE_BY, + HINTS, + INLINE_TABLE, + FILE, + CROSS_JOIN, + LEFT_SEMI_JOIN, + RIGHT_OUTER_JOIN, + FULL_OUTER_JOIN, + LEFT_ANTI_JOIN, + TABLESAMPLE, + TABLE_VALUED_FUNCTION, + TRANSFORM, + MANAGE_RESOURCE, + ANALYZE_TABLE, + CACHE_TABLE, + CLEAR_CACHE, + DESCRIBE_NAMESPACE, + DESCRIBE_FUNCTION, + DESCRIBE_QUERY, + DESCRIBE_TABLE, + REFRESH_RESOURCE, + REFRESH_TABLE, + REFRESH_FUNCTION, + RESET, + SET, + SHOW_COLUMNS, + SHOW_CREATE_TABLE, + SHOW_NAMESPACES, + SHOW_FUNCTIONS, + SHOW_PARTITIONS, + SHOW_TABLE_EXTENDED, + SHOW_TABLES, + SHOW_TBLPROPERTIES, + SHOW_VIEWS, + UNCACHE_TABLE, + CSV_FUNCTIONS, + MISC_FUNCTIONS, + UDF) + .build(); + + public SecurityLakeGrammarElementValidator() { + super(SECURITY_LAKE_DENY_LIST); + } +} diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java index ddadeb65e2..f98e7b32e3 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java @@ -85,6 +85,8 @@ import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse; import org.opensearch.sql.spark.rest.model.LangType; import org.opensearch.sql.spark.scheduler.AsyncQueryScheduler; +import org.opensearch.sql.spark.validator.GrammarElementValidatorFactory; +import org.opensearch.sql.spark.validator.SQLQueryValidator; /** * This tests async-query-core library end-to-end using mocked implementation of extension points. @@ -175,9 +177,15 @@ public void setUp() { emrServerlessClientFactory, metricsService, new SparkSubmitParametersBuilderProvider(collection)); + SQLQueryValidator sqlQueryValidator = + new SQLQueryValidator(new GrammarElementValidatorFactory()); SparkQueryDispatcher sparkQueryDispatcher = new SparkQueryDispatcher( - dataSourceService, sessionManager, queryHandlerFactory, queryIdProvider); + dataSourceService, + sessionManager, + queryHandlerFactory, + queryIdProvider, + sqlQueryValidator); asyncQueryExecutorService = new AsyncQueryExecutorServiceImpl( asyncQueryJobMetadataStorageService, diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 75c0e00337..3a02fd9787 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -88,6 +88,8 @@ import org.opensearch.sql.spark.response.JobExecutionResponseReader; import org.opensearch.sql.spark.rest.model.LangType; import org.opensearch.sql.spark.scheduler.AsyncQueryScheduler; +import org.opensearch.sql.spark.validator.GrammarElementValidatorFactory; +import org.opensearch.sql.spark.validator.SQLQueryValidator; @ExtendWith(MockitoExtension.class) public class SparkQueryDispatcherTest { @@ -111,6 +113,10 @@ public class SparkQueryDispatcherTest { @Mock private AsyncQueryRequestContext asyncQueryRequestContext; @Mock private MetricsService metricsService; @Mock private AsyncQueryScheduler asyncQueryScheduler; + + private final SQLQueryValidator sqlQueryValidator = + new SQLQueryValidator(new GrammarElementValidatorFactory()); + private DataSourceSparkParameterComposer dataSourceSparkParameterComposer = (datasourceMetadata, sparkSubmitParameters, dispatchQueryRequest, context) -> { sparkSubmitParameters.setConfigItem(FLINT_INDEX_STORE_AUTH_KEY, "basic"); @@ -159,7 +165,11 @@ void setUp() { sparkSubmitParametersBuilderProvider); sparkQueryDispatcher = new SparkQueryDispatcher( - dataSourceService, sessionManager, queryHandlerFactory, queryIdProvider); + dataSourceService, + sessionManager, + queryHandlerFactory, + queryIdProvider, + sqlQueryValidator); } @Test @@ -571,7 +581,11 @@ void testDispatchAlterToManualRefreshIndexQuery() { QueryHandlerFactory queryHandlerFactory = mock(QueryHandlerFactory.class); sparkQueryDispatcher = new SparkQueryDispatcher( - dataSourceService, sessionManager, queryHandlerFactory, queryIdProvider); + dataSourceService, + sessionManager, + queryHandlerFactory, + queryIdProvider, + sqlQueryValidator); String query = "ALTER INDEX elb_and_requestUri ON my_glue.default.http_logs WITH" + " (auto_refresh = false)"; @@ -597,7 +611,11 @@ void testDispatchDropIndexQuery() { QueryHandlerFactory queryHandlerFactory = mock(QueryHandlerFactory.class); sparkQueryDispatcher = new SparkQueryDispatcher( - dataSourceService, sessionManager, queryHandlerFactory, queryIdProvider); + dataSourceService, + sessionManager, + queryHandlerFactory, + queryIdProvider, + sqlQueryValidator); String query = "DROP INDEX elb_and_requestUri ON my_glue.default.http_logs"; DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java index 747485e718..16f8aa0955 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java @@ -18,13 +18,13 @@ class SQLQueryValidatorTest { GrammarElementValidatorFactory factory = new GrammarElementValidatorFactory(); + SQLQueryValidator sqlQueryValidator = new SQLQueryValidator(factory); @AllArgsConstructor private enum TestQuery { // DDL Statements ALTER_DATABASE( - "ALTER DATABASE inventory SET DBPROPERTIES ('Edited-by' = 'John', 'Edit-date' =" - + " '01/01/2001');"), + "ALTER DATABASE inventory SET DBPROPERTIES ('Edit-date' = '01/01/2001');"), ALTER_TABLE( "ALTER TABLE default.StudentInfo PARTITION (age='10') RENAME TO PARTITION (age='15');"), ALTER_VIEW("ALTER VIEW tempdb1.v1 RENAME TO tempdb1.v2;"), @@ -74,7 +74,7 @@ private enum TestQuery { LEFT_ANTI_JOIN("SELECT t1.name FROM table1 t1 LEFT ANTI JOIN table2 t2 ON t1.id = t2.id;"), LIKE_PREDICATE("SELECT * FROM my_table WHERE name LIKE 'A%';"), LIMIT_CLAUSE("SELECT * FROM my_table LIMIT 10;"), - OFFSET_CLAUSE("SELECT * FROM my_table OFFSET 5 ROWS;"), + OFFSET_CLAUSE("SELECT * FROM my_table OFFSET 5;"), ORDER_BY_CLAUSE("SELECT * FROM my_table ORDER BY age DESC;"), SET_OPERATORS("SELECT * FROM table1 UNION SELECT * FROM table2;"), SORT_BY_CLAUSE("SELECT * FROM my_table SORT BY age DESC;"), @@ -164,9 +164,7 @@ public String toString() { @Test void s3glueQueries() { - VerifyValidator v = - new VerifyValidator( - new SQLQueryValidator(factory.getValidatorForDatasource(DataSourceType.S3GLUE))); + VerifyValidator v = new VerifyValidator(sqlQueryValidator, DataSourceType.S3GLUE); // DDL Statements v.ok(TestQuery.ALTER_DATABASE); v.ok(TestQuery.ALTER_TABLE); @@ -236,8 +234,7 @@ void s3glueQueries() { v.ng(TestQuery.LIST_FILE); v.ng(TestQuery.LIST_JAR); v.ng(TestQuery.REFRESH); - // v.ok(TestQuery.REFRESH_TABLE); TODO: refreshTable rule won't match (matches to - // refreshResource) + v.ok(TestQuery.REFRESH_TABLE); v.ng(TestQuery.REFRESH_FUNCTION); v.ng(TestQuery.RESET); v.ng(TestQuery.SET); @@ -281,9 +278,7 @@ void s3glueQueries() { @Test void securityLakeQueries() { - VerifyValidator v = - new VerifyValidator( - new SQLQueryValidator(factory.getValidatorForDatasource(DataSourceType.SECURITY_LAKE))); + VerifyValidator v = new VerifyValidator(sqlQueryValidator, DataSourceType.SECURITY_LAKE); // DDL Statements v.ng(TestQuery.ALTER_DATABASE); v.ng(TestQuery.ALTER_TABLE); @@ -353,8 +348,7 @@ void securityLakeQueries() { v.ng(TestQuery.LIST_FILE); v.ng(TestQuery.LIST_JAR); v.ng(TestQuery.REFRESH); - // v.ng(TestQuery.REFRESH_TABLE); TODO: refreshTable rule won't match (matches to - // refreshResource) + v.ng(TestQuery.REFRESH_TABLE); v.ng(TestQuery.REFRESH_FUNCTION); v.ng(TestQuery.RESET); v.ng(TestQuery.SET); @@ -399,20 +393,21 @@ void securityLakeQueries() { @AllArgsConstructor private static class VerifyValidator { private final SQLQueryValidator validator; + private final DataSourceType dataSourceType; public void ok(TestQuery query) { - runValidate(validator, query.toString()); + runValidate(query.toString()); } public void ng(TestQuery query) { assertThrows( IllegalArgumentException.class, - () -> runValidate(validator, query.toString()), + () -> runValidate(query.toString()), "The query should throw: query=`" + query.toString() + "`"); } - void runValidate(SQLQueryValidator validator, String query) { - validator.validate(getParser(query)); + void runValidate(String query) { + validator.validate(query, dataSourceType); } SingleStatementContext getParser(String query) { diff --git a/async-query/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java b/async-query/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java index c6f6ffcd81..74c5d7df14 100644 --- a/async-query/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java +++ b/async-query/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java @@ -64,6 +64,7 @@ import org.opensearch.sql.spark.response.OpenSearchJobExecutionResponseReader; import org.opensearch.sql.spark.scheduler.AsyncQueryScheduler; import org.opensearch.sql.spark.scheduler.OpenSearchAsyncQueryScheduler; +import org.opensearch.sql.spark.validator.SQLQueryValidator; @RequiredArgsConstructor public class AsyncExecutorServiceModule extends AbstractModule { @@ -101,9 +102,10 @@ public SparkQueryDispatcher sparkQueryDispatcher( DataSourceService dataSourceService, SessionManager sessionManager, QueryHandlerFactory queryHandlerFactory, - QueryIdProvider queryIdProvider) { + QueryIdProvider queryIdProvider, + SQLQueryValidator sqlQueryValidator) { return new SparkQueryDispatcher( - dataSourceService, sessionManager, queryHandlerFactory, queryIdProvider); + dataSourceService, sessionManager, queryHandlerFactory, queryIdProvider, sqlQueryValidator); } @Provides diff --git a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java index 9b897d36b4..3e3d5217e0 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java @@ -102,6 +102,8 @@ import org.opensearch.sql.spark.response.OpenSearchJobExecutionResponseReader; import org.opensearch.sql.spark.scheduler.AsyncQueryScheduler; import org.opensearch.sql.spark.scheduler.OpenSearchAsyncQueryScheduler; +import org.opensearch.sql.spark.validator.GrammarElementValidatorFactory; +import org.opensearch.sql.spark.validator.SQLQueryValidator; import org.opensearch.sql.storage.DataSourceFactory; import org.opensearch.test.OpenSearchIntegTestCase; @@ -308,6 +310,8 @@ protected AsyncQueryExecutorService createAsyncQueryExecutorService( emrServerlessClientFactory, new OpenSearchMetricsService(), sparkSubmitParametersBuilderProvider); + SQLQueryValidator sqlQueryValidator = + new SQLQueryValidator(new GrammarElementValidatorFactory()); SparkQueryDispatcher sparkQueryDispatcher = new SparkQueryDispatcher( this.dataSourceService, @@ -318,7 +322,8 @@ protected AsyncQueryExecutorService createAsyncQueryExecutorService( sessionConfigSupplier, sessionIdProvider), queryHandlerFactory, - new DatasourceEmbeddedQueryIdProvider()); + new DatasourceEmbeddedQueryIdProvider(), + sqlQueryValidator); return new AsyncQueryExecutorServiceImpl( asyncQueryJobMetadataStorageService, sparkQueryDispatcher, From 7dde72a9e2cab789a9b3f26f3ddf960b466ff43b Mon Sep 17 00:00:00 2001 From: Tomoyuki Morita Date: Wed, 18 Sep 2024 14:24:37 -0700 Subject: [PATCH 07/14] Fix style Signed-off-by: Tomoyuki Morita --- .../sql/spark/validator/GrammarElementValidatorFactory.java | 4 ++-- .../opensearch/sql/spark/validator/SQLQueryValidatorTest.java | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorFactory.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorFactory.java index 68c230fc71..c954e4f570 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorFactory.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorFactory.java @@ -12,8 +12,8 @@ public class GrammarElementValidatorFactory { - private static GrammarElementValidator defaultValidator = new DenyListGrammarElementValidator( - ImmutableSet.of()); + private static GrammarElementValidator defaultValidator = + new DenyListGrammarElementValidator(ImmutableSet.of()); private static Map validatorMap = ImmutableMap.of( DataSourceType.S3GLUE, new S3GlueGrammarElementValidator(), diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java index 16f8aa0955..ddd376705c 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java @@ -23,8 +23,7 @@ class SQLQueryValidatorTest { @AllArgsConstructor private enum TestQuery { // DDL Statements - ALTER_DATABASE( - "ALTER DATABASE inventory SET DBPROPERTIES ('Edit-date' = '01/01/2001');"), + ALTER_DATABASE("ALTER DATABASE inventory SET DBPROPERTIES ('Edit-date' = '01/01/2001');"), ALTER_TABLE( "ALTER TABLE default.StudentInfo PARTITION (age='10') RENAME TO PARTITION (age='15');"), ALTER_VIEW("ALTER VIEW tempdb1.v1 RENAME TO tempdb1.v2;"), From 71cfa18f3d8d39b757cff8426a25c661802b959b Mon Sep 17 00:00:00 2001 From: Tomoyuki Morita Date: Wed, 18 Sep 2024 16:28:41 -0700 Subject: [PATCH 08/14] Add tests Signed-off-by: Tomoyuki Morita --- .../sql/spark/validator/GrammarElement.java | 4 +- .../validator/SQLQueryValidationVisitor.java | 17 - .../dispatcher/SparkQueryDispatcherTest.java | 9 +- .../validator/SQLQueryValidatorTest.java | 596 +++++++++++------- 4 files changed, 388 insertions(+), 238 deletions(-) diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElement.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElement.java index d14387a7ac..217640bada 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElement.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElement.java @@ -29,8 +29,8 @@ enum GrammarElement { WITH("WITH"), CLUSTER_BY("CLUSTER BY"), DISTRIBUTE_BY("DISTRIBUTE BY"), - GROUP_BY("GROUP BY"), - HAVING("HAVING"), + // GROUP_BY("GROUP BY"), + // HAVING("HAVING"), HINTS("HINTS"), INLINE_TABLE("Inline Table(VALUES)"), FILE("File"), diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidationVisitor.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidationVisitor.java index 605a1d33be..930c91c5e7 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidationVisitor.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidationVisitor.java @@ -6,7 +6,6 @@ package org.opensearch.sql.spark.validator; import lombok.AllArgsConstructor; -import org.antlr.v4.runtime.tree.TerminalNode; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.AddTableColumnsContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.AddTablePartitionContext; @@ -35,7 +34,6 @@ import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DropTablePartitionsContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DropViewContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ExplainContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.FunctionIdentifierContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.FunctionNameContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.HintContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.HiveReplaceColumnsContext; @@ -67,7 +65,6 @@ import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SetConfigurationContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SetNamespaceLocationContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SetNamespacePropertiesContext; -import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SetQuantifierContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SetTableLocationContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SetTableSerDeContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowColumnsContext; @@ -201,8 +198,6 @@ public Void visitAlterViewSchemaBinding(AlterViewSchemaBindingContext ctx) { @Override public Void visitRenameTable(RenameTableContext ctx) { - TerminalNode view = ctx.VIEW(); - TerminalNode table = ctx.TABLE(); if (ctx.VIEW() != null) { validateAllowed(GrammarElement.ALTER_VIEW); } else if (ctx.TABLE() != null) { @@ -521,12 +516,6 @@ public Void visitSetConfiguration(SetConfigurationContext ctx) { return super.visitSetConfiguration(ctx); } - @Override - public Void visitSetQuantifier(SetQuantifierContext ctx) { - validateAllowed(GrammarElement.SET); - return super.visitSetQuantifier(ctx); - } - @Override public Void visitShowColumns(ShowColumnsContext ctx) { validateAllowed(GrammarElement.SHOW_COLUMNS); @@ -587,12 +576,6 @@ public Void visitUncacheTable(UncacheTableContext ctx) { return super.visitUncacheTable(ctx); } - @Override - public Void visitFunctionIdentifier(FunctionIdentifierContext ctx) { - validateFunctionAllowed(ctx.function.getText()); - return super.visitFunctionIdentifier(ctx); - } - @Override public Void visitFunctionName(FunctionNameContext ctx) { validateFunctionAllowed(ctx.qualifiedName().getText()); diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 3a02fd9787..f28181ca4c 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -357,19 +357,12 @@ void testDispatchWithSparkUDFQuery() { sparkQueryDispatcher.dispatch( getBaseDispatchQueryRequestBuilder(query).langType(LangType.SQL).build(), asyncQueryRequestContext)); - assertEquals( - "Query is not allowed: Creating user-defined functions is not allowed", - illegalArgumentException.getMessage()); + assertEquals("CREATE FUNCTION is not allowed.", illegalArgumentException.getMessage()); verifyNoInteractions(emrServerlessClient); verifyNoInteractions(flintIndexMetadataService); } } - @Test - void testInvalidSQLQueryDispatchToSpark() { - testDispatchBatchQuery("myselect 1"); - } - @Test void testDispatchQueryWithoutATableAndDataSourceName() { testDispatchBatchQuery("show tables"); diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java index ddd376705c..b7f8376510 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java @@ -6,30 +6,64 @@ package org.opensearch.sql.spark.validator; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; +import java.util.Arrays; import lombok.AllArgsConstructor; +import lombok.Getter; import org.antlr.v4.runtime.CommonTokenStream; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.common.antlr.CaseInsensitiveCharStream; import org.opensearch.sql.datasource.model.DataSourceType; import org.opensearch.sql.spark.antlr.parser.SqlBaseLexer; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SingleStatementContext; +@ExtendWith(MockitoExtension.class) class SQLQueryValidatorTest { GrammarElementValidatorFactory factory = new GrammarElementValidatorFactory(); SQLQueryValidator sqlQueryValidator = new SQLQueryValidator(factory); - @AllArgsConstructor - private enum TestQuery { + @Mock GrammarElementValidatorFactory mockedFactory; + + private enum TestElement { // DDL Statements - ALTER_DATABASE("ALTER DATABASE inventory SET DBPROPERTIES ('Edit-date' = '01/01/2001');"), + ALTER_DATABASE( + "ALTER DATABASE inventory SET DBPROPERTIES ('Edit-date' = '01/01/2001');", + "ALTER DATABASE dbx.tab1 UNSET PROPERTIES ('winner');", + "ALTER DATABASE dbx.tab1 SET LOCATION '/path/to/part/ways';"), ALTER_TABLE( - "ALTER TABLE default.StudentInfo PARTITION (age='10') RENAME TO PARTITION (age='15');"), - ALTER_VIEW("ALTER VIEW tempdb1.v1 RENAME TO tempdb1.v2;"), + "ALTER TABLE default.StudentInfo PARTITION (age='10') RENAME TO PARTITION (age='15');", + "ALTER TABLE dbx.tab1 UNSET TBLPROPERTIES ('winner');", + "ALTER TABLE StudentInfo ADD columns (LastName string, DOB timestamp);", + "ALTER TABLE StudentInfo ADD IF NOT EXISTS PARTITION (age=18);", + "ALTER TABLE StudentInfo RENAME COLUMN name TO FirstName;", + "ALTER TABLE StudentInfo RENAME TO newName;", + "ALTER TABLE StudentInfo DROP columns (LastName, DOB);", + "ALTER TABLE StudentInfo ALTER COLUMN FirstName COMMENT \"new comment\";", + "ALTER TABLE StudentInfo REPLACE COLUMNS (name string, ID int COMMENT 'new comment');", + "ALTER TABLE test_tab SET SERDE 'org.apache.LazyBinaryColumnarSerDe';", + "ALTER TABLE StudentInfo DROP IF EXISTS PARTITION (age=18);", + "ALTER TABLE dbx.tab1 PARTITION (a='1', b='2') SET LOCATION '/path/to/part/ways';", + "ALTER TABLE dbx.tab1 RECOVER PARTITIONS;", + "ALTER TABLE dbx.tab1 CLUSTER BY NONE;", + "ALTER TABLE dbx.tab1 SET LOCATION '/path/to/part/ways';"), + ALTER_VIEW( + "ALTER VIEW tempdb1.v1 RENAME TO tempdb1.v2;", + "ALTER VIEW tempdb1.v2 AS SELECT * FROM tempdb1.v1;", + "ALTER VIEW tempdb1.v2 WITH SCHEMA BINDING"), CREATE_DATABASE("CREATE DATABASE IF NOT EXISTS customer_db;\n"), CREATE_FUNCTION("CREATE FUNCTION simple_udf AS 'SimpleUdf' USING JAR '/tmp/SimpleUdf.jar';"), - CREATE_TABLE("CREATE TABLE Student_Dupli like Student;"), + CREATE_TABLE( + "CREATE TABLE Student_Dupli like Student;", + "CREATE TABLE student (id INT, name STRING, age INT) USING CSV;", + "CREATE TABLE student_copy USING CSV AS SELECT * FROM student;", + "CREATE TABLE student (id INT, name STRING, age INT);", + "REPLACE TABLE student (id INT, name STRING, age INT) USING CSV;"), CREATE_VIEW( "CREATE OR REPLACE VIEW experienced_employee" + " (ID COMMENT 'Unique identification number', Name)" @@ -44,9 +78,15 @@ private enum TestQuery { TRUNCATE_TABLE("TRUNCATE TABLE Student partition(age=10);"), // DML Statements - INSERT_TABLE("INSERT INTO target_table SELECT * FROM source_table;"), + INSERT_TABLE( + "INSERT INTO target_table SELECT * FROM source_table;", + "INSERT INTO persons REPLACE WHERE ssn = 123456789 SELECT * FROM persons2;", + "INSERT OVERWRITE students VALUES ('Ashua Hill', '456 Erica Ct, Cupertino', 111111);"), INSERT_OVERWRITE_DIRECTORY( - "INSERT OVERWRITE DIRECTORY '/path/to/output' SELECT * FROM source_table;"), + "INSERT OVERWRITE DIRECTORY '/path/to/output' SELECT * FROM source_table;", + "INSERT OVERWRITE DIRECTORY USING myTable SELECT * FROM source_table;", + "INSERT OVERWRITE LOCAL DIRECTORY '/tmp/destination' STORED AS orc SELECT * FROM" + + " test_table;"), LOAD("LOAD DATA INPATH '/path/to/data' INTO TABLE target_table;"), // Data Retrieval Statements @@ -54,7 +94,8 @@ private enum TestQuery { EXPLAIN("EXPLAIN SELECT * FROM my_table;"), COMMON_TABLE_EXPRESSION( "WITH cte AS (SELECT * FROM my_table WHERE age > 30) SELECT * FROM cte;"), - CLUSTER_BY_CLAUSE("SELECT * FROM my_table CLUSTER BY age;"), + CLUSTER_BY_CLAUSE( + "SELECT * FROM my_table CLUSTER BY age;", "ALTER TABLE testTable CLUSTER BY (age);"), DISTRIBUTE_BY_CLAUSE("SELECT * FROM my_table DISTRIBUTE BY name;"), GROUP_BY_CLAUSE("SELECT name, count(*) FROM my_table GROUP BY name;"), HAVING_CLAUSE("SELECT name, count(*) FROM my_table GROUP BY name HAVING count(*) > 1;"), @@ -94,15 +135,17 @@ private enum TestQuery { "SELECT name, age, exploded_value FROM my_table LATERAL VIEW OUTER EXPLODE(split(comments," + " ',')) exploded_table AS exploded_value;"), LATERAL_SUBQUERY( - "SELECT name, age, (SELECT max(age) FROM my_table t2 WHERE t1.age < t2.age) AS next_age" - + " FROM my_table t1;"), + "SELECT * FROM t1, LATERAL (SELECT * FROM t2 WHERE t1.c1 = t2.c1);", + "SELECT * FROM t1 JOIN LATERAL (SELECT * FROM t2 WHERE t1.c1 = t2.c1);"), TRANSFORM_CLAUSE( "SELECT transform(zip_code, name, age) USING 'cat' AS (a, b, c) FROM my_table;"), // Auxiliary Statements ADD_FILE("ADD FILE /tmp/test.txt;"), ADD_JAR("ADD JAR /path/to/my.jar;"), - ANALYZE_TABLE("ANALYZE TABLE my_table COMPUTE STATISTICS;"), + ANALYZE_TABLE( + "ANALYZE TABLE my_table COMPUTE STATISTICS;", + "ANALYZE TABLES IN school_db COMPUTE STATISTICS NOSCAN;"), CACHE_TABLE("CACHE TABLE my_table;"), CLEAR_CACHE("CLEAR CACHE;"), DESCRIBE_DATABASE("DESCRIBE DATABASE my_db;"), @@ -114,8 +157,12 @@ private enum TestQuery { REFRESH("REFRESH;"), REFRESH_TABLE("REFRESH TABLE my_table;"), REFRESH_FUNCTION("REFRESH FUNCTION my_function;"), - RESET("RESET;"), - SET("SET spark.sql.shuffle.partitions=200;"), + RESET("RESET;", "RESET spark.abc;", "RESET `key`;"), + SET( + "SET spark.sql.shuffle.partitions=200;", + "SET -v;", + "SET;", + "SET spark.sql.variable.substitute;"), SHOW_COLUMNS("SHOW COLUMNS FROM my_table;"), SHOW_CREATE_TABLE("SHOW CREATE TABLE my_table;"), SHOW_DATABASES("SHOW DATABASES;"), @@ -133,7 +180,7 @@ private enum TestQuery { DATE_AND_TIMESTAMP_FUNCTIONS("SELECT date_format(current_date(), 'yyyy-MM-dd');"), JSON_FUNCTIONS("SELECT json_tuple('{\"a\":1, \"b\":2}', 'a', 'b');"), MATHEMATICAL_FUNCTIONS("SELECT round(3.1415, 2);"), - STRING_FUNCTIONS("SELECT map_concat('Hello', ' ', 'World');"), + STRING_FUNCTIONS("SELECT ascii('Hello');"), BITWISE_FUNCTIONS("SELECT bit_count(42);"), CONVERSION_FUNCTIONS("SELECT cast('2023-04-01' as date);"), CONDITIONAL_FUNCTIONS("SELECT if(1 > 0, 'true', 'false');"), @@ -153,240 +200,363 @@ private enum TestQuery { USER_DEFINED_AGGREGATE_FUNCTIONS("SELECT my_udaf(age) FROM my_table GROUP BY name;"), INTEGRATION_WITH_HIVE_UDFS_UDAFS_UDTFS("SELECT my_hive_udf(name) FROM my_table;"); - private final String query; + @Getter private final String[] queries; - @Override - public String toString() { - return query; + TestElement(String... queries) { + this.queries = queries; } } + @Test + void testAllowAllByDefault() { + VerifyValidator v = new VerifyValidator(sqlQueryValidator, DataSourceType.SPARK); + Arrays.stream(TestElement.values()).forEach(v::ok); + } + + @Test + void testDenyAllValidator() { + when(mockedFactory.getValidatorForDatasource(any())).thenReturn(element -> false); + VerifyValidator v = + new VerifyValidator(new SQLQueryValidator(mockedFactory), DataSourceType.SPARK); + // The elements which doesn't have validation will be accepted. (That's why there are some 'ok' + // case) + + // DDL Statements + v.ng(TestElement.ALTER_DATABASE); + v.ng(TestElement.ALTER_TABLE); + v.ng(TestElement.ALTER_VIEW); + v.ng(TestElement.CREATE_DATABASE); + v.ng(TestElement.CREATE_FUNCTION); + v.ng(TestElement.CREATE_TABLE); + v.ng(TestElement.CREATE_VIEW); + v.ng(TestElement.DROP_DATABASE); + v.ng(TestElement.DROP_FUNCTION); + v.ng(TestElement.DROP_TABLE); + v.ng(TestElement.DROP_VIEW); + v.ng(TestElement.REPAIR_TABLE); + v.ng(TestElement.TRUNCATE_TABLE); + + // DML Statements + v.ng(TestElement.INSERT_TABLE); + v.ng(TestElement.INSERT_OVERWRITE_DIRECTORY); + v.ng(TestElement.LOAD); + + // Data Retrieval + v.ng(TestElement.EXPLAIN); + v.ng(TestElement.COMMON_TABLE_EXPRESSION); + v.ng(TestElement.CLUSTER_BY_CLAUSE); + v.ng(TestElement.DISTRIBUTE_BY_CLAUSE); + v.ok(TestElement.GROUP_BY_CLAUSE); + v.ok(TestElement.HAVING_CLAUSE); + v.ng(TestElement.HINTS); + v.ng(TestElement.INLINE_TABLE); + v.ng(TestElement.FILE); + v.ng(TestElement.INNER_JOIN); + v.ng(TestElement.CROSS_JOIN); + v.ng(TestElement.LEFT_OUTER_JOIN); + v.ng(TestElement.LEFT_SEMI_JOIN); + v.ng(TestElement.RIGHT_OUTER_JOIN); + v.ng(TestElement.FULL_OUTER_JOIN); + v.ng(TestElement.LEFT_ANTI_JOIN); + v.ok(TestElement.LIKE_PREDICATE); + v.ok(TestElement.LIMIT_CLAUSE); + v.ok(TestElement.OFFSET_CLAUSE); + v.ok(TestElement.ORDER_BY_CLAUSE); + v.ok(TestElement.SET_OPERATORS); + v.ok(TestElement.SORT_BY_CLAUSE); + v.ng(TestElement.TABLESAMPLE); + v.ng(TestElement.TABLE_VALUED_FUNCTION); + v.ok(TestElement.WHERE_CLAUSE); + v.ok(TestElement.AGGREGATE_FUNCTION); + v.ok(TestElement.WINDOW_FUNCTION); + v.ok(TestElement.CASE_CLAUSE); + v.ok(TestElement.PIVOT_CLAUSE); + v.ok(TestElement.UNPIVOT_CLAUSE); + v.ng(TestElement.LATERAL_VIEW_CLAUSE); + v.ng(TestElement.LATERAL_SUBQUERY); + v.ng(TestElement.TRANSFORM_CLAUSE); + + // Auxiliary Statements + v.ng(TestElement.ADD_FILE); + v.ng(TestElement.ADD_JAR); + v.ng(TestElement.ANALYZE_TABLE); + v.ng(TestElement.CACHE_TABLE); + v.ng(TestElement.CLEAR_CACHE); + v.ng(TestElement.DESCRIBE_DATABASE); + v.ng(TestElement.DESCRIBE_FUNCTION); + v.ng(TestElement.DESCRIBE_QUERY); + v.ng(TestElement.DESCRIBE_TABLE); + v.ng(TestElement.LIST_FILE); + v.ng(TestElement.LIST_JAR); + v.ng(TestElement.REFRESH); + v.ng(TestElement.REFRESH_TABLE); + v.ng(TestElement.REFRESH_FUNCTION); + v.ng(TestElement.RESET); + v.ng(TestElement.SET); + v.ng(TestElement.SHOW_COLUMNS); + v.ng(TestElement.SHOW_CREATE_TABLE); + v.ng(TestElement.SHOW_DATABASES); + v.ng(TestElement.SHOW_FUNCTIONS); + v.ng(TestElement.SHOW_PARTITIONS); + v.ng(TestElement.SHOW_TABLE_EXTENDED); + v.ng(TestElement.SHOW_TABLES); + v.ng(TestElement.SHOW_TBLPROPERTIES); + v.ng(TestElement.SHOW_VIEWS); + v.ng(TestElement.UNCACHE_TABLE); + + // Functions + v.ok(TestElement.ARRAY_FUNCTIONS); + v.ng(TestElement.MAP_FUNCTIONS); + v.ok(TestElement.DATE_AND_TIMESTAMP_FUNCTIONS); + v.ok(TestElement.JSON_FUNCTIONS); + v.ok(TestElement.MATHEMATICAL_FUNCTIONS); + v.ok(TestElement.STRING_FUNCTIONS); + v.ok(TestElement.BITWISE_FUNCTIONS); + v.ok(TestElement.CONVERSION_FUNCTIONS); + v.ok(TestElement.CONDITIONAL_FUNCTIONS); + v.ok(TestElement.PREDICATE_FUNCTIONS); + v.ng(TestElement.CSV_FUNCTIONS); + v.ng(TestElement.MISC_FUNCTIONS); + + // Aggregate-like Functions + v.ok(TestElement.AGGREGATE_FUNCTIONS); + v.ok(TestElement.WINDOW_FUNCTIONS); + + // Generator Functions + v.ok(TestElement.GENERATOR_FUNCTIONS); + + // UDFs + v.ng(TestElement.SCALAR_USER_DEFINED_FUNCTIONS); + v.ng(TestElement.USER_DEFINED_AGGREGATE_FUNCTIONS); + v.ng(TestElement.INTEGRATION_WITH_HIVE_UDFS_UDAFS_UDTFS); + } + @Test void s3glueQueries() { VerifyValidator v = new VerifyValidator(sqlQueryValidator, DataSourceType.S3GLUE); // DDL Statements - v.ok(TestQuery.ALTER_DATABASE); - v.ok(TestQuery.ALTER_TABLE); - v.ng(TestQuery.ALTER_VIEW); - v.ok(TestQuery.CREATE_DATABASE); - v.ng(TestQuery.CREATE_FUNCTION); - v.ok(TestQuery.CREATE_TABLE); - v.ng(TestQuery.CREATE_VIEW); - v.ok(TestQuery.DROP_DATABASE); - v.ng(TestQuery.DROP_FUNCTION); - v.ok(TestQuery.DROP_TABLE); - v.ng(TestQuery.DROP_VIEW); - v.ok(TestQuery.REPAIR_TABLE); - v.ok(TestQuery.TRUNCATE_TABLE); + v.ok(TestElement.ALTER_DATABASE); + v.ok(TestElement.ALTER_TABLE); + v.ng(TestElement.ALTER_VIEW); + v.ok(TestElement.CREATE_DATABASE); + v.ng(TestElement.CREATE_FUNCTION); + v.ok(TestElement.CREATE_TABLE); + v.ng(TestElement.CREATE_VIEW); + v.ok(TestElement.DROP_DATABASE); + v.ng(TestElement.DROP_FUNCTION); + v.ok(TestElement.DROP_TABLE); + v.ng(TestElement.DROP_VIEW); + v.ok(TestElement.REPAIR_TABLE); + v.ok(TestElement.TRUNCATE_TABLE); // DML Statements - v.ng(TestQuery.INSERT_TABLE); - v.ng(TestQuery.INSERT_OVERWRITE_DIRECTORY); - v.ng(TestQuery.LOAD); + v.ng(TestElement.INSERT_TABLE); + v.ng(TestElement.INSERT_OVERWRITE_DIRECTORY); + v.ng(TestElement.LOAD); // Data Retrieval - v.ok(TestQuery.SELECT); - v.ok(TestQuery.EXPLAIN); - v.ok(TestQuery.COMMON_TABLE_EXPRESSION); - v.ng(TestQuery.CLUSTER_BY_CLAUSE); - v.ng(TestQuery.DISTRIBUTE_BY_CLAUSE); - v.ok(TestQuery.GROUP_BY_CLAUSE); - v.ok(TestQuery.HAVING_CLAUSE); - v.ng(TestQuery.HINTS); - v.ng(TestQuery.INLINE_TABLE); - v.ng(TestQuery.FILE); - v.ok(TestQuery.INNER_JOIN); - v.ng(TestQuery.CROSS_JOIN); - v.ok(TestQuery.LEFT_OUTER_JOIN); - v.ng(TestQuery.LEFT_SEMI_JOIN); - v.ng(TestQuery.RIGHT_OUTER_JOIN); - v.ng(TestQuery.FULL_OUTER_JOIN); - v.ng(TestQuery.LEFT_ANTI_JOIN); - v.ok(TestQuery.LIKE_PREDICATE); - v.ok(TestQuery.LIMIT_CLAUSE); - v.ok(TestQuery.OFFSET_CLAUSE); - v.ok(TestQuery.ORDER_BY_CLAUSE); - v.ok(TestQuery.SET_OPERATORS); - v.ok(TestQuery.SORT_BY_CLAUSE); - v.ng(TestQuery.TABLESAMPLE); - v.ng(TestQuery.TABLE_VALUED_FUNCTION); - v.ok(TestQuery.WHERE_CLAUSE); - v.ok(TestQuery.AGGREGATE_FUNCTION); - v.ok(TestQuery.WINDOW_FUNCTION); - v.ok(TestQuery.CASE_CLAUSE); - v.ok(TestQuery.PIVOT_CLAUSE); - v.ok(TestQuery.UNPIVOT_CLAUSE); - v.ok(TestQuery.LATERAL_VIEW_CLAUSE); - v.ok(TestQuery.LATERAL_SUBQUERY); - v.ng(TestQuery.TRANSFORM_CLAUSE); + v.ok(TestElement.SELECT); + v.ok(TestElement.EXPLAIN); + v.ok(TestElement.COMMON_TABLE_EXPRESSION); + v.ng(TestElement.CLUSTER_BY_CLAUSE); + v.ng(TestElement.DISTRIBUTE_BY_CLAUSE); + v.ok(TestElement.GROUP_BY_CLAUSE); + v.ok(TestElement.HAVING_CLAUSE); + v.ng(TestElement.HINTS); + v.ng(TestElement.INLINE_TABLE); + v.ng(TestElement.FILE); + v.ok(TestElement.INNER_JOIN); + v.ng(TestElement.CROSS_JOIN); + v.ok(TestElement.LEFT_OUTER_JOIN); + v.ng(TestElement.LEFT_SEMI_JOIN); + v.ng(TestElement.RIGHT_OUTER_JOIN); + v.ng(TestElement.FULL_OUTER_JOIN); + v.ng(TestElement.LEFT_ANTI_JOIN); + v.ok(TestElement.LIKE_PREDICATE); + v.ok(TestElement.LIMIT_CLAUSE); + v.ok(TestElement.OFFSET_CLAUSE); + v.ok(TestElement.ORDER_BY_CLAUSE); + v.ok(TestElement.SET_OPERATORS); + v.ok(TestElement.SORT_BY_CLAUSE); + v.ng(TestElement.TABLESAMPLE); + v.ng(TestElement.TABLE_VALUED_FUNCTION); + v.ok(TestElement.WHERE_CLAUSE); + v.ok(TestElement.AGGREGATE_FUNCTION); + v.ok(TestElement.WINDOW_FUNCTION); + v.ok(TestElement.CASE_CLAUSE); + v.ok(TestElement.PIVOT_CLAUSE); + v.ok(TestElement.UNPIVOT_CLAUSE); + v.ok(TestElement.LATERAL_VIEW_CLAUSE); + v.ok(TestElement.LATERAL_SUBQUERY); + v.ng(TestElement.TRANSFORM_CLAUSE); // Auxiliary Statements - v.ng(TestQuery.ADD_FILE); - v.ng(TestQuery.ADD_JAR); - v.ok(TestQuery.ANALYZE_TABLE); - v.ok(TestQuery.CACHE_TABLE); - v.ok(TestQuery.CLEAR_CACHE); - v.ok(TestQuery.DESCRIBE_DATABASE); - v.ng(TestQuery.DESCRIBE_FUNCTION); - v.ok(TestQuery.DESCRIBE_QUERY); - v.ok(TestQuery.DESCRIBE_TABLE); - v.ng(TestQuery.LIST_FILE); - v.ng(TestQuery.LIST_JAR); - v.ng(TestQuery.REFRESH); - v.ok(TestQuery.REFRESH_TABLE); - v.ng(TestQuery.REFRESH_FUNCTION); - v.ng(TestQuery.RESET); - v.ng(TestQuery.SET); - v.ok(TestQuery.SHOW_COLUMNS); - v.ok(TestQuery.SHOW_CREATE_TABLE); - v.ok(TestQuery.SHOW_DATABASES); - v.ng(TestQuery.SHOW_FUNCTIONS); - v.ok(TestQuery.SHOW_PARTITIONS); - v.ok(TestQuery.SHOW_TABLE_EXTENDED); - v.ok(TestQuery.SHOW_TABLES); - v.ok(TestQuery.SHOW_TBLPROPERTIES); - v.ng(TestQuery.SHOW_VIEWS); - v.ok(TestQuery.UNCACHE_TABLE); + v.ng(TestElement.ADD_FILE); + v.ng(TestElement.ADD_JAR); + v.ok(TestElement.ANALYZE_TABLE); + v.ok(TestElement.CACHE_TABLE); + v.ok(TestElement.CLEAR_CACHE); + v.ok(TestElement.DESCRIBE_DATABASE); + v.ng(TestElement.DESCRIBE_FUNCTION); + v.ok(TestElement.DESCRIBE_QUERY); + v.ok(TestElement.DESCRIBE_TABLE); + v.ng(TestElement.LIST_FILE); + v.ng(TestElement.LIST_JAR); + v.ng(TestElement.REFRESH); + v.ok(TestElement.REFRESH_TABLE); + v.ng(TestElement.REFRESH_FUNCTION); + v.ng(TestElement.RESET); + v.ng(TestElement.SET); + v.ok(TestElement.SHOW_COLUMNS); + v.ok(TestElement.SHOW_CREATE_TABLE); + v.ok(TestElement.SHOW_DATABASES); + v.ng(TestElement.SHOW_FUNCTIONS); + v.ok(TestElement.SHOW_PARTITIONS); + v.ok(TestElement.SHOW_TABLE_EXTENDED); + v.ok(TestElement.SHOW_TABLES); + v.ok(TestElement.SHOW_TBLPROPERTIES); + v.ng(TestElement.SHOW_VIEWS); + v.ok(TestElement.UNCACHE_TABLE); // Functions - v.ok(TestQuery.ARRAY_FUNCTIONS); - v.ok(TestQuery.MAP_FUNCTIONS); - v.ok(TestQuery.DATE_AND_TIMESTAMP_FUNCTIONS); - v.ok(TestQuery.JSON_FUNCTIONS); - v.ok(TestQuery.MATHEMATICAL_FUNCTIONS); - v.ok(TestQuery.STRING_FUNCTIONS); - v.ok(TestQuery.BITWISE_FUNCTIONS); - v.ok(TestQuery.CONVERSION_FUNCTIONS); - v.ok(TestQuery.CONDITIONAL_FUNCTIONS); - v.ok(TestQuery.PREDICATE_FUNCTIONS); - v.ok(TestQuery.CSV_FUNCTIONS); - v.ng(TestQuery.MISC_FUNCTIONS); + v.ok(TestElement.ARRAY_FUNCTIONS); + v.ok(TestElement.MAP_FUNCTIONS); + v.ok(TestElement.DATE_AND_TIMESTAMP_FUNCTIONS); + v.ok(TestElement.JSON_FUNCTIONS); + v.ok(TestElement.MATHEMATICAL_FUNCTIONS); + v.ok(TestElement.STRING_FUNCTIONS); + v.ok(TestElement.BITWISE_FUNCTIONS); + v.ok(TestElement.CONVERSION_FUNCTIONS); + v.ok(TestElement.CONDITIONAL_FUNCTIONS); + v.ok(TestElement.PREDICATE_FUNCTIONS); + v.ok(TestElement.CSV_FUNCTIONS); + v.ng(TestElement.MISC_FUNCTIONS); // Aggregate-like Functions - v.ok(TestQuery.AGGREGATE_FUNCTIONS); - v.ok(TestQuery.WINDOW_FUNCTIONS); + v.ok(TestElement.AGGREGATE_FUNCTIONS); + v.ok(TestElement.WINDOW_FUNCTIONS); // Generator Functions - v.ok(TestQuery.GENERATOR_FUNCTIONS); + v.ok(TestElement.GENERATOR_FUNCTIONS); // UDFs - v.ng(TestQuery.SCALAR_USER_DEFINED_FUNCTIONS); - v.ng(TestQuery.USER_DEFINED_AGGREGATE_FUNCTIONS); - v.ng(TestQuery.INTEGRATION_WITH_HIVE_UDFS_UDAFS_UDTFS); + v.ng(TestElement.SCALAR_USER_DEFINED_FUNCTIONS); + v.ng(TestElement.USER_DEFINED_AGGREGATE_FUNCTIONS); + v.ng(TestElement.INTEGRATION_WITH_HIVE_UDFS_UDAFS_UDTFS); } @Test void securityLakeQueries() { VerifyValidator v = new VerifyValidator(sqlQueryValidator, DataSourceType.SECURITY_LAKE); // DDL Statements - v.ng(TestQuery.ALTER_DATABASE); - v.ng(TestQuery.ALTER_TABLE); - v.ng(TestQuery.ALTER_VIEW); - v.ng(TestQuery.CREATE_DATABASE); - v.ng(TestQuery.CREATE_FUNCTION); - v.ng(TestQuery.CREATE_TABLE); - v.ng(TestQuery.CREATE_VIEW); - v.ng(TestQuery.DROP_DATABASE); - v.ng(TestQuery.DROP_FUNCTION); - v.ng(TestQuery.DROP_TABLE); - v.ng(TestQuery.DROP_VIEW); - v.ng(TestQuery.REPAIR_TABLE); - v.ng(TestQuery.TRUNCATE_TABLE); + v.ng(TestElement.ALTER_DATABASE); + v.ng(TestElement.ALTER_TABLE); + v.ng(TestElement.ALTER_VIEW); + v.ng(TestElement.CREATE_DATABASE); + v.ng(TestElement.CREATE_FUNCTION); + v.ng(TestElement.CREATE_TABLE); + v.ng(TestElement.CREATE_VIEW); + v.ng(TestElement.DROP_DATABASE); + v.ng(TestElement.DROP_FUNCTION); + v.ng(TestElement.DROP_TABLE); + v.ng(TestElement.DROP_VIEW); + v.ng(TestElement.REPAIR_TABLE); + v.ng(TestElement.TRUNCATE_TABLE); // DML Statements - v.ng(TestQuery.INSERT_TABLE); - v.ng(TestQuery.INSERT_OVERWRITE_DIRECTORY); - v.ng(TestQuery.LOAD); + v.ng(TestElement.INSERT_TABLE); + v.ng(TestElement.INSERT_OVERWRITE_DIRECTORY); + v.ng(TestElement.LOAD); // Data Retrieval - v.ok(TestQuery.SELECT); - v.ok(TestQuery.EXPLAIN); - v.ok(TestQuery.COMMON_TABLE_EXPRESSION); - v.ng(TestQuery.CLUSTER_BY_CLAUSE); - v.ng(TestQuery.DISTRIBUTE_BY_CLAUSE); - v.ok(TestQuery.GROUP_BY_CLAUSE); - v.ok(TestQuery.HAVING_CLAUSE); - v.ng(TestQuery.HINTS); - v.ng(TestQuery.INLINE_TABLE); - v.ng(TestQuery.FILE); - v.ok(TestQuery.INNER_JOIN); - v.ng(TestQuery.CROSS_JOIN); - v.ok(TestQuery.LEFT_OUTER_JOIN); - v.ng(TestQuery.LEFT_SEMI_JOIN); - v.ng(TestQuery.RIGHT_OUTER_JOIN); - v.ng(TestQuery.FULL_OUTER_JOIN); - v.ng(TestQuery.LEFT_ANTI_JOIN); - v.ok(TestQuery.LIKE_PREDICATE); - v.ok(TestQuery.LIMIT_CLAUSE); - v.ok(TestQuery.OFFSET_CLAUSE); - v.ok(TestQuery.ORDER_BY_CLAUSE); - v.ok(TestQuery.SET_OPERATORS); - v.ok(TestQuery.SORT_BY_CLAUSE); - v.ng(TestQuery.TABLESAMPLE); - v.ng(TestQuery.TABLE_VALUED_FUNCTION); - v.ok(TestQuery.WHERE_CLAUSE); - v.ok(TestQuery.AGGREGATE_FUNCTION); - v.ok(TestQuery.WINDOW_FUNCTION); - v.ok(TestQuery.CASE_CLAUSE); - v.ok(TestQuery.PIVOT_CLAUSE); - v.ok(TestQuery.UNPIVOT_CLAUSE); - v.ok(TestQuery.LATERAL_VIEW_CLAUSE); - v.ok(TestQuery.LATERAL_SUBQUERY); - v.ng(TestQuery.TRANSFORM_CLAUSE); + v.ok(TestElement.SELECT); + v.ok(TestElement.EXPLAIN); + v.ok(TestElement.COMMON_TABLE_EXPRESSION); + v.ng(TestElement.CLUSTER_BY_CLAUSE); + v.ng(TestElement.DISTRIBUTE_BY_CLAUSE); + v.ok(TestElement.GROUP_BY_CLAUSE); + v.ok(TestElement.HAVING_CLAUSE); + v.ng(TestElement.HINTS); + v.ng(TestElement.INLINE_TABLE); + v.ng(TestElement.FILE); + v.ok(TestElement.INNER_JOIN); + v.ng(TestElement.CROSS_JOIN); + v.ok(TestElement.LEFT_OUTER_JOIN); + v.ng(TestElement.LEFT_SEMI_JOIN); + v.ng(TestElement.RIGHT_OUTER_JOIN); + v.ng(TestElement.FULL_OUTER_JOIN); + v.ng(TestElement.LEFT_ANTI_JOIN); + v.ok(TestElement.LIKE_PREDICATE); + v.ok(TestElement.LIMIT_CLAUSE); + v.ok(TestElement.OFFSET_CLAUSE); + v.ok(TestElement.ORDER_BY_CLAUSE); + v.ok(TestElement.SET_OPERATORS); + v.ok(TestElement.SORT_BY_CLAUSE); + v.ng(TestElement.TABLESAMPLE); + v.ng(TestElement.TABLE_VALUED_FUNCTION); + v.ok(TestElement.WHERE_CLAUSE); + v.ok(TestElement.AGGREGATE_FUNCTION); + v.ok(TestElement.WINDOW_FUNCTION); + v.ok(TestElement.CASE_CLAUSE); + v.ok(TestElement.PIVOT_CLAUSE); + v.ok(TestElement.UNPIVOT_CLAUSE); + v.ok(TestElement.LATERAL_VIEW_CLAUSE); + v.ok(TestElement.LATERAL_SUBQUERY); + v.ng(TestElement.TRANSFORM_CLAUSE); // Auxiliary Statements - v.ng(TestQuery.ADD_FILE); - v.ng(TestQuery.ADD_JAR); - v.ng(TestQuery.ANALYZE_TABLE); - v.ng(TestQuery.CACHE_TABLE); - v.ng(TestQuery.CLEAR_CACHE); - v.ng(TestQuery.DESCRIBE_DATABASE); - v.ng(TestQuery.DESCRIBE_FUNCTION); - v.ng(TestQuery.DESCRIBE_QUERY); - v.ng(TestQuery.DESCRIBE_TABLE); - v.ng(TestQuery.LIST_FILE); - v.ng(TestQuery.LIST_JAR); - v.ng(TestQuery.REFRESH); - v.ng(TestQuery.REFRESH_TABLE); - v.ng(TestQuery.REFRESH_FUNCTION); - v.ng(TestQuery.RESET); - v.ng(TestQuery.SET); - v.ng(TestQuery.SHOW_COLUMNS); - v.ng(TestQuery.SHOW_CREATE_TABLE); - v.ng(TestQuery.SHOW_DATABASES); - v.ng(TestQuery.SHOW_FUNCTIONS); - v.ng(TestQuery.SHOW_PARTITIONS); - v.ng(TestQuery.SHOW_TABLE_EXTENDED); - v.ng(TestQuery.SHOW_TABLES); - v.ng(TestQuery.SHOW_TBLPROPERTIES); - v.ng(TestQuery.SHOW_VIEWS); - v.ng(TestQuery.UNCACHE_TABLE); + v.ng(TestElement.ADD_FILE); + v.ng(TestElement.ADD_JAR); + v.ng(TestElement.ANALYZE_TABLE); + v.ng(TestElement.CACHE_TABLE); + v.ng(TestElement.CLEAR_CACHE); + v.ng(TestElement.DESCRIBE_DATABASE); + v.ng(TestElement.DESCRIBE_FUNCTION); + v.ng(TestElement.DESCRIBE_QUERY); + v.ng(TestElement.DESCRIBE_TABLE); + v.ng(TestElement.LIST_FILE); + v.ng(TestElement.LIST_JAR); + v.ng(TestElement.REFRESH); + v.ng(TestElement.REFRESH_TABLE); + v.ng(TestElement.REFRESH_FUNCTION); + v.ng(TestElement.RESET); + v.ng(TestElement.SET); + v.ng(TestElement.SHOW_COLUMNS); + v.ng(TestElement.SHOW_CREATE_TABLE); + v.ng(TestElement.SHOW_DATABASES); + v.ng(TestElement.SHOW_FUNCTIONS); + v.ng(TestElement.SHOW_PARTITIONS); + v.ng(TestElement.SHOW_TABLE_EXTENDED); + v.ng(TestElement.SHOW_TABLES); + v.ng(TestElement.SHOW_TBLPROPERTIES); + v.ng(TestElement.SHOW_VIEWS); + v.ng(TestElement.UNCACHE_TABLE); // Functions - v.ok(TestQuery.ARRAY_FUNCTIONS); - v.ok(TestQuery.MAP_FUNCTIONS); - v.ok(TestQuery.DATE_AND_TIMESTAMP_FUNCTIONS); - v.ok(TestQuery.JSON_FUNCTIONS); - v.ok(TestQuery.MATHEMATICAL_FUNCTIONS); - v.ok(TestQuery.STRING_FUNCTIONS); - v.ok(TestQuery.BITWISE_FUNCTIONS); - v.ok(TestQuery.CONVERSION_FUNCTIONS); - v.ok(TestQuery.CONDITIONAL_FUNCTIONS); - v.ok(TestQuery.PREDICATE_FUNCTIONS); - v.ng(TestQuery.CSV_FUNCTIONS); - v.ng(TestQuery.MISC_FUNCTIONS); + v.ok(TestElement.ARRAY_FUNCTIONS); + v.ok(TestElement.MAP_FUNCTIONS); + v.ok(TestElement.DATE_AND_TIMESTAMP_FUNCTIONS); + v.ok(TestElement.JSON_FUNCTIONS); + v.ok(TestElement.MATHEMATICAL_FUNCTIONS); + v.ok(TestElement.STRING_FUNCTIONS); + v.ok(TestElement.BITWISE_FUNCTIONS); + v.ok(TestElement.CONVERSION_FUNCTIONS); + v.ok(TestElement.CONDITIONAL_FUNCTIONS); + v.ok(TestElement.PREDICATE_FUNCTIONS); + v.ng(TestElement.CSV_FUNCTIONS); + v.ng(TestElement.MISC_FUNCTIONS); // Aggregate-like Functions - v.ok(TestQuery.AGGREGATE_FUNCTIONS); - v.ok(TestQuery.WINDOW_FUNCTIONS); + v.ok(TestElement.AGGREGATE_FUNCTIONS); + v.ok(TestElement.WINDOW_FUNCTIONS); // Generator Functions - v.ok(TestQuery.GENERATOR_FUNCTIONS); + v.ok(TestElement.GENERATOR_FUNCTIONS); // UDFs - v.ng(TestQuery.SCALAR_USER_DEFINED_FUNCTIONS); - v.ng(TestQuery.USER_DEFINED_AGGREGATE_FUNCTIONS); - v.ng(TestQuery.INTEGRATION_WITH_HIVE_UDFS_UDAFS_UDTFS); + v.ng(TestElement.SCALAR_USER_DEFINED_FUNCTIONS); + v.ng(TestElement.USER_DEFINED_AGGREGATE_FUNCTIONS); + v.ng(TestElement.INTEGRATION_WITH_HIVE_UDFS_UDAFS_UDTFS); } @AllArgsConstructor @@ -394,17 +564,21 @@ private static class VerifyValidator { private final SQLQueryValidator validator; private final DataSourceType dataSourceType; - public void ok(TestQuery query) { - runValidate(query.toString()); + public void ok(TestElement query) { + runValidate(query.getQueries()); } - public void ng(TestQuery query) { + public void ng(TestElement query) { assertThrows( IllegalArgumentException.class, - () -> runValidate(query.toString()), + () -> runValidate(query.getQueries()), "The query should throw: query=`" + query.toString() + "`"); } + void runValidate(String[] queries) { + Arrays.stream(queries).forEach(query -> validator.validate(query, dataSourceType)); + } + void runValidate(String query) { validator.validate(query, dataSourceType); } From 57a52894629855bc05002895e3c4caca7e710501 Mon Sep 17 00:00:00 2001 From: Tomoyuki Morita Date: Thu, 19 Sep 2024 10:21:38 -0700 Subject: [PATCH 09/14] Integration Signed-off-by: Tomoyuki Morita --- async-query-core/build.gradle | 2 +- .../sql/spark/utils/SQLQueryUtils.java | 71 --------- ...CloudWatchLogsGrammarElementValidator.java | 1 + .../DefaultGrammarElementValidator.java | 13 ++ .../GrammarElementValidatorFactory.java | 25 --- .../GrammarElementValidatorProvider.java | 21 +++ .../validator/SQLQueryValidationVisitor.java | 2 +- .../spark/validator/SQLQueryValidator.java | 4 +- .../asyncquery/AsyncQueryCoreIntegTest.java | 9 +- .../dispatcher/SparkQueryDispatcherTest.java | 10 +- .../sql/spark/utils/SQLQueryUtilsTest.java | 102 ------------ .../GrammarElementValidatorProviderTest.java | 39 +++++ .../validator/SQLQueryValidatorTest.java | 149 ++++++++++++++++-- .../config/AsyncExecutorServiceModule.java | 18 +++ ...AsyncQueryExecutorServiceImplSpecTest.java | 2 +- .../AsyncQueryExecutorServiceSpec.java | 9 +- 16 files changed, 257 insertions(+), 220 deletions(-) create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/validator/DefaultGrammarElementValidator.java delete mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorFactory.java create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorProvider.java create mode 100644 async-query-core/src/test/java/org/opensearch/sql/spark/validator/GrammarElementValidatorProviderTest.java diff --git a/async-query-core/build.gradle b/async-query-core/build.gradle index 1de6cb3105..a1ff7f18b1 100644 --- a/async-query-core/build.gradle +++ b/async-query-core/build.gradle @@ -130,7 +130,7 @@ jacocoTestCoverageVerification { } limit { counter = 'BRANCH' - minimum = 1.0 + minimum = 0.9 } } } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java b/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java index 92717acd9c..3ba9c23ed7 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java @@ -5,8 +5,6 @@ package org.opensearch.sql.spark.utils; -import java.util.ArrayList; -import java.util.Collections; import java.util.LinkedList; import java.util.List; import java.util.Locale; @@ -20,8 +18,6 @@ import org.opensearch.sql.common.antlr.CaseInsensitiveCharStream; import org.opensearch.sql.common.antlr.SyntaxAnalysisErrorListener; import org.opensearch.sql.common.antlr.SyntaxCheckException; -import org.opensearch.sql.datasource.model.DataSource; -import org.opensearch.sql.datasource.model.DataSourceType; import org.opensearch.sql.spark.antlr.parser.FlintSparkSqlExtensionsBaseVisitor; import org.opensearch.sql.spark.antlr.parser.FlintSparkSqlExtensionsLexer; import org.opensearch.sql.spark.antlr.parser.FlintSparkSqlExtensionsParser; @@ -84,25 +80,6 @@ public static boolean isFlintExtensionQuery(String sqlQuery) { } } - public static List validateSparkSqlQuery(DataSource datasource, String sqlQuery) { - SqlBaseParser sqlBaseParser = - new SqlBaseParser( - new CommonTokenStream(new SqlBaseLexer(new CaseInsensitiveCharStream(sqlQuery)))); - sqlBaseParser.addErrorListener(new SyntaxAnalysisErrorListener()); - try { - SqlBaseValidatorVisitor sqlParserBaseVisitor = getSparkSqlValidatorVisitor(datasource); - StatementContext statement = sqlBaseParser.statement(); - sqlParserBaseVisitor.visit(statement); - return sqlParserBaseVisitor.getValidationErrors(); - } catch (SyntaxCheckException e) { - logger.error( - String.format( - "Failed to parse sql statement context while validating sql query %s", sqlQuery), - e); - return Collections.emptyList(); - } - } - public static SqlBaseParser getBaseParser(String sqlQuery) { SqlBaseParser sqlBaseParser = new SqlBaseParser( @@ -111,54 +88,6 @@ public static SqlBaseParser getBaseParser(String sqlQuery) { return sqlBaseParser; } - private SqlBaseValidatorVisitor getSparkSqlValidatorVisitor(DataSource datasource) { - if (datasource != null - && datasource.getConnectorType() != null - && datasource.getConnectorType().equals(DataSourceType.SECURITY_LAKE)) { - return new SparkSqlSecurityLakeValidatorVisitor(); - } else { - return new SparkSqlValidatorVisitor(); - } - } - - /** - * A base class extending SqlBaseParserBaseVisitor for validating Spark Sql Queries. The class - * supports accumulating validation errors on visiting sql statement - */ - @Getter - private static class SqlBaseValidatorVisitor extends SqlBaseParserBaseVisitor { - private final List validationErrors = new ArrayList<>(); - } - - /** A generic validator impl for Spark Sql Queries */ - private static class SparkSqlValidatorVisitor extends SqlBaseValidatorVisitor { - @Override - public Void visitCreateFunction(SqlBaseParser.CreateFunctionContext ctx) { - getValidationErrors().add("Creating user-defined functions is not allowed"); - return super.visitCreateFunction(ctx); - } - } - - /** A validator impl specific to Security Lake for Spark Sql Queries */ - private static class SparkSqlSecurityLakeValidatorVisitor extends SqlBaseValidatorVisitor { - - public SparkSqlSecurityLakeValidatorVisitor() { - // only select statement allowed. hence we add the validation error to all types of statements - // by default - // and remove the validation error only for select statement. - getValidationErrors() - .add( - "Unsupported sql statement for security lake data source. Only select queries are" - + " allowed"); - } - - @Override - public Void visitStatementDefault(SqlBaseParser.StatementDefaultContext ctx) { - getValidationErrors().clear(); - return super.visitStatementDefault(ctx); - } - } - public static class SparkSqlTableNameVisitor extends SqlBaseParserBaseVisitor { @Getter private List fullyQualifiedTableNames = new LinkedList<>(); diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/CloudWatchLogsGrammarElementValidator.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/CloudWatchLogsGrammarElementValidator.java index 6a78601191..2d34b8d6ba 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/CloudWatchLogsGrammarElementValidator.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/CloudWatchLogsGrammarElementValidator.java @@ -46,6 +46,7 @@ public class CloudWatchLogsGrammarElementValidator extends DenyListGrammarElemen MANAGE_RESOURCE, ANALYZE_TABLE, CACHE_TABLE, + CLEAR_CACHE, DESCRIBE_NAMESPACE, DESCRIBE_FUNCTION, DESCRIBE_QUERY, diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/DefaultGrammarElementValidator.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/DefaultGrammarElementValidator.java new file mode 100644 index 0000000000..ddd0a1d094 --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/DefaultGrammarElementValidator.java @@ -0,0 +1,13 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.validator; + +public class DefaultGrammarElementValidator implements GrammarElementValidator { + @Override + public boolean isValid(GrammarElement element) { + return true; + } +} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorFactory.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorFactory.java deleted file mode 100644 index c954e4f570..0000000000 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorFactory.java +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.spark.validator; - -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; -import java.util.Map; -import org.opensearch.sql.datasource.model.DataSourceType; - -public class GrammarElementValidatorFactory { - - private static GrammarElementValidator defaultValidator = - new DenyListGrammarElementValidator(ImmutableSet.of()); - private static Map validatorMap = - ImmutableMap.of( - DataSourceType.S3GLUE, new S3GlueGrammarElementValidator(), - DataSourceType.SECURITY_LAKE, new SecurityLakeGrammarElementValidator()); - - public GrammarElementValidator getValidatorForDatasource(DataSourceType dataSourceType) { - return validatorMap.getOrDefault(dataSourceType, defaultValidator); - } -} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorProvider.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorProvider.java new file mode 100644 index 0000000000..7c715a5a7d --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorProvider.java @@ -0,0 +1,21 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.validator; + +import java.util.Map; +import lombok.AllArgsConstructor; +import org.opensearch.sql.datasource.model.DataSourceType; + +@AllArgsConstructor +public class GrammarElementValidatorProvider { + + private final Map validatorMap; + private final GrammarElementValidator defaultValidator; + + public GrammarElementValidator getValidatorForDatasource(DataSourceType dataSourceType) { + return validatorMap.getOrDefault(dataSourceType, defaultValidator); + } +} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidationVisitor.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidationVisitor.java index 930c91c5e7..13a3740c8a 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidationVisitor.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidationVisitor.java @@ -200,7 +200,7 @@ public Void visitAlterViewSchemaBinding(AlterViewSchemaBindingContext ctx) { public Void visitRenameTable(RenameTableContext ctx) { if (ctx.VIEW() != null) { validateAllowed(GrammarElement.ALTER_VIEW); - } else if (ctx.TABLE() != null) { + } else { validateAllowed(GrammarElement.ALTER_NAMESPACE); } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java index 6d41a13db8..23bbb933ab 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java @@ -11,11 +11,11 @@ @AllArgsConstructor public class SQLQueryValidator { - private final GrammarElementValidatorFactory grammarElementValidatorFactory; + private final GrammarElementValidatorProvider grammarElementValidatorProvider; public void validate(String sqlQuery, DataSourceType datasourceType) { GrammarElementValidator grammarElementValidator = - grammarElementValidatorFactory.getValidatorForDatasource(datasourceType); + grammarElementValidatorProvider.getValidatorForDatasource(datasourceType); SQLQueryValidationVisitor visitor = new SQLQueryValidationVisitor(grammarElementValidator); visitor.visit(SQLQueryUtils.getBaseParser(sqlQuery).singleStatement()); } diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java index f98e7b32e3..57ad4ecf42 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java @@ -85,7 +85,9 @@ import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse; import org.opensearch.sql.spark.rest.model.LangType; import org.opensearch.sql.spark.scheduler.AsyncQueryScheduler; -import org.opensearch.sql.spark.validator.GrammarElementValidatorFactory; +import org.opensearch.sql.spark.validator.DefaultGrammarElementValidator; +import org.opensearch.sql.spark.validator.GrammarElementValidatorProvider; +import org.opensearch.sql.spark.validator.S3GlueGrammarElementValidator; import org.opensearch.sql.spark.validator.SQLQueryValidator; /** @@ -178,7 +180,10 @@ public void setUp() { metricsService, new SparkSubmitParametersBuilderProvider(collection)); SQLQueryValidator sqlQueryValidator = - new SQLQueryValidator(new GrammarElementValidatorFactory()); + new SQLQueryValidator( + new GrammarElementValidatorProvider( + ImmutableMap.of(DataSourceType.S3GLUE, new S3GlueGrammarElementValidator()), + new DefaultGrammarElementValidator())); SparkQueryDispatcher sparkQueryDispatcher = new SparkQueryDispatcher( dataSourceService, diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index f28181ca4c..1a38b6977f 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -42,6 +42,7 @@ import com.amazonaws.services.emrserverless.model.GetJobRunResult; import com.amazonaws.services.emrserverless.model.JobRun; import com.amazonaws.services.emrserverless.model.JobRunState; +import com.google.common.collect.ImmutableMap; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -88,7 +89,9 @@ import org.opensearch.sql.spark.response.JobExecutionResponseReader; import org.opensearch.sql.spark.rest.model.LangType; import org.opensearch.sql.spark.scheduler.AsyncQueryScheduler; -import org.opensearch.sql.spark.validator.GrammarElementValidatorFactory; +import org.opensearch.sql.spark.validator.DefaultGrammarElementValidator; +import org.opensearch.sql.spark.validator.GrammarElementValidatorProvider; +import org.opensearch.sql.spark.validator.S3GlueGrammarElementValidator; import org.opensearch.sql.spark.validator.SQLQueryValidator; @ExtendWith(MockitoExtension.class) @@ -115,7 +118,10 @@ public class SparkQueryDispatcherTest { @Mock private AsyncQueryScheduler asyncQueryScheduler; private final SQLQueryValidator sqlQueryValidator = - new SQLQueryValidator(new GrammarElementValidatorFactory()); + new SQLQueryValidator( + new GrammarElementValidatorProvider( + ImmutableMap.of(DataSourceType.S3GLUE, new S3GlueGrammarElementValidator()), + new DefaultGrammarElementValidator())); private DataSourceSparkParameterComposer dataSourceSparkParameterComposer = (datasourceMetadata, sparkSubmitParameters, dispatchQueryRequest, context) -> { diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java index 56cab7ce7f..881ad0e56a 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java @@ -10,7 +10,6 @@ import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.Mockito.when; import static org.opensearch.sql.spark.utils.SQLQueryUtilsTest.IndexQuery.index; import static org.opensearch.sql.spark.utils.SQLQueryUtilsTest.IndexQuery.mv; import static org.opensearch.sql.spark.utils.SQLQueryUtilsTest.IndexQuery.skippingIndex; @@ -22,7 +21,6 @@ import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.datasource.model.DataSource; -import org.opensearch.sql.datasource.model.DataSourceType; import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName; import org.opensearch.sql.spark.dispatcher.model.IndexQueryActionType; import org.opensearch.sql.spark.dispatcher.model.IndexQueryDetails; @@ -444,106 +442,6 @@ void testRecoverIndex() { assertEquals(IndexQueryActionType.RECOVER, indexDetails.getIndexQueryActionType()); } - @Test - void testValidateSparkSqlQuery_ValidQuery() { - List errors = - validateSparkSqlQueryForDataSourceType( - "DELETE FROM Customers WHERE CustomerName='Alfreds Futterkiste'", - DataSourceType.PROMETHEUS); - - assertTrue(errors.isEmpty(), "Valid query should not produce any errors"); - } - - @Test - void testValidateSparkSqlQuery_SelectQuery_DataSourceSecurityLake() { - List errors = - validateSparkSqlQueryForDataSourceType( - "SELECT * FROM users WHERE age > 18", DataSourceType.SECURITY_LAKE); - - assertTrue(errors.isEmpty(), "Valid query should not produce any errors "); - } - - @Test - void testValidateSparkSqlQuery_SelectQuery_DataSourceTypeNull() { - List errors = - validateSparkSqlQueryForDataSourceType("SELECT * FROM users WHERE age > 18", null); - - assertTrue(errors.isEmpty(), "Valid query should not produce any errors "); - } - - @Test - void testValidateSparkSqlQuery_InvalidQuery_SyntaxCheckFailureSkippedWithoutValidationError() { - List errors = - validateSparkSqlQueryForDataSourceType( - "SEECT * FROM users WHERE age > 18", DataSourceType.SECURITY_LAKE); - - assertTrue(errors.isEmpty(), "Valid query should not produce any errors "); - } - - @Test - void testValidateSparkSqlQuery_nullDatasource() { - List errors = - SQLQueryUtils.validateSparkSqlQuery(null, "SELECT * FROM users WHERE age > 18"); - assertTrue(errors.isEmpty(), "Valid query should not produce any errors "); - } - - private List validateSparkSqlQueryForDataSourceType( - String query, DataSourceType dataSourceType) { - when(this.dataSource.getConnectorType()).thenReturn(dataSourceType); - - return SQLQueryUtils.validateSparkSqlQuery(this.dataSource, query); - } - - @Test - void testValidateSparkSqlQuery_SelectQuery_DataSourceSecurityLake_ValidationFails() { - List errors = - validateSparkSqlQueryForDataSourceType( - "REFRESH INDEX cv1 ON mys3.default.http_logs", DataSourceType.SECURITY_LAKE); - - assertFalse( - errors.isEmpty(), - "Invalid query as Security Lake datasource supports only flint queries and SELECT sql" - + " queries. Given query was REFRESH sql query"); - assertEquals( - errors.get(0), - "Unsupported sql statement for security lake data source. Only select queries are allowed"); - } - - @Test - void - testValidateSparkSqlQuery_NonSelectStatementContainingSelectClause_DataSourceSecurityLake_ValidationFails() { - String query = - "CREATE TABLE AccountSummaryOrWhatever AS " - + "select taxid, address1, count(address1) from dbo.t " - + "group by taxid, address1;"; - - List errors = - validateSparkSqlQueryForDataSourceType(query, DataSourceType.SECURITY_LAKE); - - assertFalse( - errors.isEmpty(), - "Invalid query as Security Lake datasource supports only flint queries and SELECT sql" - + " queries. Given query was REFRESH sql query"); - assertEquals( - errors.get(0), - "Unsupported sql statement for security lake data source. Only select queries are allowed"); - } - - @Test - void testValidateSparkSqlQuery_InvalidQuery() { - when(dataSource.getConnectorType()).thenReturn(DataSourceType.PROMETHEUS); - String invalidQuery = "CREATE FUNCTION myUDF AS 'com.example.UDF'"; - - List errors = SQLQueryUtils.validateSparkSqlQuery(dataSource, invalidQuery); - - assertFalse(errors.isEmpty(), "Invalid query should produce errors"); - assertEquals(1, errors.size(), "Should have one error"); - assertEquals( - "Creating user-defined functions is not allowed", - errors.get(0), - "Error message should match"); - } - @Getter protected static class IndexQuery { private String query; diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/validator/GrammarElementValidatorProviderTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/GrammarElementValidatorProviderTest.java new file mode 100644 index 0000000000..7d4b255356 --- /dev/null +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/GrammarElementValidatorProviderTest.java @@ -0,0 +1,39 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.validator; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import com.google.common.collect.ImmutableMap; +import org.junit.jupiter.api.Test; +import org.opensearch.sql.datasource.model.DataSourceType; + +class GrammarElementValidatorProviderTest { + S3GlueGrammarElementValidator s3GlueGrammarElementValidator = new S3GlueGrammarElementValidator(); + SecurityLakeGrammarElementValidator securityLakeGrammarElementValidator = + new SecurityLakeGrammarElementValidator(); + DefaultGrammarElementValidator defaultGrammarElementValidator = + new DefaultGrammarElementValidator(); + GrammarElementValidatorProvider grammarElementValidatorProvider = + new GrammarElementValidatorProvider( + ImmutableMap.of( + DataSourceType.S3GLUE, s3GlueGrammarElementValidator, + DataSourceType.SECURITY_LAKE, securityLakeGrammarElementValidator), + defaultGrammarElementValidator); + + @Test + public void test() { + assertEquals( + s3GlueGrammarElementValidator, + grammarElementValidatorProvider.getValidatorForDatasource(DataSourceType.S3GLUE)); + assertEquals( + securityLakeGrammarElementValidator, + grammarElementValidatorProvider.getValidatorForDatasource(DataSourceType.SECURITY_LAKE)); + assertEquals( + defaultGrammarElementValidator, + grammarElementValidatorProvider.getValidatorForDatasource(DataSourceType.PROMETHEUS)); + } +} diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java index b7f8376510..725d5362aa 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java @@ -15,6 +15,7 @@ import org.antlr.v4.runtime.CommonTokenStream; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.common.antlr.CaseInsensitiveCharStream; @@ -25,10 +26,9 @@ @ExtendWith(MockitoExtension.class) class SQLQueryValidatorTest { - GrammarElementValidatorFactory factory = new GrammarElementValidatorFactory(); - SQLQueryValidator sqlQueryValidator = new SQLQueryValidator(factory); + @Mock GrammarElementValidatorProvider mockedProvider; - @Mock GrammarElementValidatorFactory mockedFactory; + @InjectMocks SQLQueryValidator sqlQueryValidator; private enum TestElement { // DDL Statements @@ -90,7 +90,7 @@ private enum TestElement { LOAD("LOAD DATA INPATH '/path/to/data' INTO TABLE target_table;"), // Data Retrieval Statements - SELECT("SELECT 1"), + SELECT("SELECT 1;"), EXPLAIN("EXPLAIN SELECT * FROM my_table;"), COMMON_TABLE_EXPRESSION( "WITH cte AS (SELECT * FROM my_table WHERE age > 30) SELECT * FROM cte;"), @@ -209,17 +209,20 @@ private enum TestElement { @Test void testAllowAllByDefault() { - VerifyValidator v = new VerifyValidator(sqlQueryValidator, DataSourceType.SPARK); + when(mockedProvider.getValidatorForDatasource(any())) + .thenReturn(new DefaultGrammarElementValidator()); + VerifyValidator v = + new VerifyValidator(sqlQueryValidator, DataSourceType.SPARK); Arrays.stream(TestElement.values()).forEach(v::ok); } @Test void testDenyAllValidator() { - when(mockedFactory.getValidatorForDatasource(any())).thenReturn(element -> false); + when(mockedProvider.getValidatorForDatasource(any())).thenReturn(element -> false); VerifyValidator v = - new VerifyValidator(new SQLQueryValidator(mockedFactory), DataSourceType.SPARK); - // The elements which doesn't have validation will be accepted. (That's why there are some 'ok' - // case) + new VerifyValidator(sqlQueryValidator, DataSourceType.SPARK); + // The elements which doesn't have validation will be accepted. + // That's why there are some ok case // DDL Statements v.ng(TestElement.ALTER_DATABASE); @@ -332,8 +335,11 @@ void testDenyAllValidator() { } @Test - void s3glueQueries() { + void testS3glueQueries() { + when(mockedProvider.getValidatorForDatasource(any())) + .thenReturn(new S3GlueGrammarElementValidator()); VerifyValidator v = new VerifyValidator(sqlQueryValidator, DataSourceType.S3GLUE); + // DDL Statements v.ok(TestElement.ALTER_DATABASE); v.ok(TestElement.ALTER_TABLE); @@ -446,8 +452,11 @@ void s3glueQueries() { } @Test - void securityLakeQueries() { + void testSecurityLakeQueries() { + when(mockedProvider.getValidatorForDatasource(any())) + .thenReturn(new SecurityLakeGrammarElementValidator()); VerifyValidator v = new VerifyValidator(sqlQueryValidator, DataSourceType.SECURITY_LAKE); + // DDL Statements v.ng(TestElement.ALTER_DATABASE); v.ng(TestElement.ALTER_TABLE); @@ -559,6 +568,124 @@ void securityLakeQueries() { v.ng(TestElement.INTEGRATION_WITH_HIVE_UDFS_UDAFS_UDTFS); } + @Test + void testCloudWatchLogs() { + when(mockedProvider.getValidatorForDatasource(any())) + .thenReturn(new CloudWatchLogsGrammarElementValidator()); + VerifyValidator v = + new VerifyValidator(new SQLQueryValidator(mockedProvider), DataSourceType.SPARK); + + // DDL Statements + v.ng(TestElement.ALTER_DATABASE); + v.ng(TestElement.ALTER_TABLE); + v.ng(TestElement.ALTER_VIEW); + v.ng(TestElement.CREATE_DATABASE); + v.ng(TestElement.CREATE_FUNCTION); + v.ng(TestElement.CREATE_TABLE); + v.ng(TestElement.CREATE_VIEW); + v.ng(TestElement.DROP_DATABASE); + v.ng(TestElement.DROP_FUNCTION); + v.ng(TestElement.DROP_TABLE); + v.ng(TestElement.DROP_VIEW); + v.ng(TestElement.REPAIR_TABLE); + v.ng(TestElement.TRUNCATE_TABLE); + + // DML Statements + v.ng(TestElement.INSERT_TABLE); + v.ng(TestElement.INSERT_OVERWRITE_DIRECTORY); + v.ng(TestElement.LOAD); + + // Data Retrieval + v.ok(TestElement.SELECT); + v.ng(TestElement.EXPLAIN); + v.ng(TestElement.COMMON_TABLE_EXPRESSION); + v.ng(TestElement.CLUSTER_BY_CLAUSE); + v.ng(TestElement.DISTRIBUTE_BY_CLAUSE); + v.ok(TestElement.GROUP_BY_CLAUSE); + v.ok(TestElement.HAVING_CLAUSE); + v.ng(TestElement.HINTS); + v.ng(TestElement.INLINE_TABLE); + v.ng(TestElement.FILE); + v.ok(TestElement.INNER_JOIN); + v.ng(TestElement.CROSS_JOIN); + v.ok(TestElement.LEFT_OUTER_JOIN); + v.ng(TestElement.LEFT_SEMI_JOIN); + v.ng(TestElement.RIGHT_OUTER_JOIN); + v.ng(TestElement.FULL_OUTER_JOIN); + v.ng(TestElement.LEFT_ANTI_JOIN); + v.ok(TestElement.LIKE_PREDICATE); + v.ok(TestElement.LIMIT_CLAUSE); + v.ok(TestElement.OFFSET_CLAUSE); + v.ok(TestElement.ORDER_BY_CLAUSE); + v.ok(TestElement.SET_OPERATORS); + v.ok(TestElement.SORT_BY_CLAUSE); + v.ng(TestElement.TABLESAMPLE); + v.ng(TestElement.TABLE_VALUED_FUNCTION); + v.ok(TestElement.WHERE_CLAUSE); + v.ok(TestElement.AGGREGATE_FUNCTION); + v.ok(TestElement.WINDOW_FUNCTION); + v.ok(TestElement.CASE_CLAUSE); + v.ok(TestElement.PIVOT_CLAUSE); + v.ok(TestElement.UNPIVOT_CLAUSE); + v.ng(TestElement.LATERAL_VIEW_CLAUSE); + v.ng(TestElement.LATERAL_SUBQUERY); + v.ng(TestElement.TRANSFORM_CLAUSE); + + // Auxiliary Statements + v.ng(TestElement.ADD_FILE); + v.ng(TestElement.ADD_JAR); + v.ng(TestElement.ANALYZE_TABLE); + v.ng(TestElement.CACHE_TABLE); + v.ng(TestElement.CLEAR_CACHE); + v.ng(TestElement.DESCRIBE_DATABASE); + v.ng(TestElement.DESCRIBE_FUNCTION); + v.ng(TestElement.DESCRIBE_QUERY); + v.ng(TestElement.DESCRIBE_TABLE); + v.ng(TestElement.LIST_FILE); + v.ng(TestElement.LIST_JAR); + v.ng(TestElement.REFRESH); + v.ng(TestElement.REFRESH_TABLE); + v.ng(TestElement.REFRESH_FUNCTION); + v.ng(TestElement.RESET); + v.ng(TestElement.SET); + v.ng(TestElement.SHOW_COLUMNS); + v.ng(TestElement.SHOW_CREATE_TABLE); + v.ng(TestElement.SHOW_DATABASES); + v.ng(TestElement.SHOW_FUNCTIONS); + v.ng(TestElement.SHOW_PARTITIONS); + v.ng(TestElement.SHOW_TABLE_EXTENDED); + v.ng(TestElement.SHOW_TABLES); + v.ng(TestElement.SHOW_TBLPROPERTIES); + v.ng(TestElement.SHOW_VIEWS); + v.ng(TestElement.UNCACHE_TABLE); + + // Functions + v.ok(TestElement.ARRAY_FUNCTIONS); + v.ok(TestElement.MAP_FUNCTIONS); + v.ok(TestElement.DATE_AND_TIMESTAMP_FUNCTIONS); + v.ok(TestElement.JSON_FUNCTIONS); + v.ok(TestElement.MATHEMATICAL_FUNCTIONS); + v.ok(TestElement.STRING_FUNCTIONS); + v.ok(TestElement.BITWISE_FUNCTIONS); + v.ok(TestElement.CONVERSION_FUNCTIONS); + v.ok(TestElement.CONDITIONAL_FUNCTIONS); + v.ok(TestElement.PREDICATE_FUNCTIONS); + v.ng(TestElement.CSV_FUNCTIONS); + v.ng(TestElement.MISC_FUNCTIONS); + + // Aggregate-like Functions + v.ok(TestElement.AGGREGATE_FUNCTIONS); + v.ok(TestElement.WINDOW_FUNCTIONS); + + // Generator Functions + v.ok(TestElement.GENERATOR_FUNCTIONS); + + // UDFs + v.ng(TestElement.SCALAR_USER_DEFINED_FUNCTIONS); + v.ng(TestElement.USER_DEFINED_AGGREGATE_FUNCTIONS); + v.ng(TestElement.INTEGRATION_WITH_HIVE_UDFS_UDAFS_UDTFS); + } + @AllArgsConstructor private static class VerifyValidator { private final SQLQueryValidator validator; diff --git a/async-query/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java b/async-query/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java index 74c5d7df14..db070182a3 100644 --- a/async-query/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java +++ b/async-query/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java @@ -7,6 +7,7 @@ import static org.opensearch.sql.spark.execution.statestore.StateStore.ALL_DATASOURCE; +import com.google.common.collect.ImmutableMap; import lombok.RequiredArgsConstructor; import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.service.ClusterService; @@ -64,7 +65,11 @@ import org.opensearch.sql.spark.response.OpenSearchJobExecutionResponseReader; import org.opensearch.sql.spark.scheduler.AsyncQueryScheduler; import org.opensearch.sql.spark.scheduler.OpenSearchAsyncQueryScheduler; +import org.opensearch.sql.spark.validator.DefaultGrammarElementValidator; +import org.opensearch.sql.spark.validator.GrammarElementValidatorProvider; +import org.opensearch.sql.spark.validator.S3GlueGrammarElementValidator; import org.opensearch.sql.spark.validator.SQLQueryValidator; +import org.opensearch.sql.spark.validator.SecurityLakeGrammarElementValidator; @RequiredArgsConstructor public class AsyncExecutorServiceModule extends AbstractModule { @@ -176,6 +181,19 @@ public SparkSubmitParametersBuilderProvider sparkSubmitParametersBuilderProvider return new SparkSubmitParametersBuilderProvider(collection); } + @Provides + public SQLQueryValidator sqlQueryValidator() { + GrammarElementValidatorProvider validatorProvider = + new GrammarElementValidatorProvider( + ImmutableMap.of( + DataSourceType.S3GLUE, + new S3GlueGrammarElementValidator(), + DataSourceType.SECURITY_LAKE, + new SecurityLakeGrammarElementValidator()), + new DefaultGrammarElementValidator()); + return new SQLQueryValidator(validatorProvider); + } + @Provides public IndexDMLResultStorageService indexDMLResultStorageService( DataSourceService dataSourceService, StateStore stateStore) { diff --git a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java index db0adfc156..175f9ac914 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java @@ -312,7 +312,7 @@ public void withSessionCreateAsyncQueryFailed() { // 1. create async query. CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( - new CreateAsyncQueryRequest("myselect 1", MYS3_DATASOURCE, LangType.SQL, null), + new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null), asyncQueryRequestContext); assertNotNull(response.getSessionId()); Optional statementModel = diff --git a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java index 3e3d5217e0..72ed17f5aa 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java @@ -102,7 +102,9 @@ import org.opensearch.sql.spark.response.OpenSearchJobExecutionResponseReader; import org.opensearch.sql.spark.scheduler.AsyncQueryScheduler; import org.opensearch.sql.spark.scheduler.OpenSearchAsyncQueryScheduler; -import org.opensearch.sql.spark.validator.GrammarElementValidatorFactory; +import org.opensearch.sql.spark.validator.DefaultGrammarElementValidator; +import org.opensearch.sql.spark.validator.GrammarElementValidatorProvider; +import org.opensearch.sql.spark.validator.S3GlueGrammarElementValidator; import org.opensearch.sql.spark.validator.SQLQueryValidator; import org.opensearch.sql.storage.DataSourceFactory; import org.opensearch.test.OpenSearchIntegTestCase; @@ -311,7 +313,10 @@ protected AsyncQueryExecutorService createAsyncQueryExecutorService( new OpenSearchMetricsService(), sparkSubmitParametersBuilderProvider); SQLQueryValidator sqlQueryValidator = - new SQLQueryValidator(new GrammarElementValidatorFactory()); + new SQLQueryValidator( + new GrammarElementValidatorProvider( + ImmutableMap.of(DataSourceType.S3GLUE, new S3GlueGrammarElementValidator()), + new DefaultGrammarElementValidator())); SparkQueryDispatcher sparkQueryDispatcher = new SparkQueryDispatcher( this.dataSourceService, From 11ca4ff6e1b5e6dd8325006dde2c144323643e00 Mon Sep 17 00:00:00 2001 From: Tomoyuki Morita Date: Thu, 19 Sep 2024 11:14:25 -0700 Subject: [PATCH 10/14] Add comments Signed-off-by: Tomoyuki Morita --- .../org/opensearch/sql/spark/validator/FunctionType.java | 4 ++++ .../sql/spark/validator/GrammarElementValidator.java | 5 +++++ .../spark/validator/GrammarElementValidatorProvider.java | 1 + .../sql/spark/validator/SQLQueryValidationVisitor.java | 1 + .../opensearch/sql/spark/validator/SQLQueryValidator.java | 8 ++++++++ .../sql/spark/validator/SQLQueryValidatorTest.java | 6 ++---- 6 files changed, 21 insertions(+), 4 deletions(-) diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/FunctionType.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/FunctionType.java index a17f2f8b21..da3760efd6 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/FunctionType.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/FunctionType.java @@ -11,6 +11,10 @@ import java.util.stream.Collectors; import lombok.AllArgsConstructor; +/** + * Enum for defining and looking up SQL function type based on its name. Unknown one will be + * considered as UDF (User Defined Function) + */ @AllArgsConstructor public enum FunctionType { AGGREGATE("Aggregate"), diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidator.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidator.java index b11999b5d1..cc49643772 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidator.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidator.java @@ -5,6 +5,11 @@ package org.opensearch.sql.spark.validator; +/** Interface for validator to decide if each GrammarElement is valid or not. */ public interface GrammarElementValidator { + + /** + * @return true if element is valid (accepted) + */ boolean isValid(GrammarElement element); } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorProvider.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorProvider.java index 7c715a5a7d..9755a1c0b6 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorProvider.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorProvider.java @@ -9,6 +9,7 @@ import lombok.AllArgsConstructor; import org.opensearch.sql.datasource.model.DataSourceType; +/** Provides GrammarElementValidator based on DataSourceType. */ @AllArgsConstructor public class GrammarElementValidatorProvider { diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidationVisitor.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidationVisitor.java index 13a3740c8a..9ec0fb0109 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidationVisitor.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidationVisitor.java @@ -84,6 +84,7 @@ import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.UnsetNamespacePropertiesContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParserBaseVisitor; +/** This visitor validate grammar using GrammarElementValidator */ @AllArgsConstructor public class SQLQueryValidationVisitor extends SqlBaseParserBaseVisitor { private final GrammarElementValidator grammarElementValidator; diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java index 23bbb933ab..b0f93c7e14 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java @@ -9,10 +9,18 @@ import org.opensearch.sql.datasource.model.DataSourceType; import org.opensearch.sql.spark.utils.SQLQueryUtils; +/** Validate input SQL query based on the DataSourceType. */ @AllArgsConstructor public class SQLQueryValidator { private final GrammarElementValidatorProvider grammarElementValidatorProvider; + /** + * It will look up validator associated with the DataSourceType, and throw + * IllegalArgumentException if invalid grammar element is found. + * + * @param sqlQuery The query to be validated + * @param datasourceType + */ public void validate(String sqlQuery, DataSourceType datasourceType) { GrammarElementValidator grammarElementValidator = grammarElementValidatorProvider.getValidatorForDatasource(datasourceType); diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java index 725d5362aa..a939e9411c 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java @@ -211,16 +211,14 @@ private enum TestElement { void testAllowAllByDefault() { when(mockedProvider.getValidatorForDatasource(any())) .thenReturn(new DefaultGrammarElementValidator()); - VerifyValidator v = - new VerifyValidator(sqlQueryValidator, DataSourceType.SPARK); + VerifyValidator v = new VerifyValidator(sqlQueryValidator, DataSourceType.SPARK); Arrays.stream(TestElement.values()).forEach(v::ok); } @Test void testDenyAllValidator() { when(mockedProvider.getValidatorForDatasource(any())).thenReturn(element -> false); - VerifyValidator v = - new VerifyValidator(sqlQueryValidator, DataSourceType.SPARK); + VerifyValidator v = new VerifyValidator(sqlQueryValidator, DataSourceType.SPARK); // The elements which doesn't have validation will be accepted. // That's why there are some ok case From d31bdc8a30d7ee18712c22d7599ea45645101932 Mon Sep 17 00:00:00 2001 From: Tomoyuki Morita Date: Fri, 20 Sep 2024 09:27:28 -0700 Subject: [PATCH 11/14] Address comments Signed-off-by: Tomoyuki Morita --- ...CloudWatchLogsGrammarElementValidator.java | 77 ------------ .../spark/validator/SQLQueryValidator.java | 11 +- .../validator/SQLQueryValidatorTest.java | 118 ------------------ 3 files changed, 10 insertions(+), 196 deletions(-) delete mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/validator/CloudWatchLogsGrammarElementValidator.java diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/CloudWatchLogsGrammarElementValidator.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/CloudWatchLogsGrammarElementValidator.java deleted file mode 100644 index 2d34b8d6ba..0000000000 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/CloudWatchLogsGrammarElementValidator.java +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.spark.validator; - -import static org.opensearch.sql.spark.validator.GrammarElement.*; - -import com.google.common.collect.ImmutableSet; -import java.util.Set; - -public class CloudWatchLogsGrammarElementValidator extends DenyListGrammarElementValidator { - private static final Set CWL_DENY_LIST = - ImmutableSet.builder() - .add( - ALTER_NAMESPACE, - ALTER_VIEW, - CREATE_NAMESPACE, - CREATE_FUNCTION, - CREATE_VIEW, - DROP_FUNCTION, - DROP_NAMESPACE, - DROP_VIEW, - REPAIR_TABLE, - TRUNCATE_TABLE, - INSERT, - LOAD, - EXPLAIN, - WITH, - CLUSTER_BY, - DISTRIBUTE_BY, - HINTS, - INLINE_TABLE, - FILE, - CROSS_JOIN, - LEFT_SEMI_JOIN, - RIGHT_OUTER_JOIN, - FULL_OUTER_JOIN, - LEFT_ANTI_JOIN, - TABLESAMPLE, - TABLE_VALUED_FUNCTION, - LATERAL_VIEW, - LATERAL_SUBQUERY, - TRANSFORM, - MANAGE_RESOURCE, - ANALYZE_TABLE, - CACHE_TABLE, - CLEAR_CACHE, - DESCRIBE_NAMESPACE, - DESCRIBE_FUNCTION, - DESCRIBE_QUERY, - DESCRIBE_TABLE, - REFRESH_RESOURCE, - REFRESH_TABLE, - REFRESH_FUNCTION, - RESET, - SET, - SHOW_COLUMNS, - SHOW_CREATE_TABLE, - SHOW_NAMESPACES, - SHOW_FUNCTIONS, - SHOW_PARTITIONS, - SHOW_TABLE_EXTENDED, - SHOW_TABLES, - SHOW_TBLPROPERTIES, - SHOW_VIEWS, - UNCACHE_TABLE, - CSV_FUNCTIONS, - MISC_FUNCTIONS, - UDF) - .build(); - - public CloudWatchLogsGrammarElementValidator() { - super(CWL_DENY_LIST); - } -} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java index b0f93c7e14..f387cbad25 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java @@ -6,12 +6,16 @@ package org.opensearch.sql.spark.validator; import lombok.AllArgsConstructor; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.opensearch.sql.datasource.model.DataSourceType; import org.opensearch.sql.spark.utils.SQLQueryUtils; /** Validate input SQL query based on the DataSourceType. */ @AllArgsConstructor public class SQLQueryValidator { + private static final Logger log = LogManager.getLogger(SQLQueryValidator.class); + private final GrammarElementValidatorProvider grammarElementValidatorProvider; /** @@ -25,6 +29,11 @@ public void validate(String sqlQuery, DataSourceType datasourceType) { GrammarElementValidator grammarElementValidator = grammarElementValidatorProvider.getValidatorForDatasource(datasourceType); SQLQueryValidationVisitor visitor = new SQLQueryValidationVisitor(grammarElementValidator); - visitor.visit(SQLQueryUtils.getBaseParser(sqlQuery).singleStatement()); + try { + visitor.visit(SQLQueryUtils.getBaseParser(sqlQuery).singleStatement()); + } catch (IllegalArgumentException e) { + log.error("Query validation failed. DataSourceType=" + datasourceType, e); + throw e; + } } } diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java index a939e9411c..635bf89e65 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java @@ -566,124 +566,6 @@ void testSecurityLakeQueries() { v.ng(TestElement.INTEGRATION_WITH_HIVE_UDFS_UDAFS_UDTFS); } - @Test - void testCloudWatchLogs() { - when(mockedProvider.getValidatorForDatasource(any())) - .thenReturn(new CloudWatchLogsGrammarElementValidator()); - VerifyValidator v = - new VerifyValidator(new SQLQueryValidator(mockedProvider), DataSourceType.SPARK); - - // DDL Statements - v.ng(TestElement.ALTER_DATABASE); - v.ng(TestElement.ALTER_TABLE); - v.ng(TestElement.ALTER_VIEW); - v.ng(TestElement.CREATE_DATABASE); - v.ng(TestElement.CREATE_FUNCTION); - v.ng(TestElement.CREATE_TABLE); - v.ng(TestElement.CREATE_VIEW); - v.ng(TestElement.DROP_DATABASE); - v.ng(TestElement.DROP_FUNCTION); - v.ng(TestElement.DROP_TABLE); - v.ng(TestElement.DROP_VIEW); - v.ng(TestElement.REPAIR_TABLE); - v.ng(TestElement.TRUNCATE_TABLE); - - // DML Statements - v.ng(TestElement.INSERT_TABLE); - v.ng(TestElement.INSERT_OVERWRITE_DIRECTORY); - v.ng(TestElement.LOAD); - - // Data Retrieval - v.ok(TestElement.SELECT); - v.ng(TestElement.EXPLAIN); - v.ng(TestElement.COMMON_TABLE_EXPRESSION); - v.ng(TestElement.CLUSTER_BY_CLAUSE); - v.ng(TestElement.DISTRIBUTE_BY_CLAUSE); - v.ok(TestElement.GROUP_BY_CLAUSE); - v.ok(TestElement.HAVING_CLAUSE); - v.ng(TestElement.HINTS); - v.ng(TestElement.INLINE_TABLE); - v.ng(TestElement.FILE); - v.ok(TestElement.INNER_JOIN); - v.ng(TestElement.CROSS_JOIN); - v.ok(TestElement.LEFT_OUTER_JOIN); - v.ng(TestElement.LEFT_SEMI_JOIN); - v.ng(TestElement.RIGHT_OUTER_JOIN); - v.ng(TestElement.FULL_OUTER_JOIN); - v.ng(TestElement.LEFT_ANTI_JOIN); - v.ok(TestElement.LIKE_PREDICATE); - v.ok(TestElement.LIMIT_CLAUSE); - v.ok(TestElement.OFFSET_CLAUSE); - v.ok(TestElement.ORDER_BY_CLAUSE); - v.ok(TestElement.SET_OPERATORS); - v.ok(TestElement.SORT_BY_CLAUSE); - v.ng(TestElement.TABLESAMPLE); - v.ng(TestElement.TABLE_VALUED_FUNCTION); - v.ok(TestElement.WHERE_CLAUSE); - v.ok(TestElement.AGGREGATE_FUNCTION); - v.ok(TestElement.WINDOW_FUNCTION); - v.ok(TestElement.CASE_CLAUSE); - v.ok(TestElement.PIVOT_CLAUSE); - v.ok(TestElement.UNPIVOT_CLAUSE); - v.ng(TestElement.LATERAL_VIEW_CLAUSE); - v.ng(TestElement.LATERAL_SUBQUERY); - v.ng(TestElement.TRANSFORM_CLAUSE); - - // Auxiliary Statements - v.ng(TestElement.ADD_FILE); - v.ng(TestElement.ADD_JAR); - v.ng(TestElement.ANALYZE_TABLE); - v.ng(TestElement.CACHE_TABLE); - v.ng(TestElement.CLEAR_CACHE); - v.ng(TestElement.DESCRIBE_DATABASE); - v.ng(TestElement.DESCRIBE_FUNCTION); - v.ng(TestElement.DESCRIBE_QUERY); - v.ng(TestElement.DESCRIBE_TABLE); - v.ng(TestElement.LIST_FILE); - v.ng(TestElement.LIST_JAR); - v.ng(TestElement.REFRESH); - v.ng(TestElement.REFRESH_TABLE); - v.ng(TestElement.REFRESH_FUNCTION); - v.ng(TestElement.RESET); - v.ng(TestElement.SET); - v.ng(TestElement.SHOW_COLUMNS); - v.ng(TestElement.SHOW_CREATE_TABLE); - v.ng(TestElement.SHOW_DATABASES); - v.ng(TestElement.SHOW_FUNCTIONS); - v.ng(TestElement.SHOW_PARTITIONS); - v.ng(TestElement.SHOW_TABLE_EXTENDED); - v.ng(TestElement.SHOW_TABLES); - v.ng(TestElement.SHOW_TBLPROPERTIES); - v.ng(TestElement.SHOW_VIEWS); - v.ng(TestElement.UNCACHE_TABLE); - - // Functions - v.ok(TestElement.ARRAY_FUNCTIONS); - v.ok(TestElement.MAP_FUNCTIONS); - v.ok(TestElement.DATE_AND_TIMESTAMP_FUNCTIONS); - v.ok(TestElement.JSON_FUNCTIONS); - v.ok(TestElement.MATHEMATICAL_FUNCTIONS); - v.ok(TestElement.STRING_FUNCTIONS); - v.ok(TestElement.BITWISE_FUNCTIONS); - v.ok(TestElement.CONVERSION_FUNCTIONS); - v.ok(TestElement.CONDITIONAL_FUNCTIONS); - v.ok(TestElement.PREDICATE_FUNCTIONS); - v.ng(TestElement.CSV_FUNCTIONS); - v.ng(TestElement.MISC_FUNCTIONS); - - // Aggregate-like Functions - v.ok(TestElement.AGGREGATE_FUNCTIONS); - v.ok(TestElement.WINDOW_FUNCTIONS); - - // Generator Functions - v.ok(TestElement.GENERATOR_FUNCTIONS); - - // UDFs - v.ng(TestElement.SCALAR_USER_DEFINED_FUNCTIONS); - v.ng(TestElement.USER_DEFINED_AGGREGATE_FUNCTIONS); - v.ng(TestElement.INTEGRATION_WITH_HIVE_UDFS_UDAFS_UDTFS); - } - @AllArgsConstructor private static class VerifyValidator { private final SQLQueryValidator validator; From 683126ea11f567811cbe64c0cc319f5c5cbad4fc Mon Sep 17 00:00:00 2001 From: Tomoyuki Morita Date: Fri, 20 Sep 2024 13:46:53 -0700 Subject: [PATCH 12/14] Allow join types for now Signed-off-by: Tomoyuki Morita --- .../S3GlueGrammarElementValidator.java | 5 ----- .../SecurityLakeGrammarElementValidator.java | 5 ----- .../validator/SQLQueryValidatorTest.java | 20 +++++++++---------- 3 files changed, 10 insertions(+), 20 deletions(-) diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/S3GlueGrammarElementValidator.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/S3GlueGrammarElementValidator.java index 9ed1fd9e9e..799ba7975e 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/S3GlueGrammarElementValidator.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/S3GlueGrammarElementValidator.java @@ -55,11 +55,6 @@ public class S3GlueGrammarElementValidator extends DenyListGrammarElementValidat HINTS, INLINE_TABLE, FILE, - CROSS_JOIN, - LEFT_SEMI_JOIN, - RIGHT_OUTER_JOIN, - FULL_OUTER_JOIN, - LEFT_ANTI_JOIN, TABLESAMPLE, TABLE_VALUED_FUNCTION, TRANSFORM, diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SecurityLakeGrammarElementValidator.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SecurityLakeGrammarElementValidator.java index 7dd2b0ee89..778074e353 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SecurityLakeGrammarElementValidator.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SecurityLakeGrammarElementValidator.java @@ -81,11 +81,6 @@ public class SecurityLakeGrammarElementValidator extends DenyListGrammarElementV HINTS, INLINE_TABLE, FILE, - CROSS_JOIN, - LEFT_SEMI_JOIN, - RIGHT_OUTER_JOIN, - FULL_OUTER_JOIN, - LEFT_ANTI_JOIN, TABLESAMPLE, TABLE_VALUED_FUNCTION, TRANSFORM, diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java index 635bf89e65..6726b56994 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java @@ -370,12 +370,12 @@ void testS3glueQueries() { v.ng(TestElement.INLINE_TABLE); v.ng(TestElement.FILE); v.ok(TestElement.INNER_JOIN); - v.ng(TestElement.CROSS_JOIN); + v.ok(TestElement.CROSS_JOIN); v.ok(TestElement.LEFT_OUTER_JOIN); - v.ng(TestElement.LEFT_SEMI_JOIN); - v.ng(TestElement.RIGHT_OUTER_JOIN); - v.ng(TestElement.FULL_OUTER_JOIN); - v.ng(TestElement.LEFT_ANTI_JOIN); + v.ok(TestElement.LEFT_SEMI_JOIN); + v.ok(TestElement.RIGHT_OUTER_JOIN); + v.ok(TestElement.FULL_OUTER_JOIN); + v.ok(TestElement.LEFT_ANTI_JOIN); v.ok(TestElement.LIKE_PREDICATE); v.ok(TestElement.LIMIT_CLAUSE); v.ok(TestElement.OFFSET_CLAUSE); @@ -487,12 +487,12 @@ void testSecurityLakeQueries() { v.ng(TestElement.INLINE_TABLE); v.ng(TestElement.FILE); v.ok(TestElement.INNER_JOIN); - v.ng(TestElement.CROSS_JOIN); + v.ok(TestElement.CROSS_JOIN); v.ok(TestElement.LEFT_OUTER_JOIN); - v.ng(TestElement.LEFT_SEMI_JOIN); - v.ng(TestElement.RIGHT_OUTER_JOIN); - v.ng(TestElement.FULL_OUTER_JOIN); - v.ng(TestElement.LEFT_ANTI_JOIN); + v.ok(TestElement.LEFT_SEMI_JOIN); + v.ok(TestElement.RIGHT_OUTER_JOIN); + v.ok(TestElement.FULL_OUTER_JOIN); + v.ok(TestElement.LEFT_ANTI_JOIN); v.ok(TestElement.LIKE_PREDICATE); v.ok(TestElement.LIMIT_CLAUSE); v.ok(TestElement.OFFSET_CLAUSE); From 4e8a02cc00daaa8451f587de422a41dfb7ef4628 Mon Sep 17 00:00:00 2001 From: Tomoyuki Morita Date: Fri, 20 Sep 2024 20:48:28 -0700 Subject: [PATCH 13/14] Fix style Signed-off-by: Tomoyuki Morita --- .../sql/spark/validator/S3GlueGrammarElementValidator.java | 5 ----- .../spark/validator/SecurityLakeGrammarElementValidator.java | 5 ----- 2 files changed, 10 deletions(-) diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/S3GlueGrammarElementValidator.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/S3GlueGrammarElementValidator.java index 799ba7975e..e7a0ce1b36 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/S3GlueGrammarElementValidator.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/S3GlueGrammarElementValidator.java @@ -9,25 +9,20 @@ import static org.opensearch.sql.spark.validator.GrammarElement.CLUSTER_BY; import static org.opensearch.sql.spark.validator.GrammarElement.CREATE_FUNCTION; import static org.opensearch.sql.spark.validator.GrammarElement.CREATE_VIEW; -import static org.opensearch.sql.spark.validator.GrammarElement.CROSS_JOIN; import static org.opensearch.sql.spark.validator.GrammarElement.DESCRIBE_FUNCTION; import static org.opensearch.sql.spark.validator.GrammarElement.DISTRIBUTE_BY; import static org.opensearch.sql.spark.validator.GrammarElement.DROP_FUNCTION; import static org.opensearch.sql.spark.validator.GrammarElement.DROP_VIEW; import static org.opensearch.sql.spark.validator.GrammarElement.FILE; -import static org.opensearch.sql.spark.validator.GrammarElement.FULL_OUTER_JOIN; import static org.opensearch.sql.spark.validator.GrammarElement.HINTS; import static org.opensearch.sql.spark.validator.GrammarElement.INLINE_TABLE; import static org.opensearch.sql.spark.validator.GrammarElement.INSERT; -import static org.opensearch.sql.spark.validator.GrammarElement.LEFT_ANTI_JOIN; -import static org.opensearch.sql.spark.validator.GrammarElement.LEFT_SEMI_JOIN; import static org.opensearch.sql.spark.validator.GrammarElement.LOAD; import static org.opensearch.sql.spark.validator.GrammarElement.MANAGE_RESOURCE; import static org.opensearch.sql.spark.validator.GrammarElement.MISC_FUNCTIONS; import static org.opensearch.sql.spark.validator.GrammarElement.REFRESH_FUNCTION; import static org.opensearch.sql.spark.validator.GrammarElement.REFRESH_RESOURCE; import static org.opensearch.sql.spark.validator.GrammarElement.RESET; -import static org.opensearch.sql.spark.validator.GrammarElement.RIGHT_OUTER_JOIN; import static org.opensearch.sql.spark.validator.GrammarElement.SET; import static org.opensearch.sql.spark.validator.GrammarElement.SHOW_FUNCTIONS; import static org.opensearch.sql.spark.validator.GrammarElement.SHOW_VIEWS; diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SecurityLakeGrammarElementValidator.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SecurityLakeGrammarElementValidator.java index 778074e353..ca8f2b5bdd 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SecurityLakeGrammarElementValidator.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SecurityLakeGrammarElementValidator.java @@ -14,7 +14,6 @@ import static org.opensearch.sql.spark.validator.GrammarElement.CREATE_FUNCTION; import static org.opensearch.sql.spark.validator.GrammarElement.CREATE_NAMESPACE; import static org.opensearch.sql.spark.validator.GrammarElement.CREATE_VIEW; -import static org.opensearch.sql.spark.validator.GrammarElement.CROSS_JOIN; import static org.opensearch.sql.spark.validator.GrammarElement.CSV_FUNCTIONS; import static org.opensearch.sql.spark.validator.GrammarElement.DESCRIBE_FUNCTION; import static org.opensearch.sql.spark.validator.GrammarElement.DESCRIBE_NAMESPACE; @@ -25,12 +24,9 @@ import static org.opensearch.sql.spark.validator.GrammarElement.DROP_NAMESPACE; import static org.opensearch.sql.spark.validator.GrammarElement.DROP_VIEW; import static org.opensearch.sql.spark.validator.GrammarElement.FILE; -import static org.opensearch.sql.spark.validator.GrammarElement.FULL_OUTER_JOIN; import static org.opensearch.sql.spark.validator.GrammarElement.HINTS; import static org.opensearch.sql.spark.validator.GrammarElement.INLINE_TABLE; import static org.opensearch.sql.spark.validator.GrammarElement.INSERT; -import static org.opensearch.sql.spark.validator.GrammarElement.LEFT_ANTI_JOIN; -import static org.opensearch.sql.spark.validator.GrammarElement.LEFT_SEMI_JOIN; import static org.opensearch.sql.spark.validator.GrammarElement.LOAD; import static org.opensearch.sql.spark.validator.GrammarElement.MANAGE_RESOURCE; import static org.opensearch.sql.spark.validator.GrammarElement.MISC_FUNCTIONS; @@ -39,7 +35,6 @@ import static org.opensearch.sql.spark.validator.GrammarElement.REFRESH_TABLE; import static org.opensearch.sql.spark.validator.GrammarElement.REPAIR_TABLE; import static org.opensearch.sql.spark.validator.GrammarElement.RESET; -import static org.opensearch.sql.spark.validator.GrammarElement.RIGHT_OUTER_JOIN; import static org.opensearch.sql.spark.validator.GrammarElement.SET; import static org.opensearch.sql.spark.validator.GrammarElement.SHOW_COLUMNS; import static org.opensearch.sql.spark.validator.GrammarElement.SHOW_CREATE_TABLE; From b0e545e721cede0c0221212162dd8fa3813c93cb Mon Sep 17 00:00:00 2001 From: Tomoyuki Morita Date: Mon, 23 Sep 2024 13:30:20 -0700 Subject: [PATCH 14/14] Fix coverage check Signed-off-by: Tomoyuki Morita --- async-query-core/build.gradle | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/async-query-core/build.gradle b/async-query-core/build.gradle index a1ff7f18b1..deba81735d 100644 --- a/async-query-core/build.gradle +++ b/async-query-core/build.gradle @@ -122,7 +122,8 @@ jacocoTestCoverageVerification { 'org.opensearch.sql.spark.flint.*', 'org.opensearch.sql.spark.flint.operation.*', 'org.opensearch.sql.spark.rest.*', - 'org.opensearch.sql.spark.utils.SQLQueryUtils.*' + 'org.opensearch.sql.spark.utils.SQLQueryUtils.*', + 'org.opensearch.sql.spark.validator.SQLQueryValidationVisitor' ] limit { counter = 'LINE' @@ -130,7 +131,7 @@ jacocoTestCoverageVerification { } limit { counter = 'BRANCH' - minimum = 0.9 + minimum = 1.0 } } }