Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrapper for Arrow Datasets & Dataset Pieces #754

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions petastorm/arrow_reader_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
import numpy as np
import pandas as pd
import pyarrow as pa
from pyarrow import parquet as pq
from pyarrow.parquet import ParquetFile

from petastorm.cache import NullCache
from petastorm.pyarrow_helpers.dataset_wrapper import PetastormPyArrowDataset
from petastorm.workers_pool import EmptyResultError
from petastorm.workers_pool.worker_base import WorkerBase

Expand Down Expand Up @@ -127,7 +127,7 @@ def process(self, piece_index, worker_predicate, shuffle_row_drop_partition):
"""

if not self._dataset:
self._dataset = pq.ParquetDataset(
self._dataset = PetastormPyArrowDataset(
self._dataset_path_or_paths,
filesystem=self._filesystem,
validate_schema=False, filters=self._arrow_filters)
Expand Down
20 changes: 10 additions & 10 deletions petastorm/etl/dataset_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from petastorm import utils
from petastorm.etl.legacy import depickle_legacy_package_name_compatible
from petastorm.fs_utils import FilesystemResolver, get_filesystem_and_path_or_paths, get_dataset_path
from petastorm.pyarrow_helpers.dataset_wrapper import PetastormPyArrowDataset, PetastormPyArrowDatasetPiece
from petastorm.unischema import Unischema

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -104,7 +105,7 @@ def materialize_dataset(spark, dataset_url, schema, row_group_size_mb=None, use_
dataset_path = get_dataset_path(urlparse(dataset_url))
filesystem = filesystem_factory()

dataset = pq.ParquetDataset(
dataset = PetastormPyArrowDataset(
dataset_path,
filesystem=filesystem,
validate_schema=False)
Expand All @@ -114,7 +115,7 @@ def materialize_dataset(spark, dataset_url, schema, row_group_size_mb=None, use_
_generate_num_row_groups_per_file(dataset, spark.sparkContext, filesystem_factory)

# Reload the dataset to take into account the new metadata
dataset = pq.ParquetDataset(
dataset = PetastormPyArrowDataset(
dataset_path,
filesystem=filesystem,
validate_schema=False)
Expand Down Expand Up @@ -285,8 +286,8 @@ def load_row_groups(dataset):
# This is not a real "piece" and we won't have row_groups_per_file recorded for it.
if row_groups_key != ".":
for row_group in range(row_groups_per_file[row_groups_key]):
rowgroups.append(pq.ParquetDatasetPiece(piece.path, open_file_func=dataset.fs.open, row_group=row_group,
partition_keys=piece.partition_keys))
rowgroups.append(PetastormPyArrowDatasetPiece(piece.path, open_file_func=dataset.fs.open,
row_group=row_group, partition_keys=piece.partition_keys))
return rowgroups


Expand Down Expand Up @@ -322,18 +323,17 @@ def _split_row_groups(dataset):
continue

for row_group in range(row_groups_per_file[relative_path]):
split_piece = pq.ParquetDatasetPiece(piece.path, open_file_func=dataset.fs.open, row_group=row_group,
partition_keys=piece.partition_keys)
split_piece = PetastormPyArrowDatasetPiece(piece.path, open_file_func=dataset.fs.open, row_group=row_group,
partition_keys=piece.partition_keys)
split_pieces.append(split_piece)

return split_pieces


def _split_piece(piece, fs_open):
metadata = piece.get_metadata()
return [pq.ParquetDatasetPiece(piece.path, open_file_func=fs_open,
row_group=row_group,
partition_keys=piece.partition_keys)
return [PetastormPyArrowDatasetPiece(piece.path, open_file_func=fs_open, row_group=row_group,
partition_keys=piece.partition_keys)
for row_group in range(metadata.num_row_groups)]


Expand Down Expand Up @@ -399,7 +399,7 @@ def get_schema_from_dataset_url(dataset_url_or_urls, hdfs_driver='libhdfs3', sto
storage_options=storage_options,
filesystem=filesystem)

dataset = pq.ParquetDataset(path_or_paths, filesystem=fs, validate_schema=False, metadata_nthreads=10)
dataset = PetastormPyArrowDataset(path_or_paths, filesystem=fs, validate_schema=False, metadata_nthreads=10)

# Get a unischema stored in the dataset metadata.
stored_schema = get_schema(dataset)
Expand Down
6 changes: 3 additions & 3 deletions petastorm/etl/metadata_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
from __future__ import print_function

import argparse
from pyarrow import parquet as pq

from petastorm.etl import dataset_metadata, rowgroup_indexing
from petastorm.fs_utils import FilesystemResolver
from petastorm.pyarrow_helpers.dataset_wrapper import PetastormPyArrowDataset

if __name__ == "__main__":

Expand Down Expand Up @@ -47,8 +47,8 @@

# Create pyarrow file system
resolver = FilesystemResolver(args.dataset_url, hdfs_driver=args.hdfs_driver)
dataset = pq.ParquetDataset(resolver.get_dataset_path(), filesystem=resolver.filesystem(),
validate_schema=False)
dataset = PetastormPyArrowDataset(resolver.get_dataset_path(), filesystem=resolver.filesystem(),
validate_schema=False)

print_all = not args.schema and not args.index
if args.schema or print_all:
Expand Down
4 changes: 2 additions & 2 deletions petastorm/etl/petastorm_generate_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
import sys
from pydoc import locate

from pyarrow import parquet as pq
from pyspark.sql import SparkSession

from petastorm.etl.dataset_metadata import materialize_dataset, get_schema, ROW_GROUPS_PER_FILE_KEY
from petastorm.etl.rowgroup_indexing import ROWGROUPS_INDEX_KEY
from petastorm.fs_utils import FilesystemResolver
from petastorm.pyarrow_helpers.dataset_wrapper import PetastormPyArrowDataset
from petastorm.unischema import Unischema
from petastorm.utils import add_to_dataset_metadata

Expand Down Expand Up @@ -63,7 +63,7 @@ def generate_petastorm_metadata(spark, dataset_url, unischema_class=None, use_su
resolver = FilesystemResolver(dataset_url, sc._jsc.hadoopConfiguration(), hdfs_driver=hdfs_driver,
user=spark.sparkContext.sparkUser())
fs = resolver.filesystem()
dataset = pq.ParquetDataset(
dataset = PetastormPyArrowDataset(
resolver.get_dataset_path(),
filesystem=fs,
validate_schema=False)
Expand Down
10 changes: 5 additions & 5 deletions petastorm/etl/rowgroup_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
import time
from collections import namedtuple

from pyarrow import parquet as pq
from six.moves import cPickle as pickle
from six.moves import range

from petastorm import utils
from petastorm.etl import dataset_metadata
from petastorm.etl.legacy import depickle_legacy_package_name_compatible
from petastorm.fs_utils import FilesystemResolver
from petastorm.pyarrow_helpers.dataset_wrapper import PetastormPyArrowDataset, PetastormPyArrowDatasetPiece

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -51,8 +51,8 @@ def build_rowgroup_index(dataset_url, spark_context, indexers, hdfs_driver='libh
# Create pyarrow file system
resolver = FilesystemResolver(dataset_url, spark_context._jsc.hadoopConfiguration(),
hdfs_driver=hdfs_driver, user=spark_context.sparkUser())
dataset = pq.ParquetDataset(resolver.get_dataset_path(), filesystem=resolver.filesystem(),
validate_schema=False)
dataset = PetastormPyArrowDataset(resolver.get_dataset_path(), filesystem=resolver.filesystem(),
validate_schema=False)

split_pieces = dataset_metadata.load_row_groups(dataset)
schema = dataset_metadata.get_schema(dataset)
Expand Down Expand Up @@ -97,8 +97,8 @@ def _index_columns(piece_info, dataset_url, partitions, indexers, schema, hdfs_d
fs = resolver.filesystem()

# Create pyarrow piece
piece = pq.ParquetDatasetPiece(piece_info.path, open_file_func=fs.open, row_group=piece_info.row_group,
partition_keys=piece_info.partition_keys)
piece = PetastormPyArrowDatasetPiece(piece_info.path, open_file_func=fs.open, row_group=piece_info.row_group,
partition_keys=piece_info.partition_keys)

# Collect column names needed for indexing
column_names = set()
Expand Down
4 changes: 2 additions & 2 deletions petastorm/py_dict_reader_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
from collections.abc import Iterable

import numpy as np
from pyarrow import parquet as pq
from pyarrow.parquet import ParquetFile

from petastorm import utils
from petastorm.cache import NullCache
from petastorm.pyarrow_helpers.dataset_wrapper import PetastormPyArrowDataset
from petastorm.workers_pool import EmptyResultError
from petastorm.workers_pool.worker_base import WorkerBase

Expand Down Expand Up @@ -133,7 +133,7 @@ def process(self, piece_index, worker_predicate, shuffle_row_drop_partition):
"""

