Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using FixedCombinations constraint with an integer constraint column causes sampling to fail #2185

Merged
merged 7 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions sdv/constraints/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
cast_to_datetime64,
compute_nans_column,
get_datetime_diff,
get_mappable_combination,
logit,
matches_datetime_format,
revert_nans_columns,
Expand Down Expand Up @@ -297,9 +298,10 @@ def _fit(self, table_data):
self._combinations_to_uuids = {}
self._uuids_to_combinations = {}
for combination in self._combinations.itertuples(index=False, name=None):
mappable_combination = get_mappable_combination(combination)
uuid_str = str(uuid.uuid4())
self._combinations_to_uuids[combination] = uuid_str
self._uuids_to_combinations[uuid_str] = combination
self._combinations_to_uuids[mappable_combination] = uuid_str
self._uuids_to_combinations[uuid_str] = mappable_combination

def is_valid(self, table_data):
"""Say whether the column values are within the original combinations.
Expand Down Expand Up @@ -333,6 +335,7 @@ def _transform(self, table_data):
pandas.DataFrame:
Transformed data.
"""
table_data[self._columns] = table_data[self._columns].replace({np.nan: None})
combinations = table_data[self._columns].itertuples(index=False, name=None)
uuids = map(self._combinations_to_uuids.get, combinations)
table_data[self._joint_column] = list(uuids)
Expand Down
13 changes: 13 additions & 0 deletions sdv/constraints/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,16 @@ def get_datetime_diff(high, low, high_datetime_format=None, low_datetime_format=
diff_column = diff_column.astype(np.float64)
diff_column[nan_mask] = np.nan
return diff_column


def get_mappable_combination(combination):
"""Get a mappable combination of values.

This function replaces NaN values with None inside the tuple
to ensure consistent comparisons when using mapping.

Args:
combination (tuple):
A combination of values.
amontanez24 marked this conversation as resolved.
Show resolved Hide resolved
"""
return tuple(None if pd.isna(x) else x for x in combination)
39 changes: 39 additions & 0 deletions tests/integration/constraints/test_tabular.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import numpy as np
import pandas as pd

from sdv.metadata import SingleTableMetadata
from sdv.single_table import GaussianCopulaSynthesizer


def test_fixed_combinations_with_nans():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is crashing on the main branch with the error described in the issue

"""Test that FixedCombinations constraint works with NaNs."""
# Setup
data = pd.DataFrame({
'A': [1, 2, np.nan, 1, 2, 1],
amontanez24 marked this conversation as resolved.
Show resolved Hide resolved
'B': [10, 20, 30, 10, 20, 10],
})
metadata = SingleTableMetadata().load_from_dict({
'columns': {
'A': {'sdtype': 'categorical'},
'B': {'sdtype': 'categorical'},
}
})

synthesizer = GaussianCopulaSynthesizer(metadata)
my_constraint = {
'constraint_class': 'FixedCombinations',
'constraint_parameters': {'column_names': ['A', 'B']},
}
synthesizer.add_constraints(constraints=[my_constraint])

# Run
synthesizer.fit(data)
synthetic_data = synthesizer.sample(1000)

# Assert
assert len(synthetic_data) == 1000
pd.testing.assert_frame_equal(
synthetic_data.drop_duplicates(ignore_index=True),
data.drop_duplicates(ignore_index=True),
check_like=True,
)
10 changes: 8 additions & 2 deletions tests/unit/constraints/test_tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import operator
import re
from datetime import datetime
from unittest.mock import MagicMock, Mock, patch
from unittest.mock import MagicMock, Mock, call, patch

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -658,7 +658,8 @@ def test___init__with_one_column(self):
with pytest.raises(ValueError, match=err_msg):
FixedCombinations(column_names=columns)

def test__fit(self):
@patch('sdv.constraints.tabular.get_mappable_combination')
def test__fit(self, get_mappable_combination_mock):
"""Test the ``FixedCombinations._fit`` method.

The ``FixedCombinations.fit`` method is expected to:
Expand All @@ -683,9 +684,14 @@ def test__fit(self):

# Asserts
expected_combinations = pd.DataFrame({'b': ['d', 'e', 'f'], 'c': ['g', 'h', 'i']})
expected_calls = [
call(combination)
for combination in instance._combinations.itertuples(index=False, name=None)
]
assert instance._separator == '##'
assert instance._joint_column == 'b##c'
pd.testing.assert_frame_equal(instance._combinations, expected_combinations)
assert get_mappable_combination_mock.call_args_list == expected_calls

def test_is_valid_true(self):
"""Test the ``FixedCombinations.is_valid`` method.
Expand Down
17 changes: 17 additions & 0 deletions tests/unit/constraints/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
cast_to_datetime64,
compute_nans_column,
get_datetime_diff,
get_mappable_combination,
get_nan_component_value,
logit,
matches_datetime_format,
Expand Down Expand Up @@ -302,3 +303,19 @@ def test_get_datetime_diff():

# Assert
assert np.array_equal(expected, diff, equal_nan=True)


def test_get_mappable_combination():
"""Test the ``get_mappable_combination`` method."""
# Setup
already_mappable = ('a', 1, 1.2, 'b')
not_mappable = ('a', 1, np.nan, 'b')

# Run
result_already_mappable = get_mappable_combination(already_mappable)
result_not_mappable = get_mappable_combination(not_mappable)

# Assert
expected_result_not_mappable = ('a', 1, None, 'b')
assert result_already_mappable == already_mappable
assert result_not_mappable == expected_result_not_mappable
Loading