From cc99b936eaadc99ca8906c27c869abde35878a75 Mon Sep 17 00:00:00 2001 From: Dan Lidral-Porter Date: Mon, 18 Apr 2022 14:09:37 -0700 Subject: [PATCH] Add wrappers around arrow legacy dataset classes Add a wrapper around Arrow's ParquetDataset legacy class, to allow us to re-implement that class's API using Arrow's new dataset class. Add a wrapper around Arrow's ParquetDatasetPiece legacy class, to allow us to re-implement that class's API using Arrow's new dataset "Fragment" class. --- petastorm/arrow_reader_worker.py | 4 +- petastorm/etl/dataset_metadata.py | 20 ++-- petastorm/etl/metadata_util.py | 6 +- petastorm/etl/petastorm_generate_metadata.py | 4 +- petastorm/etl/rowgroup_indexing.py | 10 +- petastorm/py_dict_reader_worker.py | 4 +- petastorm/pyarrow_helpers/dataset_wrapper.py | 98 ++++++++++++++++++++ petastorm/reader.py | 8 +- 8 files changed, 126 insertions(+), 28 deletions(-) create mode 100644 petastorm/pyarrow_helpers/dataset_wrapper.py diff --git a/petastorm/arrow_reader_worker.py b/petastorm/arrow_reader_worker.py index 9e6ffe510..3789241d8 100644 --- a/petastorm/arrow_reader_worker.py +++ b/petastorm/arrow_reader_worker.py @@ -19,10 +19,10 @@ import numpy as np import pandas as pd import pyarrow as pa -from pyarrow import parquet as pq from pyarrow.parquet import ParquetFile from petastorm.cache import NullCache +from petastorm.pyarrow_helpers.dataset_wrapper import PetastormPyArrowDataset from petastorm.workers_pool import EmptyResultError from petastorm.workers_pool.worker_base import WorkerBase @@ -127,7 +127,7 @@ def process(self, piece_index, worker_predicate, shuffle_row_drop_partition): """ if not self._dataset: - self._dataset = pq.ParquetDataset( + self._dataset = PetastormPyArrowDataset( self._dataset_path_or_paths, filesystem=self._filesystem, validate_schema=False, filters=self._arrow_filters) diff --git a/petastorm/etl/dataset_metadata.py b/petastorm/etl/dataset_metadata.py index 77c192ca3..c46df35be 100644 --- a/petastorm/etl/dataset_metadata.py +++ b/petastorm/etl/dataset_metadata.py @@ -27,6 +27,7 @@ from petastorm import utils from petastorm.etl.legacy import depickle_legacy_package_name_compatible from petastorm.fs_utils import FilesystemResolver, get_filesystem_and_path_or_paths, get_dataset_path +from petastorm.pyarrow_helpers.dataset_wrapper import PetastormPyArrowDataset, PetastormPyArrowDatasetPiece from petastorm.unischema import Unischema logger = logging.getLogger(__name__) @@ -104,7 +105,7 @@ def materialize_dataset(spark, dataset_url, schema, row_group_size_mb=None, use_ dataset_path = get_dataset_path(urlparse(dataset_url)) filesystem = filesystem_factory() - dataset = pq.ParquetDataset( + dataset = PetastormPyArrowDataset( dataset_path, filesystem=filesystem, validate_schema=False) @@ -114,7 +115,7 @@ def materialize_dataset(spark, dataset_url, schema, row_group_size_mb=None, use_ _generate_num_row_groups_per_file(dataset, spark.sparkContext, filesystem_factory) # Reload the dataset to take into account the new metadata - dataset = pq.ParquetDataset( + dataset = PetastormPyArrowDataset( dataset_path, filesystem=filesystem, validate_schema=False) @@ -285,8 +286,8 @@ def load_row_groups(dataset): # This is not a real "piece" and we won't have row_groups_per_file recorded for it. if row_groups_key != ".": for row_group in range(row_groups_per_file[row_groups_key]): - rowgroups.append(pq.ParquetDatasetPiece(piece.path, open_file_func=dataset.fs.open, row_group=row_group, - partition_keys=piece.partition_keys)) + rowgroups.append(PetastormPyArrowDatasetPiece(piece.path, open_file_func=dataset.fs.open, + row_group=row_group, partition_keys=piece.partition_keys)) return rowgroups @@ -322,8 +323,8 @@ def _split_row_groups(dataset): continue for row_group in range(row_groups_per_file[relative_path]): - split_piece = pq.ParquetDatasetPiece(piece.path, open_file_func=dataset.fs.open, row_group=row_group, - partition_keys=piece.partition_keys) + split_piece = PetastormPyArrowDatasetPiece(piece.path, open_file_func=dataset.fs.open, row_group=row_group, + partition_keys=piece.partition_keys) split_pieces.append(split_piece) return split_pieces @@ -331,9 +332,8 @@ def _split_row_groups(dataset): def _split_piece(piece, fs_open): metadata = piece.get_metadata() - return [pq.ParquetDatasetPiece(piece.path, open_file_func=fs_open, - row_group=row_group, - partition_keys=piece.partition_keys) + return [PetastormPyArrowDatasetPiece(piece.path, open_file_func=fs_open, row_group=row_group, + partition_keys=piece.partition_keys) for row_group in range(metadata.num_row_groups)] @@ -399,7 +399,7 @@ def get_schema_from_dataset_url(dataset_url_or_urls, hdfs_driver='libhdfs3', sto storage_options=storage_options, filesystem=filesystem) - dataset = pq.ParquetDataset(path_or_paths, filesystem=fs, validate_schema=False, metadata_nthreads=10) + dataset = PetastormPyArrowDataset(path_or_paths, filesystem=fs, validate_schema=False, metadata_nthreads=10) # Get a unischema stored in the dataset metadata. stored_schema = get_schema(dataset) diff --git a/petastorm/etl/metadata_util.py b/petastorm/etl/metadata_util.py index 8530481a8..c9c502db7 100644 --- a/petastorm/etl/metadata_util.py +++ b/petastorm/etl/metadata_util.py @@ -16,10 +16,10 @@ from __future__ import print_function import argparse -from pyarrow import parquet as pq from petastorm.etl import dataset_metadata, rowgroup_indexing from petastorm.fs_utils import FilesystemResolver +from petastorm.pyarrow_helpers.dataset_wrapper import PetastormPyArrowDataset if __name__ == "__main__": @@ -47,8 +47,8 @@ # Create pyarrow file system resolver = FilesystemResolver(args.dataset_url, hdfs_driver=args.hdfs_driver) - dataset = pq.ParquetDataset(resolver.get_dataset_path(), filesystem=resolver.filesystem(), - validate_schema=False) + dataset = PetastormPyArrowDataset(resolver.get_dataset_path(), filesystem=resolver.filesystem(), + validate_schema=False) print_all = not args.schema and not args.index if args.schema or print_all: diff --git a/petastorm/etl/petastorm_generate_metadata.py b/petastorm/etl/petastorm_generate_metadata.py index b33768dda..ca2a4e18c 100644 --- a/petastorm/etl/petastorm_generate_metadata.py +++ b/petastorm/etl/petastorm_generate_metadata.py @@ -18,12 +18,12 @@ import sys from pydoc import locate -from pyarrow import parquet as pq from pyspark.sql import SparkSession from petastorm.etl.dataset_metadata import materialize_dataset, get_schema, ROW_GROUPS_PER_FILE_KEY from petastorm.etl.rowgroup_indexing import ROWGROUPS_INDEX_KEY from petastorm.fs_utils import FilesystemResolver +from petastorm.pyarrow_helpers.dataset_wrapper import PetastormPyArrowDataset from petastorm.unischema import Unischema from petastorm.utils import add_to_dataset_metadata @@ -63,7 +63,7 @@ def generate_petastorm_metadata(spark, dataset_url, unischema_class=None, use_su resolver = FilesystemResolver(dataset_url, sc._jsc.hadoopConfiguration(), hdfs_driver=hdfs_driver, user=spark.sparkContext.sparkUser()) fs = resolver.filesystem() - dataset = pq.ParquetDataset( + dataset = PetastormPyArrowDataset( resolver.get_dataset_path(), filesystem=fs, validate_schema=False) diff --git a/petastorm/etl/rowgroup_indexing.py b/petastorm/etl/rowgroup_indexing.py index f0f8b067f..5f3a5e14a 100644 --- a/petastorm/etl/rowgroup_indexing.py +++ b/petastorm/etl/rowgroup_indexing.py @@ -16,7 +16,6 @@ import time from collections import namedtuple -from pyarrow import parquet as pq from six.moves import cPickle as pickle from six.moves import range @@ -24,6 +23,7 @@ from petastorm.etl import dataset_metadata from petastorm.etl.legacy import depickle_legacy_package_name_compatible from petastorm.fs_utils import FilesystemResolver +from petastorm.pyarrow_helpers.dataset_wrapper import PetastormPyArrowDataset, PetastormPyArrowDatasetPiece logger = logging.getLogger(__name__) @@ -51,8 +51,8 @@ def build_rowgroup_index(dataset_url, spark_context, indexers, hdfs_driver='libh # Create pyarrow file system resolver = FilesystemResolver(dataset_url, spark_context._jsc.hadoopConfiguration(), hdfs_driver=hdfs_driver, user=spark_context.sparkUser()) - dataset = pq.ParquetDataset(resolver.get_dataset_path(), filesystem=resolver.filesystem(), - validate_schema=False) + dataset = PetastormPyArrowDataset(resolver.get_dataset_path(), filesystem=resolver.filesystem(), + validate_schema=False) split_pieces = dataset_metadata.load_row_groups(dataset) schema = dataset_metadata.get_schema(dataset) @@ -97,8 +97,8 @@ def _index_columns(piece_info, dataset_url, partitions, indexers, schema, hdfs_d fs = resolver.filesystem() # Create pyarrow piece - piece = pq.ParquetDatasetPiece(piece_info.path, open_file_func=fs.open, row_group=piece_info.row_group, - partition_keys=piece_info.partition_keys) + piece = PetastormPyArrowDatasetPiece(piece_info.path, open_file_func=fs.open, row_group=piece_info.row_group, + partition_keys=piece_info.partition_keys) # Collect column names needed for indexing column_names = set() diff --git a/petastorm/py_dict_reader_worker.py b/petastorm/py_dict_reader_worker.py index a9b01f288..e44a2ae1e 100644 --- a/petastorm/py_dict_reader_worker.py +++ b/petastorm/py_dict_reader_worker.py @@ -18,11 +18,11 @@ from collections.abc import Iterable import numpy as np -from pyarrow import parquet as pq from pyarrow.parquet import ParquetFile from petastorm import utils from petastorm.cache import NullCache +from petastorm.pyarrow_helpers.dataset_wrapper import PetastormPyArrowDataset from petastorm.workers_pool import EmptyResultError from petastorm.workers_pool.worker_base import WorkerBase @@ -133,7 +133,7 @@ def process(self, piece_index, worker_predicate, shuffle_row_drop_partition): """ if not self._dataset: - self._dataset = pq.ParquetDataset( + self._dataset = PetastormPyArrowDataset( self._dataset_path, filesystem=self._filesystem, validate_schema=False, filters=self._arrow_filters) diff --git a/petastorm/pyarrow_helpers/dataset_wrapper.py b/petastorm/pyarrow_helpers/dataset_wrapper.py new file mode 100644 index 000000000..1818c4f6b --- /dev/null +++ b/petastorm/pyarrow_helpers/dataset_wrapper.py @@ -0,0 +1,98 @@ +# Copyright (c) 2022 BlackBerry Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyarrow import parquet as pq + + +class PetastormPyArrowDataset: + + def __init__(self, path_or_paths, filesystem=None, validate_schema=None, filters=None, metadata_nthreads=None, + use_new_arrow_api=False): + if use_new_arrow_api: + raise NotImplementedError("The implementation using the new PyArrow API is not yet complete") + + kwargs = {} + if filesystem: + kwargs["filesystem"] = filesystem + if validate_schema is not None: + kwargs["validate_schema"] = validate_schema + if filters: + kwargs["filters"] = filters + if metadata_nthreads: + kwargs["metadata_nthreads"] = metadata_nthreads + + self._legacy_dataset = pq.ParquetDataset(path_or_paths, **kwargs) + + @property + def common_metadata(self): + return self._legacy_dataset.common_metadata + + @property + def metadata(self): + return self._legacy_dataset.metadata + + @property + def metadata_path(self): + return self._legacy_dataset.metadata_path + + @property + def fs(self): + return self._legacy_dataset.fs + + @property + def partitions(self): + return self._legacy_dataset.partitions + + @property + def paths(self): + return self._legacy_dataset.paths + + @property + def pieces(self): + if not hasattr(self, "_wrapped_pieces"): + self._wrapped_pieces = [_wrap_piece(piece) for piece in self._legacy_dataset.pieces] + return self._wrapped_pieces + + +def _wrap_piece(piece): + return PetastormPyArrowDatasetPiece(piece.path, open_file_func=piece.open_file_func, row_group=piece.row_group, + partition_keys=piece.partition_keys) + + +class PetastormPyArrowDatasetPiece: + + def __init__(self, path, open_file_func, row_group, partition_keys, use_new_arrow_api=False): + if use_new_arrow_api: + raise NotImplementedError("The implementation using the new PyArrow API is not yet complete") + + self._legacy_piece = pq.ParquetDatasetPiece(path, open_file_func=open_file_func, row_group=row_group, + partition_keys=partition_keys) + + def get_metadata(self): + return self._legacy_piece.get_metadata() + + def read(self, *, columns, partitions): + return self._legacy_piece.read(columns=columns, partitions=partitions) + + @property + def path(self): + return self._legacy_piece.path + + @property + def partition_keys(self): + return self._legacy_piece.partition_keys + + @property + def row_group(self): + return self._legacy_piece.row_group diff --git a/petastorm/reader.py b/petastorm/reader.py index 48ce9fd4c..1699f177c 100644 --- a/petastorm/reader.py +++ b/petastorm/reader.py @@ -17,7 +17,6 @@ import warnings import six -from pyarrow import parquet as pq from petastorm.arrow_reader_worker import ArrowReaderWorker from petastorm.cache import NullCache @@ -30,6 +29,7 @@ from petastorm.ngram import NGram from petastorm.predicates import PredicateBase from petastorm.py_dict_reader_worker import PyDictReaderWorker +from petastorm.pyarrow_helpers.dataset_wrapper import PetastormPyArrowDataset from petastorm.reader_impl.arrow_table_serializer import ArrowTableSerializer from petastorm.reader_impl.pickle_serializer import PickleSerializer from petastorm.selectors import RowGroupSelectorBase @@ -402,9 +402,9 @@ def __init__(self, pyarrow_filesystem, dataset_path, schema_fields=None, self.is_batched_reader = is_batched_reader # 1. Resolve dataset path (hdfs://, file://) and open the parquet storage (dataset) - self.dataset = pq.ParquetDataset(dataset_path, filesystem=pyarrow_filesystem, - validate_schema=False, metadata_nthreads=10, - filters=filters) + self.dataset = PetastormPyArrowDataset(dataset_path, filesystem=pyarrow_filesystem, + validate_schema=False, metadata_nthreads=10, + filters=filters) stored_schema = infer_or_load_unischema(self.dataset)