if not self._dataset:
self._dataset = pq.ParquetDataset(
self._dataset = PetastormPyArrowDataset(
self._dataset_path,
filesystem=self._filesystem,
validate_schema=False, filters=self._arrow_filters)
Expand Down
98 changes: 98 additions & 0 deletions petastorm/pyarrow_helpers/dataset_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright (c) 2022 BlackBerry Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pyarrow import parquet as pq


class PetastormPyArrowDataset:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a comment explaining why do we need the wrapper.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good feature


def __init__(self, path_or_paths, filesystem=None, validate_schema=None, filters=None, metadata_nthreads=None,
use_new_arrow_api=False):
if use_new_arrow_api:
raise NotImplementedError("The implementation using the new PyArrow API is not yet complete")

kwargs = {}
if filesystem:
kwargs["filesystem"] = filesystem
if validate_schema is not None:
kwargs["validate_schema"] = validate_schema
if filters:
kwargs["filters"] = filters
if metadata_nthreads:
kwargs["metadata_nthreads"] = metadata_nthreads

self._legacy_dataset = pq.ParquetDataset(path_or_paths, **kwargs)

@property
def common_metadata(self):
return self._legacy_dataset.common_metadata

@property
def metadata(self):
return self._legacy_dataset.metadata

@property
def metadata_path(self):
return self._legacy_dataset.metadata_path

@property
def fs(self):
return self._legacy_dataset.fs

@property
def partitions(self):
return self._legacy_dataset.partitions

@property
def paths(self):
return self._legacy_dataset.paths

@property
def pieces(self):
if not hasattr(self, "_wrapped_pieces"):
self._wrapped_pieces = [_wrap_piece(piece) for piece in self._legacy_dataset.pieces]
return self._wrapped_pieces


def _wrap_piece(piece):
return PetastormPyArrowDatasetPiece(piece.path, open_file_func=piece.open_file_func, row_group=piece.row_group,
partition_keys=piece.partition_keys)


class PetastormPyArrowDatasetPiece:

def __init__(self, path, open_file_func, row_group, partition_keys, use_new_arrow_api=False):
if use_new_arrow_api:
raise NotImplementedError("The implementation using the new PyArrow API is not yet complete")

self._legacy_piece = pq.ParquetDatasetPiece(path, open_file_func=open_file_func, row_group=row_group,
partition_keys=partition_keys)

def get_metadata(self):
return self._legacy_piece.get_metadata()

def read(self, *, columns, partitions):
return self._legacy_piece.read(columns=columns, partitions=partitions)

@property
def path(self):
return self._legacy_piece.path

@property
def partition_keys(self):
return self._legacy_piece.partition_keys

@property
def row_group(self):
return self._legacy_piece.row_group
8 changes: 4 additions & 4 deletions petastorm/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import warnings

import six
from pyarrow import parquet as pq

from petastorm.arrow_reader_worker import ArrowReaderWorker
from petastorm.cache import NullCache
Expand All @@ -30,6 +29,7 @@
from petastorm.ngram import NGram
from petastorm.predicates import PredicateBase
from petastorm.py_dict_reader_worker import PyDictReaderWorker
from petastorm.pyarrow_helpers.dataset_wrapper import PetastormPyArrowDataset
from petastorm.reader_impl.arrow_table_serializer import ArrowTableSerializer
from petastorm.reader_impl.pickle_serializer import PickleSerializer
from petastorm.selectors import RowGroupSelectorBase
Expand Down Expand Up @@ -402,9 +402,9 @@ def __init__(self, pyarrow_filesystem, dataset_path, schema_fields=None,

self.is_batched_reader = is_batched_reader
# 1. Resolve dataset path (hdfs://, file://) and open the parquet storage (dataset)
self.dataset = pq.ParquetDataset(dataset_path, filesystem=pyarrow_filesystem,
validate_schema=False, metadata_nthreads=10,
filters=filters)
self.dataset = PetastormPyArrowDataset(dataset_path, filesystem=pyarrow_filesystem,
validate_schema=False, metadata_nthreads=10,
filters=filters)

stored_schema = infer_or_load_unischema(self.dataset)

Expand Down