diff --git a/.github/workflows/functional-tests-workflow.yml b/.github/workflows/functional-tests-workflow.yml index a7e82a45..56a21cb2 100644 --- a/.github/workflows/functional-tests-workflow.yml +++ b/.github/workflows/functional-tests-workflow.yml @@ -35,10 +35,10 @@ jobs: with: ref: ${{ inputs.checkout-ref }} repository: ${{ inputs.checkout-repository }} - - name: Set up Python 3.10 + - name: Set up Python uses: actions/setup-python@v4 with: - python-version: '3.10' + python-version: '3.8' - name: Install dependencies run: | make install_deps diff --git a/dbt/adapters/athena/__version__.py b/dbt/adapters/athena/__version__.py index cd70cb07..a55413d1 100644 --- a/dbt/adapters/athena/__version__.py +++ b/dbt/adapters/athena/__version__.py @@ -1 +1 @@ -version = "1.6.4" +version = "1.7.0" diff --git a/dbt/adapters/athena/impl.py b/dbt/adapters/athena/impl.py index f794c635..2200c0ca 100755 --- a/dbt/adapters/athena/impl.py +++ b/dbt/adapters/athena/impl.py @@ -58,6 +58,7 @@ from dbt.adapters.base.impl import AdapterConfig from dbt.adapters.base.relation import BaseRelation, InformationSchema from dbt.adapters.sql import SQLAdapter +from dbt.clients.agate_helper import table_from_rows from dbt.config.runtime import RuntimeConfig from dbt.contracts.graph.manifest import Manifest from dbt.contracts.graph.nodes import CompiledNode, ConstraintType @@ -527,6 +528,9 @@ def _get_one_catalog( schemas: Dict[str, Optional[Set[str]]], manifest: Manifest, ) -> agate.Table: + """ + This function is invoked by Adapter.get_catalog for each schema. + """ data_catalog = self._get_data_catalog(information_schema.path.database) data_catalog_type = get_catalog_type(data_catalog) @@ -581,6 +585,10 @@ def _get_one_catalog( return self._join_catalog_table_owners(filtered_table, manifest) def _get_catalog_schemas(self, manifest: Manifest) -> AthenaSchemaSearchMap: + """ + Get the schemas from the catalog. + It's called by the `get_catalog` method. + """ info_schema_name_map = AthenaSchemaSearchMap() nodes: Iterator[CompiledNode] = chain( [node for node in manifest.nodes.values() if (node.is_relational and not node.is_ephemeral_model)], @@ -658,6 +666,31 @@ def list_relations_without_caching(self, schema_relation: AthenaRelation) -> Lis return relations + def _get_one_catalog_by_relations( + self, + information_schema: InformationSchema, + relations: List[BaseRelation], + manifest: Manifest, + ) -> agate.Table: + """ + Overwrite of _get_one_catalog_by_relations for Athena, in order to use glue apis. + This function is invoked by Adapter.get_catalog_by_relations. + """ + _table_definitions = [] + for _rel in relations: + glue_table_definition = self.get_glue_table(_rel) + if glue_table_definition: + _table_definition = self._get_one_table_for_catalog(glue_table_definition["Table"], _rel.database) + _table_definitions.extend(_table_definition) + table = agate.Table.from_object(_table_definitions) + # picked from _catalog_filter_table, force database + schema to be strings + table_casted = table_from_rows( + table.rows, + table.column_names, + text_only_columns=["table_database", "table_schema", "table_name"], + ) + return self._join_catalog_table_owners(table_casted, manifest) + @available def swap_table(self, src_relation: AthenaRelation, target_relation: AthenaRelation) -> None: conn = self.connections.get_thread_connection() diff --git a/dbt/include/athena/macros/adapters/metadata.sql b/dbt/include/athena/macros/adapters/metadata.sql index 0ca7cdb8..ae7a187c 100644 --- a/dbt/include/athena/macros/adapters/metadata.sql +++ b/dbt/include/athena/macros/adapters/metadata.sql @@ -11,3 +11,7 @@ {% macro athena__list_relations_without_caching(schema_relation) %} {{ return(adapter.list_relations_without_caching(schema_relation)) }} {% endmacro %} + +{% macro athena__get_catalog_relations(information_schema, relations) %} + {{ return(adapter.get_catalog_by_relations(information_schema, relations)) }} +{% endmacro %} diff --git a/dev-requirements.txt b/dev-requirements.txt index 0b20006b..c5687ea7 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,7 +1,7 @@ autoflake~=1.7 black~=23.10 boto3-stubs[s3]~=1.28 -dbt-tests-adapter~=1.6.6 +dbt-tests-adapter~=1.7.0 flake8~=6.1 Flake8-pyproject~=1.2 isort~=5.11 diff --git a/setup.py b/setup.py index af6941ee..a2509409 100644 --- a/setup.py +++ b/setup.py @@ -31,7 +31,7 @@ def _get_package_version() -> str: return f'{parts["major"]}.{parts["minor"]}.{parts["patch"]}' -dbt_version = "1.6" +dbt_version = "1.7" package_version = _get_package_version() description = "The athena adapter plugin for dbt (data build tool)" @@ -55,7 +55,7 @@ def _get_package_version() -> str: # In order to control dbt-core version and package version "boto3~=1.26", "boto3-stubs[athena,glue,lakeformation,sts]~=1.26", - "dbt-core~=1.6.0", + "dbt-core~=1.7.0", "pyathena>=2.25,<4.0", "pydantic>=1.10,<3.0", "tenacity~=8.2", diff --git a/tests/conftest.py b/tests/conftest.py index e6be8531..b3d1d35f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,8 +4,9 @@ import pytest from dbt.events.base_types import EventLevel -from dbt.events.eventmgr import LineFormat, NoFilter +from dbt.events.eventmgr import LineFormat from dbt.events.functions import EVENT_MANAGER, _get_stdout_config +from dbt.events.logger import NoFilter # Import the functional fixtures as a plugin # Note: fixtures with session scope need to be local diff --git a/tests/functional/adapter/test_constraints.py b/tests/functional/adapter/test_constraints.py index 9295f681..23f91e10 100644 --- a/tests/functional/adapter/test_constraints.py +++ b/tests/functional/adapter/test_constraints.py @@ -12,7 +12,9 @@ class TestAthenaConstraintQuotedColumn(BaseConstraintQuotedColumn): def models(self): return { "my_model.sql": my_model_with_quoted_column_name_sql, - "constraints_schema.yml": model_quoted_column_schema_yml.replace("text", "string"), + # we replace text type with varchar + # do not replace text with string, because then string is replaced with a capital TEXT that leads to failure + "constraints_schema.yml": model_quoted_column_schema_yml.replace("text", "varchar"), } @pytest.fixture(scope="class") diff --git a/tests/functional/adapter/test_docs.py b/tests/functional/adapter/test_docs.py new file mode 100644 index 00000000..663d3027 --- /dev/null +++ b/tests/functional/adapter/test_docs.py @@ -0,0 +1,101 @@ +import os + +import pytest + +from dbt.contracts.results import RunStatus +from dbt.tests.adapter.basic.expected_catalog import base_expected_catalog, no_stats +from dbt.tests.adapter.basic.test_docs_generate import ( + BaseDocsGenerate, + run_and_generate, + verify_metadata, +) +from dbt.tests.util import get_artifact, run_dbt + +model_sql = """ +select 1 as id +""" + +override_macros_sql = """ +{% macro get_catalog_relations(information_schema, relations) %} + {{ return(adapter.get_catalog_by_relations(information_schema, relations)) }} +{% endmacro %} +""" + + +def custom_verify_catalog_athena(project, expected_catalog, start_time): + # get the catalog.json + catalog_path = os.path.join(project.project_root, "target", "catalog.json") + assert os.path.exists(catalog_path) + catalog = get_artifact(catalog_path) + + # verify the catalog + assert set(catalog) == {"errors", "nodes", "sources", "metadata"} + verify_metadata( + catalog["metadata"], + "https://schemas.getdbt.com/dbt/catalog/v1.json", + start_time, + ) + assert not catalog["errors"] + + for key in "nodes", "sources": + for unique_id, expected_node in expected_catalog[key].items(): + found_node = catalog[key][unique_id] + for node_key in expected_node: + assert node_key in found_node + # the value of found_node[node_key] is not exactly expected_node[node_key] + + +class TestDocsGenerate(BaseDocsGenerate): + """ + Override of BaseDocsGenerate to make it working with Athena + """ + + @pytest.fixture(scope="class") + def expected_catalog(self, project): + return base_expected_catalog( + project, + role="test", + id_type="integer", + text_type="text", + time_type="timestamp without time zone", + view_type="VIEW", + table_type="BASE TABLE", + model_stats=no_stats(), + ) + + def test_run_and_generate_no_compile(self, project, expected_catalog): + start_time = run_and_generate(project, ["--no-compile"]) + assert not os.path.exists(os.path.join(project.project_root, "target", "manifest.json")) + custom_verify_catalog_athena(project, expected_catalog, start_time) + + # Test generic "docs generate" command + def test_run_and_generate(self, project, expected_catalog): + start_time = run_and_generate(project) + custom_verify_catalog_athena(project, expected_catalog, start_time) + + # Check that assets have been copied to the target directory for use in the docs html page + assert os.path.exists(os.path.join(".", "target", "assets")) + assert os.path.exists(os.path.join(".", "target", "assets", "lorem-ipsum.txt")) + assert not os.path.exists(os.path.join(".", "target", "non-existent-assets")) + + +class TestDocsGenerateOverride: + @pytest.fixture(scope="class") + def models(self): + return {"model.sql": model_sql} + + @pytest.fixture(scope="class") + def macros(self): + return {"override_macros_sql.sql": override_macros_sql} + + def test_generate_docs( + self, + project, + ): + results = run_dbt(["run"]) + assert len(results) == 1 + + docs_generate = run_dbt(["--warn-error", "docs", "generate"]) + assert len(docs_generate._compile_results.results) == 1 + assert docs_generate._compile_results.results[0].status == RunStatus.Success + assert docs_generate.errors is None diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index 733d7f8d..22a4d24f 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -652,12 +652,56 @@ def test__get_one_catalog(self, mock_aws_service): ("awsdatacatalog", "baz", "qux", "table", None, "id", 0, "string", None, "data-engineers"), ("awsdatacatalog", "baz", "qux", "table", None, "country", 1, "string", None, "data-engineers"), ] - assert actual.column_names == expected_column_names assert len(actual.rows) == len(expected_rows) for row in actual.rows.values(): assert row.values() in expected_rows + @mock_glue + @mock_athena + @mock_sts + def test__get_one_catalog_by_relations(self, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database("foo") + mock_aws_service.create_database("quux") + mock_aws_service.create_table(database_name="foo", table_name="bar") + # we create another relations + mock_aws_service.create_table(table_name="bar", database_name="quux") + + mock_information_schema = mock.MagicMock() + mock_information_schema.path.database = "awsdatacatalog" + + self.adapter.acquire_connection("dummy") + + rel_1 = self.adapter.Relation.create( + database="awsdatacatalog", + schema="foo", + identifier="bar", + ) + + expected_column_names = ( + "table_database", + "table_schema", + "table_name", + "table_type", + "table_comment", + "column_name", + "column_index", + "column_type", + "column_comment", + "table_owner", + ) + + expected_rows = [ + ("awsdatacatalog", "foo", "bar", "table", None, "id", 0, "string", None, "data-engineers"), + ("awsdatacatalog", "foo", "bar", "table", None, "country", 1, "string", None, "data-engineers"), + ("awsdatacatalog", "foo", "bar", "table", None, "dt", 2, "date", None, "data-engineers"), + ] + + actual = self.adapter._get_one_catalog_by_relations(mock_information_schema, [rel_1], self.mock_manifest) + assert actual.column_names == expected_column_names + assert actual.rows == expected_rows + @mock_glue @mock_athena def test__get_one_catalog_shared_catalog(self, mock_aws_service):