From 1e6c0e96bd15f1fa79cac21e96789d0182d5e9d4 Mon Sep 17 00:00:00 2001 From: Giovanni Lanzani Date: Tue, 8 May 2018 12:41:51 +0200 Subject: [PATCH] [AIRFLOW-2427] Add tests to named hive sensor Closes #3323 from gglanzani/AIRFLOW-2427 (cherry picked from commit b18b437c216b0c4b3ffb41e4934f3c2dd966c14b) Signed-off-by: Fokko Driesprong --- .../sensors/named_hive_partition_sensor.py | 68 +++++---- .../test_named_hive_partition_sensor.py | 130 ++++++++++++++++++ 2 files changed, 169 insertions(+), 29 deletions(-) create mode 100644 tests/sensors/test_named_hive_partition_sensor.py diff --git a/airflow/sensors/named_hive_partition_sensor.py b/airflow/sensors/named_hive_partition_sensor.py index a42a3608a4f892..4a076a3dd6870c 100644 --- a/airflow/sensors/named_hive_partition_sensor.py +++ b/airflow/sensors/named_hive_partition_sensor.py @@ -7,9 +7,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -48,6 +48,7 @@ def __init__(self, partition_names, metastore_conn_id='metastore_default', poke_interval=60 * 3, + hook=None, *args, **kwargs): super(NamedHivePartitionSensor, self).__init__( @@ -58,37 +59,46 @@ def __init__(self, self.metastore_conn_id = metastore_conn_id self.partition_names = partition_names - self.next_poke_idx = 0 - - @classmethod - def parse_partition_name(self, partition): - try: - schema, table_partition = partition.split('.', 1) - table, partition = table_partition.split('/', 1) - return schema, table, partition - except ValueError as e: - raise ValueError('Could not parse ' + partition) - - def poke(self, context): - if not hasattr(self, 'hook'): + self.hook = hook + if self.hook and metastore_conn_id != 'metastore_default': + self.log.warning('A hook was passed but a non default' + 'metastore_conn_id=' + '{} was used'.format(metastore_conn_id)) + + @staticmethod + def parse_partition_name(partition): + first_split = partition.split('.', 1) + if len(first_split) == 1: + schema = 'default' + table_partition = max(first_split) # poor man first + else: + schema, table_partition = first_split + second_split = table_partition.split('/', 1) + if len(second_split) == 1: + raise ValueError('Could not parse ' + partition + + 'into table, partition') + else: + table, partition = second_split + return schema, table, partition + + def poke_partition(self, partition): + if not self.hook: from airflow.hooks.hive_hooks import HiveMetastoreHook self.hook = HiveMetastoreHook( metastore_conn_id=self.metastore_conn_id) - def poke_partition(partition): - - schema, table, partition = self.parse_partition_name(partition) + schema, table, partition = self.parse_partition_name(partition) - self.log.info( - 'Poking for {schema}.{table}/{partition}'.format(**locals()) - ) - return self.hook.check_for_named_partition( - schema, table, partition) + self.log.info( + 'Poking for {schema}.{table}/{partition}'.format(**locals()) + ) + return self.hook.check_for_named_partition( + schema, table, partition) - while self.next_poke_idx < len(self.partition_names): - if poke_partition(self.partition_names[self.next_poke_idx]): - self.next_poke_idx += 1 - else: - return False + def poke(self, context): - return True + self.partition_names = [ + partition_name for partition_name in self.partition_names + if not self.poke_partition(partition_name) + ] + return not self.partition_names diff --git a/tests/sensors/test_named_hive_partition_sensor.py b/tests/sensors/test_named_hive_partition_sensor.py new file mode 100644 index 00000000000000..4fef3e0f349548 --- /dev/null +++ b/tests/sensors/test_named_hive_partition_sensor.py @@ -0,0 +1,130 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import random +import unittest +from datetime import timedelta + +from airflow import configuration, DAG, operators +from airflow.sensors.named_hive_partition_sensor import NamedHivePartitionSensor +from airflow.utils.timezone import datetime +from airflow.hooks.hive_hooks import HiveMetastoreHook + +configuration.load_test_config() + +DEFAULT_DATE = datetime(2015, 1, 1) +DEFAULT_DATE_ISO = DEFAULT_DATE.isoformat() +DEFAULT_DATE_DS = DEFAULT_DATE_ISO[:10] + + +class NamedHivePartitionSensorTests(unittest.TestCase): + def setUp(self): + configuration.load_test_config() + args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} + self.dag = DAG('test_dag_id', default_args=args) + self.next_day = (DEFAULT_DATE + + timedelta(days=1)).isoformat()[:10] + self.database = 'airflow' + self.partition_by = 'ds' + self.table = 'static_babynames_partitioned' + self.hql = """ + CREATE DATABASE IF NOT EXISTS {{ params.database }}; + USE {{ params.database }}; + DROP TABLE IF EXISTS {{ params.table }}; + CREATE TABLE IF NOT EXISTS {{ params.table }} ( + state string, + year string, + name string, + gender string, + num int) + PARTITIONED BY ({{ params.partition_by }} string); + ALTER TABLE {{ params.table }} + ADD PARTITION({{ params.partition_by }}='{{ ds }}'); + """ + self.hook = HiveMetastoreHook() + t = operators.hive_operator.HiveOperator( + task_id='HiveHook_' + str(random.randint(1, 10000)), + params={ + 'database': self.database, + 'table': self.table, + 'partition_by': self.partition_by + }, + hive_cli_conn_id='beeline_default', + hql=self.hql, dag=self.dag) + t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, + ignore_ti_state=True) + + def tearDown(self): + hook = HiveMetastoreHook() + with hook.get_conn() as metastore: + metastore.drop_table(self.database, self.table, deleteData=True) + + def test_parse_partition_name_correct(self): + schema = 'default' + table = 'users' + partition = 'ds=2016-01-01/state=IT' + name = '{schema}.{table}/{partition}'.format(schema=schema, + table=table, + partition=partition) + parsed_schema, parsed_table, parsed_partition = ( + NamedHivePartitionSensor.parse_partition_name(name) + ) + self.assertEqual(schema, parsed_schema) + self.assertEqual(table, parsed_table) + self.assertEqual(partition, parsed_partition) + + def test_parse_partition_name_incorrect(self): + name = 'incorrect.name' + with self.assertRaises(ValueError): + NamedHivePartitionSensor.parse_partition_name(name) + + def test_parse_partition_name_default(self): + table = 'users' + partition = 'ds=2016-01-01/state=IT' + name = '{table}/{partition}'.format(table=table, + partition=partition) + parsed_schema, parsed_table, parsed_partition = ( + NamedHivePartitionSensor.parse_partition_name(name) + ) + self.assertEqual('default', parsed_schema) + self.assertEqual(table, parsed_table) + self.assertEqual(partition, parsed_partition) + + def test_poke_existing(self): + partitions = ["{}.{}/{}={}".format(self.database, + self.table, + self.partition_by, + DEFAULT_DATE_DS)] + sensor = NamedHivePartitionSensor(partition_names=partitions, + task_id='test_poke_existing', + poke_interval=1, + hook=self.hook, + dag=self.dag) + self.assertTrue(sensor.poke(None)) + + def test_poke_non_existing(self): + partitions = ["{}.{}/{}={}".format(self.database, + self.table, + self.partition_by, + self.next_day)] + sensor = NamedHivePartitionSensor(partition_names=partitions, + task_id='test_poke_non_existing', + poke_interval=1, + hook=self.hook, + dag=self.dag) + self.assertFalse(sensor.poke(None))