Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a collate_lists_fn #772

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 22 additions & 16 deletions petastorm/arrow_reader_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,30 @@
from pyarrow.parquet import ParquetFile

from petastorm.cache import NullCache
from petastorm.unischema import Unischema
from petastorm.workers_pool import EmptyResultError
from petastorm.workers_pool.worker_base import WorkerBase
from typing import Callable, List


def default_list_collate_fn(column_name: str, schema: Unischema, list_of_lists: List[List[np.ndarray]]):
try:
col_data = np.vstack(list_of_lists.tolist())
shape = schema.fields[column_name].shape
if len(shape) > 1:
col_data = col_data.reshape((len(list_of_lists),) + shape)
return col_data

except ValueError:
raise RuntimeError('Length of all values in column \'{}\' are expected to be the same length. '
'Got the following set of lengths: \'{}\''
.format(column_name,
', '.join(str(value.shape[0]) for value in list_of_lists)))


class ArrowReaderWorkerResultsQueueReader(object):
def __init__(self):
pass
def __init__(self, collate_lists_fn: Callable):
self._collate_fn = collate_lists_fn or default_list_collate_fn

@property
def batched_output(self):
Expand Down Expand Up @@ -66,18 +83,7 @@ def read_next(self, workers_pool, schema, ngram):
elif pa.types.is_list(column.type):
# Assuming all lists are of the same length, hence we can collate them into a matrix
list_of_lists = column_as_numpy
try:
col_data = np.vstack(list_of_lists.tolist())
shape = schema.fields[column_name].shape
if len(shape) > 1:
col_data = col_data.reshape((len(list_of_lists),) + shape)
result_dict[column_name] = col_data

except ValueError:
raise RuntimeError('Length of all values in column \'{}\' are expected to be the same length. '
'Got the following set of lengths: \'{}\''
.format(column_name,
', '.join(str(value.shape[0]) for value in list_of_lists)))
result_dict[column_name] = self._collate_fn(column_name, schema, list_of_lists)
else:
result_dict[column_name] = column_as_numpy

Expand Down Expand Up @@ -111,8 +117,8 @@ def __init__(self, worker_id, publish_func, args):
self._dataset = None

@staticmethod
def new_results_queue_reader():
return ArrowReaderWorkerResultsQueueReader()
def new_results_queue_reader(collate_fn: Callable):
return ArrowReaderWorkerResultsQueueReader(collate_fn)

# pylint: disable=arguments-differ
def process(self, piece_index, worker_predicate, shuffle_row_drop_partition):
Expand Down
5 changes: 4 additions & 1 deletion petastorm/py_dict_reader_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import hashlib
import threading
from collections.abc import Iterable
from typing import Callable

import numpy as np
from pyarrow import parquet as pq
Expand Down Expand Up @@ -117,7 +118,9 @@ def __init__(self, worker_id, publish_func, args):
self._dataset = None

@staticmethod
def new_results_queue_reader():
def new_results_queue_reader(collate_lists_fn: Callable):
if collate_lists_fn is not None:
raise "PyDictReaderWorker can not collate records"
return PyDictReaderWorkerResultsQueueReader()

# pylint: disable=arguments-differ
Expand Down
12 changes: 8 additions & 4 deletions petastorm/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def make_reader(dataset_url,
return Reader(filesystem, dataset_path,
worker_class=PyDictReaderWorker,
is_batched_reader=False,
collate_lists_fn=None,
**kwargs)
except PetastormMetadataError as e:
logger.error('Unexpected exception: %s', str(e))
Expand All @@ -219,7 +220,8 @@ def make_batch_reader(dataset_url_or_urls,
filters=None,
storage_options=None,
zmq_copy_buffers=True,
filesystem=None):
filesystem=None,
collate_lists_fn=None):
"""
Creates an instance of Reader for reading batches out of a non-Petastorm Parquet store.

Expand Down Expand Up @@ -339,7 +341,8 @@ def make_batch_reader(dataset_url_or_urls,
cache=cache,
transform_spec=transform_spec,
is_batched_reader=True,
filters=filters)
filters=filters,
collate_lists_fn=collate_lists_fn)


class Reader(object):
Expand All @@ -353,7 +356,8 @@ def __init__(self, pyarrow_filesystem, dataset_path, schema_fields=None,
shuffle_row_drop_partitions=1,
predicate=None, rowgroup_selector=None, reader_pool=None, num_epochs=1,
cur_shard=None, shard_count=None, cache=None, worker_class=None,
transform_spec=None, is_batched_reader=False, filters=None, shard_seed=None):
transform_spec=None, is_batched_reader=False, filters=None, shard_seed=None,
collate_lists_fn=None):
"""Initializes a reader object.

:param pyarrow_filesystem: An instance of ``pyarrow.FileSystem`` that will be used. If not specified,
Expand Down Expand Up @@ -429,7 +433,7 @@ def __init__(self, pyarrow_filesystem, dataset_path, schema_fields=None,

# By default, use original method of working with list of dictionaries and not arrow tables
worker_class = worker_class or PyDictReaderWorker
self._results_queue_reader = worker_class.new_results_queue_reader()
self._results_queue_reader = worker_class.new_results_queue_reader(collate_lists_fn)

if self.ngram and not self.ngram.timestamp_overlap and shuffle_row_drop_partitions > 1:
raise NotImplementedError('Using timestamp_overlap=False is not implemented with'
Expand Down
22 changes: 21 additions & 1 deletion petastorm/tests/test_parquet_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
# pylint: disable=unnecessary-lambda
from petastorm.tests.test_common import create_test_scalar_dataset
from petastorm.transform import TransformSpec
from petastorm.unischema import UnischemaField
from petastorm.unischema import UnischemaField, Unischema

_D = [lambda url, **kwargs: make_batch_reader(url, reader_pool_type='dummy', **kwargs)]

Expand Down Expand Up @@ -258,3 +258,23 @@ def test_random_seed(scalar_dataset):
results.append(actual_row_ids)
# Shuffled results are expected to be same
np.testing.assert_equal(results[0], results[1])

@pytest.mark.parametrize('reader_factory', _D + _TP)
def test_read_with_collate(reader_factory, tmp_path):
data = pd.DataFrame({"str": ["a", "bc"], "varlen_nums": [[1], [3, 4]]})
path = tmp_path / 'data'
url = f"file:///{path}"
data.to_parquet(path)

def collate_lists_fn(column_name: str, schema: Unischema, values):
max_len = max(map(len, values))
result = np.asarray([np.pad(v, (0, max_len - len(v)), 'constant', constant_values=0) for v in values])
return result

with reader_factory(url, collate_lists_fn=collate_lists_fn) as reader:
actual = list(reader)

assert len(actual) == 1
np.testing.assert_equal(actual[0].varlen_nums, [[1, 0], [1, 2]])
np.testing.assert_equal(actual[0].str, ["a", "bc"])