From bbe4921dc0eb5d14007d78bc0430980b89f50808 Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Tue, 11 Jun 2024 10:56:29 -0500 Subject: [PATCH 1/3] Remove the temp file from sample --- sdv/lite/single_table.py | 6 ++-- sdv/single_table/base.py | 29 ++++--------------- sdv/single_table/utils.py | 24 ++++++---------- tests/unit/single_table/test_base.py | 30 ++++++++++++++------ tests/unit/single_table/test_utils.py | 40 ++++++++++++++------------- 5 files changed, 58 insertions(+), 71 deletions(-) diff --git a/sdv/lite/single_table.py b/sdv/lite/single_table.py index a7f9479b8..9e4f2129a 100644 --- a/sdv/lite/single_table.py +++ b/sdv/lite/single_table.py @@ -136,8 +136,7 @@ def sample_from_conditions(self, conditions, max_tries_per_batch=100, The batch size to use per attempt at sampling. Defaults to 10 times the number of rows. output_file_path (str or None): - The file to periodically write sampled rows to. Defaults to - a temporary file, if None. + The file to periodically write sampled rows to. Returns: pandas.DataFrame: @@ -168,8 +167,7 @@ def sample_remaining_columns(self, known_columns, max_tries_per_batch=100, The batch size to use per attempt at sampling. Defaults to 10 times the number of rows. output_file_path (str or None): - The file to periodically write sampled rows to. Defaults to - a temporary file, if None. + The file to periodically write sampled rows to. Returns: pandas.DataFrame: diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index f34045363..154efbd93 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -33,8 +33,6 @@ COND_IDX = str(uuid.uuid4()) FIXED_RNG_SEED = 73251 -TMP_FILE_NAME = '.sample.csv.temp' -DISABLE_TMP_FILE = 'disable' class BaseSynthesizer: @@ -859,11 +857,7 @@ def _sample_with_progress_bar(self, num_rows, max_tries_per_batch=100, batch_siz ) except (Exception, KeyboardInterrupt) as error: - handle_sampling_error(output_file_path == TMP_FILE_NAME, output_file_path, error) - - else: - if output_file_path == TMP_FILE_NAME and os.path.exists(output_file_path): - os.remove(output_file_path) + handle_sampling_error(output_file_path, error) return sampled @@ -935,8 +929,7 @@ def _sample_with_conditions(self, conditions, max_tries_per_batch, batch_size, progress_bar (tqdm.tqdm or None): The progress bar to update. output_file_path (str or None): - The file to periodically write sampled rows to. Defaults to - a temporary file, if None. + The file to periodically write sampled rows to. Returns: pandas.DataFrame: @@ -1059,8 +1052,7 @@ def sample_from_conditions(self, conditions, max_tries_per_batch=100, batch_size (int): The batch size to use per sampling call. output_file_path (str or None): - The file to periodically write sampled rows to. Defaults to - a temporary file, if None. + The file to periodically write sampled rows to. Returns: pandas.DataFrame: @@ -1106,11 +1098,7 @@ def sample_from_conditions(self, conditions, max_tries_per_batch=100, ) except (Exception, KeyboardInterrupt) as error: - handle_sampling_error(output_file_path == TMP_FILE_NAME, output_file_path, error) - - else: - if output_file_path == TMP_FILE_NAME and os.path.exists(output_file_path): - os.remove(output_file_path) + handle_sampling_error(output_file_path, error) return sampled @@ -1139,8 +1127,7 @@ def sample_remaining_columns(self, known_columns, max_tries_per_batch=100, batch_size (int): The batch size to use per sampling call. output_file_path (str or None): - The file to periodically write sampled rows to. Defaults to - a temporary file, if None. + The file to periodically write sampled rows to. Returns: pandas.DataFrame: @@ -1176,10 +1163,6 @@ def sample_remaining_columns(self, known_columns, max_tries_per_batch=100, ) except (Exception, KeyboardInterrupt) as error: - handle_sampling_error(output_file_path == TMP_FILE_NAME, output_file_path, error) - - else: - if output_file_path == TMP_FILE_NAME and os.path.exists(output_file_path): - os.remove(output_file_path) + handle_sampling_error(output_file_path, error) return sampled diff --git a/sdv/single_table/utils.py b/sdv/single_table/utils.py index d9bdb72d7..1ecb68267 100644 --- a/sdv/single_table/utils.py +++ b/sdv/single_table/utils.py @@ -7,7 +7,6 @@ from sdv.errors import SynthesizerInputError -TMP_FILE_NAME = '.sample.csv.temp' DISABLE_TMP_FILE = 'disable' IGNORED_DICT_KEYS = ['fitted', 'distribution', 'type'] @@ -79,12 +78,10 @@ def detect_discrete_columns(metadata, data, transformers): return discrete_columns -def handle_sampling_error(is_tmp_file, output_file_path, sampling_error): +def handle_sampling_error(output_file_path, sampling_error): """Handle sampling errors by printing a user-legible error and then raising. Args: - is_tmp_file (bool): - Whether or not the output file is a temp file. output_file_path (str): The output file path. sampling_error: @@ -97,15 +94,14 @@ def handle_sampling_error(is_tmp_file, output_file_path, sampling_error): raise sampling_error error_msg = None - if is_tmp_file: + if output_file_path is not None: error_msg = ( - 'Error: Sampling terminated. Partial results are stored in a temporary file: ' - f'{output_file_path}. This file will be overridden the next time you sample. ' - 'Please rename the file if you wish to save these results.' + f'Error: Sampling terminated. Partial results are stored in {output_file_path}.' ) - elif output_file_path is not None: + else: error_msg = ( - f'Error: Sampling terminated. Partial results are stored in {output_file_path}.' + 'Error: Sampling terminated. No results were saved due to unspecified ' + '"output_file_path".' ) if error_msg: @@ -166,17 +162,13 @@ def validate_file_path(output_file_path): if output_file_path == DISABLE_TMP_FILE: # Temporary way of disabling the output file feature, used by HMA1. return output_path - elif output_file_path: output_path = os.path.abspath(output_file_path) if os.path.exists(output_path): raise AssertionError(f'{output_path} already exists.') - else: - if os.path.exists(TMP_FILE_NAME): - os.remove(TMP_FILE_NAME) - - output_path = TMP_FILE_NAME + # Do not save a file if the user specified not to save a file. + return None # Create the file. with open(output_path, 'w+'): diff --git a/tests/unit/single_table/test_base.py b/tests/unit/single_table/test_base.py index 8e9989127..a9a3a9ee3 100644 --- a/tests/unit/single_table/test_base.py +++ b/tests/unit/single_table/test_base.py @@ -1413,7 +1413,7 @@ def test__sample_with_progress_bar_handle_sampling_error( progress_bar=progress_bar.__enter__.return_value, output_file_path=mock_validate_file_path.return_value, ) - mock_handle_sampling_error.assert_called_once_with(False, 'temp_file', keyboard_error) + mock_handle_sampling_error.assert_called_once_with('temp_file', keyboard_error) @patch('sdv.single_table.base.os') @patch('sdv.single_table.base.tqdm') @@ -1446,8 +1446,6 @@ def test__sample_with_progress_bar_removing_temp_file( progress_bar=progress_bar.__enter__.return_value, output_file_path=mock_validate_file_path.return_value, ) - mock_os.remove.assert_called_once_with('.sample.csv.temp') - mock_os.path.exists.assert_called_once_with('.sample.csv.temp') def test_sample_not_fitted(self): """Test that ``sample`` raises an error when the synthesizer is not fitted.""" @@ -1463,6 +1461,24 @@ def test_sample_not_fitted(self): with pytest.raises(SamplingError, match=expected_message): BaseSingleTableSynthesizer.sample(instance, 10) + def test__sample_with_progress_bar_without_output_filepath(self): + """Test that ``_sample_with_progress_bar`` raises an error + when the synthesizer is not fitted. + """ + # Setup + instance = Mock() + instance._fitted = True + expected_message = re.escape( + 'Error: Sampling terminated. No results were saved due to unspecified ' + '"output_file_path".\nMocked Error' + ) + instance._sample_in_batches.side_effect = RuntimeError('Mocked Error') + + # Run and Assert + with pytest.raises(RuntimeError, match=expected_message): + BaseSingleTableSynthesizer._sample_with_progress_bar( + instance, output_file_path=None, num_rows=10) + @patch('sdv.single_table.base.datetime') def test_sample(self, mock_datetime, caplog): """Test that we use ``_sample_with_progress_bar`` in this method.""" @@ -1715,8 +1731,6 @@ def test_sample_from_conditions(self, mock_validate_file_path, mock_tqdm, # Assert pd.testing.assert_frame_equal(result, pd.DataFrame({'name': ['John Doe']})) - mock_os.remove.assert_called_once_with('.sample.csv.temp') - mock_os.path.exists.assert_called_once_with('.sample.csv.temp') @patch('sdv.single_table.base.handle_sampling_error') @patch('sdv.single_table.base.tqdm') @@ -1751,7 +1765,7 @@ def test_sample_from_conditions_handle_sampling_error(self, progress_bar.__enter__.return_value.set_description.assert_called_once_with( 'Sampling conditions' ) - mock_handle_sampling_error.assert_called_once_with(False, 'temp_file', keyboard_error) + mock_handle_sampling_error.assert_called_once_with('temp_file', keyboard_error) @patch('sdv.single_table.base.os') @patch('sdv.single_table.base.check_num_rows') @@ -1785,8 +1799,6 @@ def test_sample_remaining_columns(self, mock_validate_file_path, mock_tqdm, # Assert pd.testing.assert_frame_equal(result, pd.DataFrame({'name': ['John Doe']})) - mock_os.remove.assert_called_once_with('.sample.csv.temp') - mock_os.path.exists.assert_called_once_with('.sample.csv.temp') @patch('sdv.single_table.base.handle_sampling_error') @patch('sdv.single_table.base.check_num_rows') @@ -1826,7 +1838,7 @@ def test_sample_remaining_columns_handles_sampling_error( # Assert pd.testing.assert_frame_equal(result, pd.DataFrame()) - mock_handle_sampling_error.assert_called_once_with(False, 'temp_file', keyboard_error) + mock_handle_sampling_error.assert_called_once_with('temp_file', keyboard_error) def test__validate_known_columns_nans(self): """Test that it crashes when condition has nans.""" diff --git a/tests/unit/single_table/test_utils.py b/tests/unit/single_table/test_utils.py index d3e6ffd33..dcb941bd3 100644 --- a/tests/unit/single_table/test_utils.py +++ b/tests/unit/single_table/test_utils.py @@ -8,7 +8,7 @@ from sdv.metadata.single_table import SingleTableMetadata from sdv.single_table.utils import ( _key_order, check_num_rows, detect_discrete_columns, flatten_array, flatten_dict, - handle_sampling_error, unflatten_dict) + handle_sampling_error, unflatten_dict, validate_file_path) def test_detect_discrete_columns(): @@ -223,21 +223,7 @@ def test_unflatten_dict(): assert result == expected -def test_handle_sampling_error(): - """Test when an error is raised at the end of the function when temp dir is ``True``.""" - # Run and Assert - error_msg = ( - 'Error: Sampling terminated. Partial results are stored in a temporary file: test.csv. ' - 'This file will be overridden the next time you sample. Please rename the file if you ' - 'wish to save these results.' - '\n' - 'Test error' - ) - with pytest.raises(ValueError, match=error_msg): - handle_sampling_error(True, 'test.csv', ValueError('Test error')) - - -def test_handle_sampling_error_false_temp_file(): +def test_handle_sampling_error_temp_file(): """Test that an error is raised when temp dir is ``False``.""" # Run and Assert error_msg = ( @@ -246,7 +232,7 @@ def test_handle_sampling_error_false_temp_file(): 'Test error' ) with pytest.raises(ValueError, match=error_msg): - handle_sampling_error(False, 'test.csv', ValueError('Test error')) + handle_sampling_error('test.csv', ValueError('Test error')) def test_handle_sampling_error_false_temp_file_none_output_file(): @@ -258,7 +244,7 @@ def test_handle_sampling_error_false_temp_file_none_output_file(): # Run and Assert error_msg = 'Test error' with pytest.raises(ValueError, match=error_msg): - handle_sampling_error(False, 'test.csv', ValueError('Test error')) + handle_sampling_error('test.csv', ValueError('Test error')) def test_handle_sampling_error_ignore(): @@ -266,7 +252,7 @@ def test_handle_sampling_error_ignore(): # Run and assert error_msg = 'Unable to sample any rows for the given conditions.' with pytest.raises(ValueError, match=error_msg): - handle_sampling_error(True, 'test.csv', ValueError(error_msg)) + handle_sampling_error('test.csv', ValueError(error_msg)) def test_check_num_rows_reject_sampling_error(): @@ -342,3 +328,19 @@ def test_check_num_rows_valid(warning_mock): # Assert assert warning_mock.warn.call_count == 0 + + +@patch('builtins.open') +def test_validate_file_path(mock_open): + """Test the validate_file_path function.""" + # Setup + output_path = '.sample.csv.temp' + + # Run + result = validate_file_path(output_path) + none_result = validate_file_path(None) + + # Assert + assert output_path in result + assert none_result is None + mock_open.assert_called_once_with(result, 'w+') From 3bb795c22061b2717c79e55e73c8e1f82ffbd362 Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Wed, 12 Jun 2024 14:21:47 -0500 Subject: [PATCH 2/3] Fix documentation --- sdv/lite/single_table.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sdv/lite/single_table.py b/sdv/lite/single_table.py index 9e4f2129a..5c46ca183 100644 --- a/sdv/lite/single_table.py +++ b/sdv/lite/single_table.py @@ -136,7 +136,7 @@ def sample_from_conditions(self, conditions, max_tries_per_batch=100, The batch size to use per attempt at sampling. Defaults to 10 times the number of rows. output_file_path (str or None): - The file to periodically write sampled rows to. + The file to periodically write sampled rows to. Defaults to None. Returns: pandas.DataFrame: @@ -167,7 +167,7 @@ def sample_remaining_columns(self, known_columns, max_tries_per_batch=100, The batch size to use per attempt at sampling. Defaults to 10 times the number of rows. output_file_path (str or None): - The file to periodically write sampled rows to. + The file to periodically write sampled rows to. Defaults to None. Returns: pandas.DataFrame: From a113c2187b08bef48eb7485851df4e183236b44d Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Wed, 12 Jun 2024 14:44:54 -0500 Subject: [PATCH 3/3] Add docs --- sdv/single_table/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index 154efbd93..f0d669d2a 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -929,7 +929,7 @@ def _sample_with_conditions(self, conditions, max_tries_per_batch, batch_size, progress_bar (tqdm.tqdm or None): The progress bar to update. output_file_path (str or None): - The file to periodically write sampled rows to. + The file to periodically write sampled rows to. Defaults to None. Returns: pandas.DataFrame: @@ -1052,7 +1052,7 @@ def sample_from_conditions(self, conditions, max_tries_per_batch=100, batch_size (int): The batch size to use per sampling call. output_file_path (str or None): - The file to periodically write sampled rows to. + The file to periodically write sampled rows to. Defaults to None. Returns: pandas.DataFrame: @@ -1127,7 +1127,7 @@ def sample_remaining_columns(self, known_columns, max_tries_per_batch=100, batch_size (int): The batch size to use per sampling call. output_file_path (str or None): - The file to periodically write sampled rows to. + The file to periodically write sampled rows to. Defaults to None. Returns: pandas.DataFrame: