From fb1f0a0519eb1d63aba5be778663d016832083c8 Mon Sep 17 00:00:00 2001 From: Frances Hartwell Date: Wed, 31 May 2023 10:44:53 -0400 Subject: [PATCH 1/2] switch HMA over to using HierarchicalSampler mixin --- sdv/multi_table/hma.py | 253 ++---------- sdv/sampling/__init__.py | 4 + sdv/sampling/hierarchical_sampler.py | 2 +- sdv/single_table/base.py | 11 +- tests/unit/multi_table/test_hma.py | 360 +++--------------- .../sampling/test_hierarchical_sampler.py | 2 +- 6 files changed, 109 insertions(+), 523 deletions(-) diff --git a/sdv/multi_table/hma.py b/sdv/multi_table/hma.py index 19fbd1421..361115c7e 100644 --- a/sdv/multi_table/hma.py +++ b/sdv/multi_table/hma.py @@ -9,11 +9,12 @@ from tqdm import tqdm from sdv.multi_table.base import BaseMultiTableSynthesizer +from sdv.sampling import BaseHierarchicalSampler LOGGER = logging.getLogger(__name__) -class HMASynthesizer(BaseMultiTableSynthesizer): +class HMASynthesizer(BaseHierarchicalSampler, BaseMultiTableSynthesizer): """Hierarchical Modeling Algorithm One. Args: @@ -31,13 +32,19 @@ class HMASynthesizer(BaseMultiTableSynthesizer): } def __init__(self, metadata, locales=None, verbose=True): - super().__init__(metadata, locales=locales) + BaseMultiTableSynthesizer.__init__(self, metadata, locales=locales) self._table_sizes = {} self._max_child_rows = {} self._augmented_tables = [] self._learned_relationships = 0 self.verbose = verbose + BaseHierarchicalSampler.__init__( + self, + self.metadata, + self._table_synthesizers, + self._table_sizes) + def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc): """Generate the extension columns for this child table. @@ -115,15 +122,6 @@ def _clear_nans(table_data): table_data[column] = table_data[column].fillna(fill_value) - def _get_foreign_keys(self, table_name, child_name): - foreign_keys = [] - for relation in self.metadata.relationships: - if table_name == relation['parent_table_name'] and\ - child_name == relation['child_table_name']: - foreign_keys.append(deepcopy(relation['child_foreign_key'])) - - return foreign_keys - def _augment_table(self, table, tables, table_name): """Generate the extension columns for this table. @@ -152,7 +150,7 @@ def _augment_table(self, table, tables, table_name): else: child_table = tables[child_name] - foreign_keys = self._get_foreign_keys(table_name, child_name) + foreign_keys = self.metadata._get_foreign_keys(table_name, child_name) for foreign_key in foreign_keys: progress_bar_desc = ( f'({self._learned_relationships + 1}/{len(self.metadata.relationships)}) ' @@ -241,46 +239,6 @@ def _augment_tables(self, processed_data): LOGGER.info('Augmentation Complete') return augmented_data - def _finalize(self, sampled_data): - """Do the final touches to the generated data. - - This method reverts the previous transformations to go back - to values in the original space and also adds the parent - keys in case foreign key relationships exist between the tables. - - Args: - sampled_data (dict): - Generated data - - Return: - pandas.DataFrame: - Formatted synthesized data. - """ - final_data = {} - for table_name, table_rows in sampled_data.items(): - parents = self.metadata._get_parent_map().get(table_name) - if parents: - for parent_name in parents: - foreign_keys = self._get_foreign_keys(parent_name, table_name) - for foreign_key in foreign_keys: - if foreign_key not in table_rows: - parent_ids = self._find_parent_ids( - table_name, - parent_name, - foreign_key, - sampled_data - ) - table_rows[foreign_key] = parent_ids.to_numpy() - - synthesizer = self._table_synthesizers.get(table_name) - dtypes = synthesizer._data_processor._dtypes - for name, dtype in dtypes.items(): - table_rows[name] = table_rows[name].dropna().astype(dtype) - - final_data[table_name] = table_rows[list(dtypes.keys())] - - return final_data - def _extract_parameters(self, parent_row, table_name, foreign_key): """Get the params from a generated parent row. @@ -308,105 +266,17 @@ def _extract_parameters(self, parent_row, table_name, foreign_key): return flat_parameters.rename(new_keys).to_dict() - def _process_samples(self, table_name, sampled_rows): - """Process the ``sampled_rows`` for the given ``table_name``. - - Process the raw samples and convert them to the original space by reverse transforming - them. Also, when there are synthesizer columns (columns used to recreate an instance - of a synthesizer), those will be returned together. - """ - data_processor = self._table_synthesizers[table_name]._data_processor - sampled = data_processor.reverse_transform(sampled_rows) - - synthesizer_columns = list(set(sampled_rows.columns) - set(sampled.columns)) - if synthesizer_columns: - sampled = pd.concat([sampled, sampled_rows[synthesizer_columns]], axis=1) - - return sampled - - def _sample_rows(self, synthesizer, table_name, num_rows=None): - """Sample ``num_rows`` from ``synthesizer``. - - Args: - synthesizer (copula.multivariate.base): - Fitted synthesizer. - table_name (str): - Name of the table to sample from. - num_rows (int): - Number of rows to sample. + def _recreate_child_synthesizer(self, child_name, parent_name, parent_row): + foreign_key = self.metadata._get_foreign_keys(parent_name, child_name)[0] + parameters = self._extract_parameters(parent_row, child_name, foreign_key) + table_meta = self.metadata.tables[child_name] - Returns: - pandas.DataFrame: - Sampled rows, shape (, num_rows) - """ - num_rows = num_rows or synthesizer._num_rows - if synthesizer._model: - sampled_rows = synthesizer._sample(num_rows) - else: - sampled_rows = pd.DataFrame(index=range(num_rows)) - - return self._process_samples(table_name, sampled_rows) - - def _get_child_synthesizer(self, parent_row, table_name, foreign_key): - parameters = self._extract_parameters(parent_row, table_name, foreign_key) - table_meta = self.metadata.tables[table_name] - synthesizer = self._synthesizer(table_meta, **self._table_parameters[table_name]) + synthesizer = self._synthesizer(table_meta, **self._table_parameters[child_name]) synthesizer._set_parameters(parameters) + synthesizer._data_processor = self._table_synthesizers[child_name]._data_processor return synthesizer - def _sample_child_rows(self, table_name, parent_name, parent_row, sampled_data): - """Sample child rows that reference the given parent row. - - The sampled rows will be stored in ``sampled_data`` under the ``table_name`` key. - - Args: - table_name (str): - The name of the table to sample. - parent_name (str): - The name of the parent table. - parent_row (pandas.Series): - The parent row the child rows should reference. - sampled_data (dict): - A map of table name to sampled table data (pandas.DataFrame). - """ - foreign_key = self._get_foreign_keys(parent_name, table_name)[0] - synthesizer = self._get_child_synthesizer(parent_row, table_name, foreign_key) - table_rows = self._sample_rows(synthesizer, table_name) - - if len(table_rows): - parent_key = self.metadata.tables[parent_name].primary_key - table_rows[foreign_key] = parent_row[parent_key] - - previous = sampled_data.get(table_name) - if previous is None: - sampled_data[table_name] = table_rows - else: - sampled_data[table_name] = pd.concat( - [previous, table_rows]).reset_index(drop=True) - - def _sample_children(self, table_name, sampled_data, table_rows): - """Recursively sample the child tables of the given table. - - Sampled child data will be stored into `sampled_data`. - - Args: - table_name (str): - The name of the table whose children will be sampled. - sampled_data (dict): - A map of table name to the sampled table data (pandas.DataFrame). - table_rows (pandas.DataFrame): - The sampled rows of the given table. - """ - for child_name in self.metadata._get_child_map()[table_name]: - if child_name not in sampled_data: - LOGGER.info('Sampling rows from child table %s', child_name) - for _, row in table_rows.iterrows(): - self._sample_child_rows(child_name, table_name, row, sampled_data) - - child_rows = sampled_data[child_name] - self._sample_children(child_name, sampled_data, child_rows) - @staticmethod def _find_parent_id(likelihoods, num_rows): """Find the parent id for one row based on the likelihoods of parent id values. @@ -484,7 +354,7 @@ def _get_likelihoods(self, table_rows, parent_rows, table_name, foreign_key): return pd.DataFrame(likelihoods, index=table_rows.index) - def _find_parent_ids(self, table_name, parent_name, foreign_key, sampled_data): + def _find_parent_ids(self, child_table, parent_table, child_name, parent_name, foreign_key): """Find parent ids for the given table and foreign key. The parent ids are chosen randomly based on the likelihood of the available @@ -492,81 +362,36 @@ def _find_parent_ids(self, table_name, parent_name, foreign_key, sampled_data): will first sample rows for the parent table. Args: - table_name (str): - The name of the table to find parent ids for. - parent_name (str): - The name of the parent table. + child_table (pd.DataFrame): + + parent_table (pd.DataFrame): + + child_name (str): + + parent_name (dict): + Map of table name to sampled data (pandas.DataFrame). foreign_key (str): The name of the foreign key column in the child table. - sampled_data (dict): - Map of table name to sampled data (pandas.DataFrame). Returns: pandas.Series: The parent ids for the given table data. """ - table_rows = sampled_data[table_name] - if parent_name in sampled_data: - parent_rows = sampled_data[parent_name] - else: - ratio = self._table_sizes[parent_name] / self._table_sizes[table_name] - num_parent_rows = max(int(round(len(table_rows) * ratio)), 1) - parent_model = self._table_synthesizers[parent_name] - parent_rows = self._sample_rows(parent_model, parent_name, num_parent_rows) - primary_key = self.metadata.tables[parent_name].primary_key - parent_rows = parent_rows.set_index(primary_key) - num_rows = parent_rows[f'__{table_name}__{foreign_key}__num_rows'].fillna(0).clip(0) + parent_table = parent_table.set_index(primary_key) + num_rows = parent_table[f'__{child_name}__{foreign_key}__num_rows'].fillna(0).clip(0) - likelihoods = self._get_likelihoods(table_rows, parent_rows, table_name, foreign_key) + likelihoods = self._get_likelihoods(child_table, parent_table, child_name, foreign_key) return likelihoods.apply(self._find_parent_id, axis=1, num_rows=num_rows) - def _sample_table(self, table_name, scale=1.0, sample_children=True, sampled_data=None): - """Sample a single table and optionally its children.""" - if sampled_data is None: - sampled_data = {} - - num_rows = int(self._table_sizes[table_name] * scale) - - LOGGER.info('Sampling %s rows from table %s', num_rows, table_name) - - synthesizer = self._table_synthesizers[table_name] - table_rows = self._sample_rows(synthesizer, table_name, num_rows) - sampled_data[table_name] = table_rows - - if sample_children: - self._sample_children(table_name, sampled_data, table_rows) - - return sampled_data - - def _sample(self, scale=1.0): - """Sample the entire dataset. - - Returns a dictionary with all the tables of the dataset. The amount of rows sampled will - depend from table to table. This is because the children tables are created modelling the - relation that they have with their parent tables, so its behavior may change from one - table to another. - - Args: - scale (float): - A float representing how much to scale the data by. If scale is set to ``1.0``, - this does not scale the sizes of the tables. If ``scale`` is greater than ``1.0`` - create more rows than the original data by a factor of ``scale``. - If ``scale`` is lower than ``1.0`` create fewer rows by the factor of ``scale`` - than the original tables. Defaults to ``1.0``. - - Returns: - dict: - A dictionary containing as keys the names of the tables and as values the - sampled data tables as ``pandas.DataFrame``. - - Raises: - NotFittedError: - A ``NotFittedError`` is raised when the ``SDV`` instance has not been fitted yet. - """ - sampled_data = {} - for table in self.metadata.tables: - if not self.metadata._get_parent_map().get(table): - self._sample_table(table, scale=scale, sampled_data=sampled_data) - - return self._finalize(sampled_data) + def _add_foreign_key_columns(self, child_table, parent_table, child_name, parent_name): + for foreign_key in self.metadata._get_foreign_keys(parent_name, child_name): + if foreign_key not in child_table: + parent_ids = self._find_parent_ids( + child_table=child_table, + parent_table=parent_table, + child_name=child_name, + parent_name=parent_name, + foreign_key=foreign_key + ) + child_table[foreign_key] = parent_ids.to_numpy() diff --git a/sdv/sampling/__init__.py b/sdv/sampling/__init__.py index 47034a007..1cec8a2a6 100644 --- a/sdv/sampling/__init__.py +++ b/sdv/sampling/__init__.py @@ -1,7 +1,11 @@ """SDV Sampling module.""" +from sdv.sampling.hierarchical_sampler import BaseHierarchicalSampler +from sdv.sampling.independent_sampler import BaseIndependentSampler from sdv.sampling.tabular import Condition __all__ = [ + 'BaseHierarchicalSampler', + 'BaseIndependentSampler', 'Condition', ] diff --git a/sdv/sampling/hierarchical_sampler.py b/sdv/sampling/hierarchical_sampler.py index a2fcad882..72487d64a 100644 --- a/sdv/sampling/hierarchical_sampler.py +++ b/sdv/sampling/hierarchical_sampler.py @@ -68,7 +68,7 @@ def _sample_rows(self, synthesizer, num_rows=None): Sampled rows, shape (, num_rows) """ num_rows = num_rows or synthesizer._num_rows - return synthesizer._sample_batch(num_rows, remove_extra_columns=False) + return synthesizer._sample_batch(int(num_rows), keep_extra_columns=True) def _get_num_rows_from_parent(self, parent_row, child_name, foreign_key): """Get the number of rows to sample for the child from the parent row.""" diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index 50b481f9b..6a690eb28 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -594,7 +594,7 @@ def _sample_rows(self, num_rows, conditions=None, transformed_conditions=None, Maximum tolerance when considering a float match. previous_rows (pandas.DataFrame): Valid rows sampled in the previous iterations. - remove_extra_columns (bool): + keep_extra_columns (bool): Whether to keep extra columns from the sampled data. Defaults to False. Returns: @@ -604,10 +604,11 @@ def _sample_rows(self, num_rows, conditions=None, transformed_conditions=None, * int: Number of rows that are considered valid. """ - if not self._random_state_set: - self._set_random_state(FIXED_RNG_SEED) + if (self._model and ( + self._data_processor.get_sdtypes(primary_keys=False) or keep_extra_columns)): + if not self._random_state_set: + self._set_random_state(FIXED_RNG_SEED) - if self._data_processor.get_sdtypes(primary_keys=False): if conditions is None: raw_sampled = self._sample(num_rows) else: @@ -682,7 +683,7 @@ def _sample_batch(self, batch_size, max_tries=100, output_file_path (str or None): The file to periodically write sampled rows to. If None, does not write rows anywhere. - remove_extra_columns (bool): + keep_extra_columns (bool): Whether to keep extra columns from the sampled data. Defaults to False. Returns: diff --git a/tests/unit/multi_table/test_hma.py b/tests/unit/multi_table/test_hma.py index 7208a3e61..9e086066b 100644 --- a/tests/unit/multi_table/test_hma.py +++ b/tests/unit/multi_table/test_hma.py @@ -91,18 +91,6 @@ def test__get_extension_foreign_key_only(self): pd.testing.assert_frame_equal(result, expected) - def test__get_foreign_keys(self): - """Test that this method returns the foreign keys for a given table name and child name.""" - # Setup - metadata = get_multi_table_metadata() - instance = HMASynthesizer(metadata) - - # Run - result = instance._get_foreign_keys('nesreca', 'oseba') - - # Assert - assert result == ['id_nesreca'] - def test__get_all_foreign_keys(self): """Test that this method returns all the foreign keys for a given table name.""" # Setup @@ -284,8 +272,8 @@ def test__augment_tables(self): def test__finalize(self): """Test that the finalize method applies the final touches to the generated data. - The process consists of applying the propper data types to each table, and finding - foreign keys if those are not present in the current sampled data. + The process consists of applying the propper data types to each table, and dropping + extra columns not present in the metadata. """ # Setup instance = Mock() @@ -296,9 +284,6 @@ def test__finalize(self): } instance.metadata = metadata - instance._get_foreign_keys.side_effect = [['user_id'], ['session_id']] - instance._find_parent_ids.return_value = pd.Series(['a', 'a', 'b'], name='session_id') - sampled_data = { 'users': pd.DataFrame({ 'user_id': pd.Series([0, 1, 2], dtype=np.int64), @@ -314,6 +299,7 @@ def test__finalize(self): }), 'transactions': pd.DataFrame({ 'transaction_id': pd.Series([1, 2, 3], dtype=np.int64), + 'session_id': pd.Series(['a', 'a', 'b'], dtype=object), }), } @@ -388,311 +374,36 @@ def test__extract_parameters(self): assert result == expected_result - def test__process_samples(self): - """Test the ``_process_samples``. - - Test that the method retrieves the ``data_processor`` from the fitted ``table_synthesizer`` - and performs a ``reverse_transform`` and returns the data in the real space. - """ - # Setup - sampled_rows = pd.DataFrame({ - 'name': [0.1, 0.25, 0.35], - 'a': [1.0, 0.25, 0.5], - 'b': [0.2, 0.6, 0.9], - 'loc': [0.5, 0.1, 0.2], - 'num_rows': [1, 2, 3], - 'scale': [0.25, 0.35, 0.15] - }) - instance = Mock() - users_synthesizer = Mock() - users_synthesizer._data_processor.reverse_transform.return_value = pd.DataFrame({ - 'user_id': [0, 1, 2], - 'name': ['John', 'Doe', 'Johanna'] - }) - instance._table_synthesizers = {'users': users_synthesizer} - - # Run - result = HMASynthesizer._process_samples(instance, 'users', sampled_rows) - - # Assert - expected_result = pd.DataFrame({ - 'user_id': [0, 1, 2], - 'name': ['John', 'Doe', 'Johanna'], - 'a': [1.0, 0.25, 0.5], - 'b': [0.2, 0.6, 0.9], - 'loc': [0.5, 0.1, 0.2], - 'num_rows': [1, 2, 3], - 'scale': [0.25, 0.35, 0.15] - }) - result = result.reindex(sorted(result.columns), axis=1) - expected_result = expected_result.reindex(sorted(expected_result.columns), axis=1) - pd.testing.assert_frame_equal(result, expected_result) - - def test__sample_rows(self): - """Test sample rows. - - Test that sampling rows will return the reverse transformed data with the extension columns - sampled by the model. - """ - # Setup - synthesizer = Mock() - instance = Mock() - - # Run - result = HMASynthesizer._sample_rows(instance, synthesizer, 'users', 10) - - # Assert - assert result == instance._process_samples.return_value - instance._process_samples.assert_called_once_with( - 'users', - synthesizer._sample.return_value - ) - synthesizer._sample.assert_called_once_with(10) - - def test__get_child_synthesizer(self): + def test__recreate_child_synthesizer(self): """Test that this method returns a synthesizer for the given child table.""" # Setup instance = Mock() parent_row = 'row' table_name = 'users' - foreign_key = 'session_id' + parent_table_name = 'sessions' table_meta = Mock() + table_synthesizer = Mock() instance.metadata.tables = {'users': table_meta} + instance.metadata._get_foreign_keys.return_value = ['session_id'] instance._table_parameters = {'users': {'a': 1}} + instance._table_synthesizers = {'users': table_synthesizer} # Run - synthesizer = HMASynthesizer._get_child_synthesizer( + synthesizer = HMASynthesizer._recreate_child_synthesizer( instance, - parent_row, table_name, - foreign_key + parent_table_name, + parent_row, ) # Assert assert synthesizer == instance._synthesizer.return_value + assert synthesizer._data_processor == table_synthesizer._data_processor instance._synthesizer.assert_called_once_with(table_meta, a=1) synthesizer._set_parameters.assert_called_once_with( instance._extract_parameters.return_value ) - instance._extract_parameters.assert_called_once_with(parent_row, table_name, foreign_key) - - def test__sample_child_rows(self): - """Test the sampling of child rows when sampled data is empty.""" - # Setup - instance = Mock() - instance._get_foreign_keys.return_value = ['user_id'] - instance._extract_parameters.return_value = { - 'a': 1.0, - 'b': 0.2, - 'loc': 0.5, - 'num_rows': 10.0, - 'scale': 0.25 - } - - metadata = Mock() - sessions_meta = Mock() - users_meta = Mock() - users_meta.primary_key.return_value = 'user_id' - metadata.tables = { - 'users': users_meta, - 'sessions': sessions_meta - } - instance.metadata = metadata - instance._synthesizer_kwargs = {'a': 0.1, 'b': 0.5, 'loc': 0.25} - - instance._sample_rows.return_value = pd.DataFrame({ - 'session_id': ['a', 'b', 'c'], - 'os': ['linux', 'mac', 'win'], - 'country': ['us', 'us', 'es'], - }) - parent_row = pd.DataFrame({ - 'user_id': [1, 2, 3], - 'name': ['John', 'Doe', 'Johanna'] - }) - sampled_data = {} - - # Run - HMASynthesizer._sample_child_rows(instance, 'sessions', 'users', parent_row, sampled_data) - - # Assert - expected_result = pd.DataFrame({ - 'session_id': ['a', 'b', 'c'], - 'os': ['linux', 'mac', 'win'], - 'country': ['us', 'us', 'es'], - 'user_id': [1, 2, 3] - }) - pd.testing.assert_frame_equal(sampled_data['sessions'], expected_result) - - def test__sample_child_rows_with_sampled_data(self): - """Test the sampling of child rows when sampled data contains values. - - The new sampled data has to be concatenated to the current sampled data. - """ - # Setup - instance = Mock() - instance._get_foreign_keys.return_value = ['user_id'] - instance._extract_parameters.return_value = { - 'a': 1.0, - 'b': 0.2, - 'loc': 0.5, - 'num_rows': 10.0, - 'scale': 0.25 - } - - metadata = Mock() - sessions_meta = Mock() - users_meta = Mock() - users_meta.primary_key.return_value = 'user_id' - metadata.tables = { - 'users': users_meta, - 'sessions': sessions_meta - } - instance.metadata = metadata - instance._synthesizer_kwargs = {'a': 0.1, 'b': 0.5, 'loc': 0.25} - - instance._sample_rows.return_value = pd.DataFrame({ - 'session_id': ['a', 'b', 'c'], - 'os': ['linux', 'mac', 'win'], - 'country': ['us', 'us', 'es'], - }) - parent_row = pd.DataFrame({ - 'user_id': [1, 2, 3], - 'name': ['John', 'Doe', 'Johanna'] - }) - sampled_data = { - 'sessions': pd.DataFrame({ - 'user_id': [0, 1, 0], - 'session_id': ['d', 'e', 'f'], - 'os': ['linux', 'mac', 'win'], - 'country': ['us', 'us', 'es'], - }) - } - - # Run - HMASynthesizer._sample_child_rows(instance, 'sessions', 'users', parent_row, sampled_data) - - # Assert - expected_result = pd.DataFrame({ - 'user_id': [0, 1, 0, 1, 2, 3], - 'session_id': ['d', 'e', 'f', 'a', 'b', 'c'], - 'os': ['linux', 'mac', 'win', 'linux', 'mac', 'win'], - 'country': ['us', 'us', 'es', 'us', 'us', 'es'], - }) - pd.testing.assert_frame_equal(sampled_data['sessions'], expected_result) - - def test__sample_children(self): - """Test that child tables are being sampled recursively.""" - # Setup - def update_sampled_data(child_name, table_name, row, sampled_data): - sampled_data['sessions'] = pd.DataFrame({ - 'user_id': [1], - 'session_id': ['d'], - 'os': ['linux'], - 'country': ['us'], - }) - - metadata = Mock() - metadata._get_child_map.return_value = {'users': ['sessions']} - instance = Mock() - table_rows = pd.DataFrame({ - 'user_id': [1, 2, 3], - 'name': ['John', 'Doe', 'Johanna'] - }) - sampled_data = {} - instance.metadata = metadata - instance._sample_child_rows.side_effect = update_sampled_data - - # Run - HMASynthesizer._sample_children(instance, 'users', sampled_data, table_rows) - - # Assert - assert instance._sample_child_rows.call_count == 3 - sample_calls = instance._sample_child_rows.call_args_list - pd.testing.assert_series_equal( - sample_calls[0][0][2], - pd.Series({'user_id': 1, 'name': 'John'}, name=0) - ) - pd.testing.assert_series_equal( - sample_calls[1][0][2], - pd.Series({'user_id': 2, 'name': 'Doe'}, name=1) - ) - pd.testing.assert_series_equal( - sample_calls[2][0][2], - pd.Series({'user_id': 3, 'name': 'Johanna'}, name=2) - ) - - def test__sample_table(self): - """Test sampling a table. - - The ``sample_table`` method will call sample children and return the sampled data - dictionary. - """ - # Setup - def sample_children(table_name, sampled_data, table_rows): - sampled_data['sessions'] = pd.DataFrame({ - 'user_id': [1, 1, 3], - 'session_id': ['a', 'b', 'c'], - 'os': ['windows', 'linux', 'mac'], - 'country': ['us', 'us', 'es'] - }) - sampled_data['transactions'] = pd.DataFrame({ - 'transaction_id': [1, 2, 3], - 'session_id': ['a', 'a', 'b'] - }) - - instance = Mock() - instance._table_sizes = {'users': 10} - instance._table_synthesizers = {'users': Mock()} - instance._sample_children.side_effect = sample_children - instance._sample_rows.return_value = pd.DataFrame({ - 'user_id': [1, 2, 3], - 'name': ['John', 'Doe', 'Johanna'] - }) - - # Run - result = HMASynthesizer._sample_table(instance, 'users') - - # Assert - expected_result = { - 'users': pd.DataFrame({ - 'user_id': [1, 2, 3], - 'name': ['John', 'Doe', 'Johanna'], - }), - 'sessions': pd.DataFrame({ - 'user_id': [1, 1, 3], - 'session_id': ['a', 'b', 'c'], - 'os': ['windows', 'linux', 'mac'], - 'country': ['us', 'us', 'es'], - }), - 'transactions': pd.DataFrame({ - 'transaction_id': [1, 2, 3], - 'session_id': ['a', 'a', 'b'] - }) - } - for result_frame, expected_frame in zip(result.values(), expected_result.values()): - pd.testing.assert_frame_equal(result_frame, expected_frame) - - def test__sample(self): - """Test that the ``_sample_table`` is called for tables that don't have parents.""" - # Setup - instance = Mock() - instance.metadata._get_parent_map.return_value = { - 'sessions': ['users'], - 'transactions': ['sessions'] - } - instance.metadata.tables = { - 'users': Mock(), - 'sessions': Mock(), - 'transactions': Mock(), - } - - # Run - result = HMASynthesizer._sample(instance) - - # Assert - assert result == instance._finalize.return_value - instance._sample_table.assert_called_once_with('users', scale=1, sampled_data={}) - instance._finalize.assert_called_once_with({}) + instance._extract_parameters.assert_called_once_with(parent_row, table_name, 'session_id') def test_get_learned_distributions(self): """Test that ``get_learned_distributions`` returns a dict. @@ -752,3 +463,48 @@ def test_get_learned_distributions_raises_an_error(self): ) with pytest.raises(ValueError, match=error_msg): instance.get_learned_distributions('upravna_enota') + + def test__add_foreign_key_columns(self): + """Test that the ``_add_foreign_key_columns`` method adds foreign keys.""" + # Setup + instance = Mock() + metadata = Mock() + metadata._get_foreign_keys.return_value = ['primary_user_id', 'secondary_user_id'] + instance.metadata = metadata + + instance._find_parent_ids.return_value = pd.Series([2, 1, 2], name='secondary_user_id') + + parent_table = pd.DataFrame({ + 'user_id': pd.Series([0, 1, 2], dtype=np.int64), + 'name': pd.Series(['John', 'Doe', 'Johanna'], dtype=object), + }) + child_table = pd.DataFrame({ + 'transaction_id': pd.Series([1, 2, 3], dtype=np.int64), + 'primary_user_id': pd.Series([0, 0, 1], dtype=np.int64) + }) + + instance._table_synthesizers = { + 'users': Mock(), + 'transactions': Mock() + } + + # Run + HMASynthesizer._add_foreign_key_columns( + instance, + child_table, + parent_table, + 'transactions', + 'users') + + # Assert + expected_parent_table = pd.DataFrame({ + 'user_id': pd.Series([0, 1, 2], dtype=np.int64), + 'name': pd.Series(['John', 'Doe', 'Johanna'], dtype=object), + }) + expected_child_table = pd.DataFrame({ + 'transaction_id': pd.Series([1, 2, 3], dtype=np.int64), + 'primary_user_id': pd.Series([0, 0, 1], dtype=np.int64), + 'secondary_user_id': pd.Series([2, 1, 2], dtype=np.int64) + }) + pd.testing.assert_frame_equal(expected_parent_table, parent_table) + pd.testing.assert_frame_equal(expected_child_table, child_table) diff --git a/tests/unit/sampling/test_hierarchical_sampler.py b/tests/unit/sampling/test_hierarchical_sampler.py index 030b86e87..b24a772d0 100644 --- a/tests/unit/sampling/test_hierarchical_sampler.py +++ b/tests/unit/sampling/test_hierarchical_sampler.py @@ -59,7 +59,7 @@ def test__sample_rows(self): assert result == synthesizer._sample_batch.return_value synthesizer._sample_batch.assert_called_once_with( 10, - remove_extra_columns=False + keep_extra_columns=True ) def test__get_num_rows_from_parent(self): From 757ea166896b8c3173750ad119e59738554edfda Mon Sep 17 00:00:00 2001 From: Frances Hartwell Date: Tue, 20 Jun 2023 09:22:26 -0400 Subject: [PATCH 2/2] comment --- sdv/multi_table/hma.py | 10 +++++----- sdv/single_table/base.py | 9 +++++---- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/sdv/multi_table/hma.py b/sdv/multi_table/hma.py index 361115c7e..9615257fc 100644 --- a/sdv/multi_table/hma.py +++ b/sdv/multi_table/hma.py @@ -358,16 +358,15 @@ def _find_parent_ids(self, child_table, parent_table, child_name, parent_name, f """Find parent ids for the given table and foreign key. The parent ids are chosen randomly based on the likelihood of the available - parent ids in the parent table. If the parent table is not sampled, this method - will first sample rows for the parent table. + parent ids in the parent table. Args: child_table (pd.DataFrame): - + The child table dataframe. parent_table (pd.DataFrame): - + The parent table dataframe. child_name (str): - + The name of the child table. parent_name (dict): Map of table name to sampled data (pandas.DataFrame). foreign_key (str): @@ -377,6 +376,7 @@ def _find_parent_ids(self, child_table, parent_table, child_name, parent_name, f pandas.Series: The parent ids for the given table data. """ + # Create a copy of the parent table with the primary key as index to calculate likilihoods primary_key = self.metadata.tables[parent_name].primary_key parent_table = parent_table.set_index(primary_key) num_rows = parent_table[f'__{child_name}__{foreign_key}__num_rows'].fillna(0).clip(0) diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index 6a690eb28..4bc080a1e 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -604,10 +604,11 @@ def _sample_rows(self, num_rows, conditions=None, transformed_conditions=None, * int: Number of rows that are considered valid. """ - if (self._model and ( - self._data_processor.get_sdtypes(primary_keys=False) or keep_extra_columns)): - if not self._random_state_set: - self._set_random_state(FIXED_RNG_SEED) + if self._model and not self._random_state_set: + self._set_random_state(FIXED_RNG_SEED) + + need_sample = self._data_processor.get_sdtypes(primary_keys=False) or keep_extra_columns + if self._model and need_sample: if conditions is None: raw_sampled = self._sample(num_rows)