diff --git a/petastorm/arrow_reader_worker.py b/petastorm/arrow_reader_worker.py index fedad01d..95608b30 100644 --- a/petastorm/arrow_reader_worker.py +++ b/petastorm/arrow_reader_worker.py @@ -23,13 +23,30 @@ from pyarrow.parquet import ParquetFile from petastorm.cache import NullCache +from petastorm.unischema import Unischema from petastorm.workers_pool import EmptyResultError from petastorm.workers_pool.worker_base import WorkerBase +from typing import Callable, List + + +def default_list_collate_fn(column_name: str, schema: Unischema, list_of_lists: List[List[np.ndarray]]): + try: + col_data = np.vstack(list_of_lists.tolist()) + shape = schema.fields[column_name].shape + if len(shape) > 1: + col_data = col_data.reshape((len(list_of_lists),) + shape) + return col_data + + except ValueError: + raise RuntimeError('Length of all values in column \'{}\' are expected to be the same length. ' + 'Got the following set of lengths: \'{}\'' + .format(column_name, + ', '.join(str(value.shape[0]) for value in list_of_lists))) class ArrowReaderWorkerResultsQueueReader(object): - def __init__(self): - pass + def __init__(self, collate_lists_fn: Callable): + self._collate_fn = collate_lists_fn or default_list_collate_fn @property def batched_output(self): @@ -66,18 +83,7 @@ def read_next(self, workers_pool, schema, ngram): elif pa.types.is_list(column.type): # Assuming all lists are of the same length, hence we can collate them into a matrix list_of_lists = column_as_numpy - try: - col_data = np.vstack(list_of_lists.tolist()) - shape = schema.fields[column_name].shape - if len(shape) > 1: - col_data = col_data.reshape((len(list_of_lists),) + shape) - result_dict[column_name] = col_data - - except ValueError: - raise RuntimeError('Length of all values in column \'{}\' are expected to be the same length. ' - 'Got the following set of lengths: \'{}\'' - .format(column_name, - ', '.join(str(value.shape[0]) for value in list_of_lists))) + result_dict[column_name] = self._collate_fn(column_name, schema, list_of_lists) else: result_dict[column_name] = column_as_numpy @@ -111,8 +117,8 @@ def __init__(self, worker_id, publish_func, args): self._dataset = None @staticmethod - def new_results_queue_reader(): - return ArrowReaderWorkerResultsQueueReader() + def new_results_queue_reader(collate_fn: Callable): + return ArrowReaderWorkerResultsQueueReader(collate_fn) # pylint: disable=arguments-differ def process(self, piece_index, worker_predicate, shuffle_row_drop_partition): diff --git a/petastorm/py_dict_reader_worker.py b/petastorm/py_dict_reader_worker.py index 6c555797..4b8db193 100644 --- a/petastorm/py_dict_reader_worker.py +++ b/petastorm/py_dict_reader_worker.py @@ -16,6 +16,7 @@ import hashlib import threading from collections.abc import Iterable +from typing import Callable import numpy as np from pyarrow import parquet as pq @@ -117,7 +118,9 @@ def __init__(self, worker_id, publish_func, args): self._dataset = None @staticmethod - def new_results_queue_reader(): + def new_results_queue_reader(collate_lists_fn: Callable): + if collate_lists_fn is not None: + raise "PyDictReaderWorker can not collate records" return PyDictReaderWorkerResultsQueueReader() # pylint: disable=arguments-differ diff --git a/petastorm/reader.py b/petastorm/reader.py index 082f8e63..459b0546 100644 --- a/petastorm/reader.py +++ b/petastorm/reader.py @@ -194,6 +194,7 @@ def make_reader(dataset_url, return Reader(filesystem, dataset_path, worker_class=PyDictReaderWorker, is_batched_reader=False, + collate_lists_fn=None, **kwargs) except PetastormMetadataError as e: logger.error('Unexpected exception: %s', str(e)) @@ -219,7 +220,8 @@ def make_batch_reader(dataset_url_or_urls, filters=None, storage_options=None, zmq_copy_buffers=True, - filesystem=None): + filesystem=None, + collate_lists_fn=None): """ Creates an instance of Reader for reading batches out of a non-Petastorm Parquet store. @@ -339,7 +341,8 @@ def make_batch_reader(dataset_url_or_urls, cache=cache, transform_spec=transform_spec, is_batched_reader=True, - filters=filters) + filters=filters, + collate_lists_fn=collate_lists_fn) class Reader(object): @@ -353,7 +356,8 @@ def __init__(self, pyarrow_filesystem, dataset_path, schema_fields=None, shuffle_row_drop_partitions=1, predicate=None, rowgroup_selector=None, reader_pool=None, num_epochs=1, cur_shard=None, shard_count=None, cache=None, worker_class=None, - transform_spec=None, is_batched_reader=False, filters=None, shard_seed=None): + transform_spec=None, is_batched_reader=False, filters=None, shard_seed=None, + collate_lists_fn=None): """Initializes a reader object. :param pyarrow_filesystem: An instance of ``pyarrow.FileSystem`` that will be used. If not specified, @@ -429,7 +433,7 @@ def __init__(self, pyarrow_filesystem, dataset_path, schema_fields=None, # By default, use original method of working with list of dictionaries and not arrow tables worker_class = worker_class or PyDictReaderWorker - self._results_queue_reader = worker_class.new_results_queue_reader() + self._results_queue_reader = worker_class.new_results_queue_reader(collate_lists_fn) if self.ngram and not self.ngram.timestamp_overlap and shuffle_row_drop_partitions > 1: raise NotImplementedError('Using timestamp_overlap=False is not implemented with' diff --git a/petastorm/tests/test_parquet_reader.py b/petastorm/tests/test_parquet_reader.py index ba15281c..1710fa83 100644 --- a/petastorm/tests/test_parquet_reader.py +++ b/petastorm/tests/test_parquet_reader.py @@ -24,7 +24,7 @@ # pylint: disable=unnecessary-lambda from petastorm.tests.test_common import create_test_scalar_dataset from petastorm.transform import TransformSpec -from petastorm.unischema import UnischemaField +from petastorm.unischema import UnischemaField, Unischema _D = [lambda url, **kwargs: make_batch_reader(url, reader_pool_type='dummy', **kwargs)] @@ -258,3 +258,23 @@ def test_random_seed(scalar_dataset): results.append(actual_row_ids) # Shuffled results are expected to be same np.testing.assert_equal(results[0], results[1]) + +@pytest.mark.parametrize('reader_factory', _D + _TP) +def test_read_with_collate(reader_factory, tmp_path): + data = pd.DataFrame({"str": ["a", "bc"], "varlen_nums": [[1], [3, 4]]}) + path = tmp_path / 'data' + url = f"file:///{path}" + data.to_parquet(path) + + def collate_lists_fn(column_name: str, schema: Unischema, values): + max_len = max(map(len, values)) + result = np.asarray([np.pad(v, (0, max_len - len(v)), 'constant', constant_values=0) for v in values]) + return result + + with reader_factory(url, collate_lists_fn=collate_lists_fn) as reader: + actual = list(reader) + + assert len(actual) == 1 + np.testing.assert_equal(actual[0].varlen_nums, [[1, 0], [1, 2]]) + np.testing.assert_equal(actual[0].str, ["a", "bc"]) +