From 0d422b89b894b486042af666b454efe2070d0e99 Mon Sep 17 00:00:00 2001 From: Frances Hartwell Date: Tue, 17 Sep 2024 13:26:34 -0400 Subject: [PATCH] Add stub methods (#2218) --- sdv/multi_table/base.py | 10 ++++++++++ tests/unit/multi_table/test_base.py | 4 ++++ 2 files changed, 14 insertions(+) diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 5565428f9..7ffd3eef3 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -340,6 +340,10 @@ def _store_and_convert_original_cols(self, data): data[table] = dataframe return list_of_changed_tables + def _transform_helper(self, data): + """Stub method for transforming data patterns.""" + return data + def preprocess(self, data): """Transform the raw data to numerical space. @@ -353,6 +357,7 @@ def preprocess(self, data): """ list_of_changed_tables = self._store_and_convert_original_cols(data) + data = self._transform_helper(data) self.validate(data) if self._fitted: warnings.warn( @@ -471,6 +476,10 @@ def reset_sampling(self): def _sample(self, scale): raise NotImplementedError() + def _reverse_transform_helper(self, sampled_data): + """Stub method for reverse transforming data patterns.""" + return sampled_data + def sample(self, scale=1.0): """Generate synthetic data for the entire dataset. @@ -495,6 +504,7 @@ def sample(self, scale=1.0): with self._set_temp_numpy_seed(), disable_single_table_logger(): sampled_data = self._sample(scale=scale) + sampled_data = self._reverse_transform_helper(sampled_data) total_rows = 0 total_columns = 0 diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index f732fe827..464681938 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -763,6 +763,7 @@ def test_preprocess(self): 'id_upravna_enota': np.arange(10), }), } + instance._transform_helper = Mock(return_value=data) synth_nesreca = Mock() synth_oseba = Mock() @@ -782,6 +783,7 @@ def test_preprocess(self): 'oseba': synth_oseba._preprocess.return_value, 'upravna_enota': synth_upravna_enota._preprocess.return_value, } + instance._transform_helper.assert_called_once_with(data) instance.validate.assert_called_once_with(data) assert instance.metadata._get_all_foreign_keys.call_args_list == [ call('nesreca'), @@ -1212,6 +1214,7 @@ def test_sample(self, mock_datetime, caplog): 'table2': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}), } instance._sample = Mock(return_value=data) + instance._reverse_transform_helper = Mock(return_value=data) synth_id = 'BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' instance._synthesizer_id = synth_id @@ -1222,6 +1225,7 @@ def test_sample(self, mock_datetime, caplog): # Assert instance._sample.assert_called_once_with(scale=1.5) + instance._reverse_transform_helper.assert_called_once_with(data) assert caplog.messages[0] == str({ 'EVENT': 'Sample', 'TIMESTAMP': '2024-04-19 16:20:10.037183',