diff --git a/sdv/constraints/tabular.py b/sdv/constraints/tabular.py index 63fdcaa9a..d78c7ff10 100644 --- a/sdv/constraints/tabular.py +++ b/sdv/constraints/tabular.py @@ -48,6 +48,7 @@ cast_to_datetime64, compute_nans_column, get_datetime_diff, + get_mappable_combination, logit, matches_datetime_format, revert_nans_columns, @@ -297,9 +298,10 @@ def _fit(self, table_data): self._combinations_to_uuids = {} self._uuids_to_combinations = {} for combination in self._combinations.itertuples(index=False, name=None): + mappable_combination = get_mappable_combination(combination) uuid_str = str(uuid.uuid4()) - self._combinations_to_uuids[combination] = uuid_str - self._uuids_to_combinations[uuid_str] = combination + self._combinations_to_uuids[mappable_combination] = uuid_str + self._uuids_to_combinations[uuid_str] = mappable_combination def is_valid(self, table_data): """Say whether the column values are within the original combinations. @@ -333,6 +335,7 @@ def _transform(self, table_data): pandas.DataFrame: Transformed data. """ + table_data[self._columns] = table_data[self._columns].replace({np.nan: None}) combinations = table_data[self._columns].itertuples(index=False, name=None) uuids = map(self._combinations_to_uuids.get, combinations) table_data[self._joint_column] = list(uuids) diff --git a/sdv/constraints/utils.py b/sdv/constraints/utils.py index 714a2c489..c395d3a29 100644 --- a/sdv/constraints/utils.py +++ b/sdv/constraints/utils.py @@ -204,3 +204,20 @@ def get_datetime_diff(high, low, high_datetime_format=None, low_datetime_format= diff_column = diff_column.astype(np.float64) diff_column[nan_mask] = np.nan return diff_column + + +def get_mappable_combination(combination): + """Get a mappable combination of values. + + This function replaces NaN values with None inside the tuple + to ensure consistent comparisons when using mapping. + + Args: + combination (tuple): + A combination of values. + + Returns: + tuple: + A mappable combination of values. + """ + return tuple(None if pd.isna(x) else x for x in combination) diff --git a/tests/integration/constraints/test_tabular.py b/tests/integration/constraints/test_tabular.py new file mode 100644 index 000000000..23fa701f0 --- /dev/null +++ b/tests/integration/constraints/test_tabular.py @@ -0,0 +1,72 @@ +import numpy as np +import pandas as pd + +from sdv.metadata import SingleTableMetadata +from sdv.single_table import GaussianCopulaSynthesizer + + +def test_fixed_combinations_integers(): + """Test that FixedCombinations constraint works with integer columns.""" + data = pd.DataFrame({ + 'A': [1, 2, 3, 1, 2, 1], + 'B': [10, 20, 30, 10, 20, 10], + }) + metadata = SingleTableMetadata().load_from_dict({ + 'columns': { + 'A': {'sdtype': 'categorical'}, + 'B': {'sdtype': 'categorical'}, + } + }) + + synthesizer = GaussianCopulaSynthesizer(metadata) + my_constraint = { + 'constraint_class': 'FixedCombinations', + 'constraint_parameters': {'column_names': ['A', 'B']}, + } + synthesizer.add_constraints(constraints=[my_constraint]) + + # Run + synthesizer.fit(data) + synthetic_data = synthesizer.sample(1000) + + # Assert + assert len(synthetic_data) == 1000 + pd.testing.assert_frame_equal( + synthetic_data.drop_duplicates(ignore_index=True), + data.drop_duplicates(ignore_index=True), + check_like=True, + ) + + +def test_fixed_combinations_with_nans(): + """Test that FixedCombinations constraint works with NaNs.""" + # Setup + data = pd.DataFrame({ + 'A': [1, 2, np.nan, 1, 2, 1], + 'B': [10, 20, 30, 10, 20, 10], + }) + metadata = SingleTableMetadata().load_from_dict({ + 'columns': { + 'A': {'sdtype': 'categorical'}, + 'B': {'sdtype': 'categorical'}, + } + }) + + synthesizer = GaussianCopulaSynthesizer(metadata) + my_constraint = { + 'constraint_class': 'FixedCombinations', + 'constraint_parameters': {'column_names': ['A', 'B']}, + } + synthesizer.add_constraints(constraints=[my_constraint]) + + # Run + synthesizer.fit(data) + synthetic_data = synthesizer.sample(1000) + + # Assert + assert len(synthetic_data) == 1000 + pd.testing.assert_frame_equal( + synthetic_data.drop_duplicates(ignore_index=True), + data.drop_duplicates(ignore_index=True), + check_like=True, + ) diff --git a/tests/unit/constraints/test_tabular.py b/tests/unit/constraints/test_tabular.py index 2d4b0ce04..306cb581e 100644 --- a/tests/unit/constraints/test_tabular.py +++ b/tests/unit/constraints/test_tabular.py @@ -3,7 +3,7 @@ import operator import re from datetime import datetime -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import MagicMock, Mock, call, patch import numpy as np import pandas as pd @@ -658,7 +658,8 @@ def test___init__with_one_column(self): with pytest.raises(ValueError, match=err_msg): FixedCombinations(column_names=columns) - def test__fit(self): + @patch('sdv.constraints.tabular.get_mappable_combination') + def test__fit(self, get_mappable_combination_mock): """Test the ``FixedCombinations._fit`` method. The ``FixedCombinations.fit`` method is expected to: @@ -683,9 +684,14 @@ def test__fit(self): # Asserts expected_combinations = pd.DataFrame({'b': ['d', 'e', 'f'], 'c': ['g', 'h', 'i']}) + expected_calls = [ + call(combination) + for combination in instance._combinations.itertuples(index=False, name=None) + ] assert instance._separator == '##' assert instance._joint_column == 'b##c' pd.testing.assert_frame_equal(instance._combinations, expected_combinations) + assert get_mappable_combination_mock.call_args_list == expected_calls def test_is_valid_true(self): """Test the ``FixedCombinations.is_valid`` method. diff --git a/tests/unit/constraints/test_utils.py b/tests/unit/constraints/test_utils.py index 514ffa13d..5775db150 100644 --- a/tests/unit/constraints/test_utils.py +++ b/tests/unit/constraints/test_utils.py @@ -10,6 +10,7 @@ cast_to_datetime64, compute_nans_column, get_datetime_diff, + get_mappable_combination, get_nan_component_value, logit, matches_datetime_format, @@ -302,3 +303,19 @@ def test_get_datetime_diff(): # Assert assert np.array_equal(expected, diff, equal_nan=True) + + +def test_get_mappable_combination(): + """Test the ``get_mappable_combination`` method.""" + # Setup + already_mappable = ('a', 1, 1.2, 'b') + not_mappable = ('a', 1, np.nan, 'b') + + # Run + result_already_mappable = get_mappable_combination(already_mappable) + result_not_mappable = get_mappable_combination(not_mappable) + + # Assert + expected_result_not_mappable = ('a', 1, None, 'b') + assert result_already_mappable == already_mappable + assert result_not_mappable == expected_result_not_mappable