diff --git a/airflow/contrib/hooks/gcs_hook.py b/airflow/contrib/hooks/gcs_hook.py index 894cc7aa826af..0d11c12df42d6 100644 --- a/airflow/contrib/hooks/gcs_hook.py +++ b/airflow/contrib/hooks/gcs_hook.py @@ -90,6 +90,57 @@ def copy(self, source_bucket, source_object, destination_bucket=None, return False raise + def rewrite(self, source_bucket, source_object, destination_bucket, + destination_object=None): + """ + Has the same functionality as copy, except that will work on files + over 5 TB, as well as when copying between locations and/or storage + classes. + + destination_object can be omitted, in which case source_object is used. + + :param source_bucket: The bucket of the object to copy from. + :type source_bucket: string + :param source_object: The object to copy. + :type source_object: string + :param destination_bucket: The destination of the object to copied to. + :type destination_bucket: string + :param destination_object: The (renamed) path of the object if given. + Can be omitted; then the same name is used. + """ + destination_object = destination_object or source_object + if (source_bucket == destination_bucket and + source_object == destination_object): + raise ValueError( + 'Either source/destination bucket or source/destination object ' + 'must be different, not both the same: bucket=%s, object=%s' % + (source_bucket, source_object)) + if not source_bucket or not source_object: + raise ValueError('source_bucket and source_object cannot be empty.') + + service = self.get_conn() + request_count = 1 + try: + result = service.objects() \ + .rewrite(sourceBucket=source_bucket, sourceObject=source_object, + destinationBucket=destination_bucket, + destinationObject=destination_object, body='') \ + .execute() + self.log.info('Rewrite request #%s: %s', request_count, result) + while not result['done']: + request_count += 1 + result = service.objects() \ + .rewrite(sourceBucket=source_bucket, sourceObject=source_object, + destinationBucket=destination_bucket, + destinationObject=destination_object, + rewriteToken=result['rewriteToken'], body='') \ + .execute() + self.log.info('Rewrite request #%s: %s', request_count, result) + return True + except errors.HttpError as ex: + if ex.resp['status'] == '404': + return False + raise # pylint:disable=redefined-builtin def download(self, bucket, object, filename=None):