Skip to content

Commit

Permalink
Fix constraints failing with datetime and None values (#1481)
Browse files Browse the repository at this point in the history
* Fix bug with datetime values and constraints
  • Loading branch information
pvk-developer committed Jul 5, 2023
1 parent 9d5e86c commit 750332c
Show file tree
Hide file tree
Showing 4 changed files with 288 additions and 11 deletions.
23 changes: 17 additions & 6 deletions sdv/constraints/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,9 +681,12 @@ def _transform(self, table_data):
Transformed data.
"""
column = table_data[self._column_name].to_numpy()
diff_column = abs(column - self._value)
if self._is_datetime:
column = cast_to_datetime64(column)
diff_column = abs(column - self._value)
diff_column = diff_column.astype(np.float64)
else:
diff_column = abs(column - self._value)

self._diff_column_name = create_unique_name(self._diff_column_name, table_data.columns)
table_data[self._diff_column_name] = np.log(diff_column + 1)
Expand All @@ -709,7 +712,7 @@ def _reverse_transform(self, table_data):
diff_column = diff_column.round()

if self._is_datetime:
diff_column = diff_column.astype('timedelta64[ns]')
diff_column = convert_to_timedelta(diff_column)

if self._operator in [np.greater, np.greater_equal]:
original_column = self._value + diff_column
Expand Down Expand Up @@ -914,9 +917,10 @@ def _transform(self, table_data):
pandas.DataFrame:
Transformed data.
"""
low = table_data[self.low_column_name]
middle = table_data[self.middle_column_name]
high = table_data[self.high_column_name]
# Using ``to_numpy`` since ``get_datetime_diff`` requires ``numpy.ndarray``
low = table_data[self.low_column_name].to_numpy()
middle = table_data[self.middle_column_name].to_numpy()
high = table_data[self.high_column_name].to_numpy()

if self._is_datetime:
low_diff_column = get_datetime_diff(middle, low, self._dtype)
Expand Down Expand Up @@ -1115,6 +1119,9 @@ def is_valid(self, table_data):
"""
data = table_data[self._column_name]

if self._is_datetime:
data = cast_to_datetime64(data)

satisfy_low_bound = np.logical_or(
self._operator(self._low_value, data),
np.isnan(self._low_value),
Expand Down Expand Up @@ -1144,7 +1151,11 @@ def _transform(self, table_data):
pandas.DataFrame:
Transformed data.
"""
data = logit(table_data[self._column_name], self._low_value, self._high_value)
data = table_data[self._column_name]
if self._is_datetime:
data = cast_to_datetime64(table_data[self._column_name])

data = logit(data, self._low_value, self._high_value)
table_data[self._transformed_column] = data
table_data = table_data.drop(self._column_name, axis=1)

Expand Down
3 changes: 2 additions & 1 deletion sdv/constraints/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ def cast_to_datetime64(value):
if isinstance(value, str):
value = pd.to_datetime(value).to_datetime64()
elif isinstance(value, pd.Series):
value.apply(lambda x: pd.to_datetime(x).to_datetime64())
value = value.astype('datetime64[ns]')
elif isinstance(value, (np.ndarray, list)):
value = np.array([
pd.to_datetime(item).to_datetime64()
if not pd.isna(item)
else pd.NaT.to_datetime64()
for item in value
])

Expand Down
257 changes: 257 additions & 0 deletions tests/integration/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,263 @@ def test_synthesizer_with_inequality_constraint():
)


def test_inequality_constraint_with_datetimes_and_nones():
"""Test that the ``Inequality`` constraint works with ``None`` and ``datetime``."""
# Setup
data = pd.DataFrame(data={
'A': [None, None, '2020-01-02', '2020-03-04'] * 2,
'B': [None, '2021-03-04', '2021-12-31', None] * 2
})

metadata = SingleTableMetadata.load_from_dict({
'columns': {
'A': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'},
'B': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}
}
})

metadata.validate()
synth = GaussianCopulaSynthesizer(metadata)
synth.add_constraints([
{
'constraint_class': 'Inequality',
'constraint_parameters': {
'low_column_name': 'A',
'high_column_name': 'B'
}
}
])
synth.validate(data)

# Run
synth.fit(data)
sampled = synth.sample(10)

# Assert
expected_sampled = pd.DataFrame({
'A': {
0: '2020-01-02',
1: '2019-10-30',
2: np.nan,
3: np.nan,
4: '2020-01-02',
5: np.nan,
6: '2019-10-30',
7: np.nan,
8: '2020-01-02',
9: np.nan
},
'B': {
0: '2021-12-30',
1: '2021-10-27',
2: '2021-10-27',
3: '2021-10-27',
4: np.nan,
5: '2021-10-27',
6: '2021-10-27',
7: '2021-12-30',
8: np.nan,
9: '2021-10-27'
}
})
pd.testing.assert_frame_equal(expected_sampled, sampled)


def test_scalar_inequality_constraint_with_datetimes_and_nones():
"""Test that the ``ScalarInequality`` constraint works with ``None`` and ``datetime``."""
# Setup
data = pd.DataFrame(data={
'A': [None, None, '2020-01-02', '2020-03-04'],
'B': [None, '2021-03-04', '2021-12-31', None]
})

metadata = SingleTableMetadata.load_from_dict({
'columns': {
'A': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'},
'B': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}
}
})

metadata.validate()
synth = GaussianCopulaSynthesizer(metadata)
synth.add_constraints([
{
'constraint_class': 'ScalarInequality',
'constraint_parameters': {
'column_name': 'A',
'relation': '>=',
'value': '2019-01-01'
}
}
])
synth.validate(data)

# Run
synth.fit(data)
sampled = synth.sample(5)

# Assert
expected_sampled = pd.DataFrame({
'A': {
0: np.nan,
1: '2020-01-19',
2: np.nan,
3: '2020-01-29',
4: '2020-01-31',
},
'B': {
0: '2021-07-28',
1: '2021-07-14',
2: '2021-07-26',
3: '2021-07-02',
4: '2021-06-06',
}
})
pd.testing.assert_frame_equal(expected_sampled, sampled)


def test_scalar_range_constraint_with_datetimes_and_nones():
"""Test that the ``ScalarRange`` constraint works with ``None`` and ``datetime``."""
# Setup
data = pd.DataFrame(data={
'A': [None, None, '2020-01-02', '2020-03-04'],
'B': [None, '2021-03-04', '2021-12-31', None]
})

metadata = SingleTableMetadata.load_from_dict({
'columns': {
'A': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'},
'B': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}
}
})

metadata.validate()
synth = GaussianCopulaSynthesizer(metadata)
synth.add_constraints([
{
'constraint_class': 'ScalarRange',
'constraint_parameters': {
'column_name': 'A',
'low_value': '2019-10-30',
'high_value': '2020-03-04',
'strict_boundaries': False
}
}
])
synth.validate(data)

# Run
synth.fit(data)
sampled = synth.sample(10)

# Assert
expected_sampled = pd.DataFrame({
'A': {
0: '2020-03-03',
1: np.nan,
2: '2020-03-03',
3: np.nan,
4: np.nan,
5: '2020-03-03',
6: np.nan,
7: np.nan,
8: np.nan,
9: '2020-02-27',
},
'B': {
0: np.nan,
1: np.nan,
2: np.nan,
3: np.nan,
4: np.nan,
5: '2021-04-14',
6: np.nan,
7: '2021-05-21',
8: np.nan,
9: np.nan,
}
})
pd.testing.assert_frame_equal(expected_sampled, sampled)


def test_range_constraint_with_datetimes_and_nones():
"""Test that the ``Range`` constraint works with ``None`` and ``datetime``."""
# Setup
data = pd.DataFrame(data={
'A': [None, None, '2020-01-02', '2020-03-04'],
'B': [None, '2021-03-04', '2021-12-31', None],
'C': [None, '2022-03-04', '2022-12-31', None],
})

metadata = SingleTableMetadata.load_from_dict({
'columns': {
'A': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'},
'B': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'},
'C': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}
}
})

metadata.validate()
synth = GaussianCopulaSynthesizer(metadata)
synth.add_constraints([
{
'constraint_class': 'Range',
'constraint_parameters': {
'low_column_name': 'A',
'middle_column_name': 'B',
'high_column_name': 'C',
'strict_boundaries': False
}
}
])
synth.validate(data)

# Run
synth.fit(data)
sampled = synth.sample(10)

# Assert
expected_sampled = pd.DataFrame({
'A': {
0: '2020-01-02',
1: '2020-01-02',
2: np.nan,
3: '2020-01-02',
4: '2019-10-30',
5: np.nan,
6: '2020-01-02',
7: '2019-10-30',
8: '2019-10-30',
9: np.nan
},
'B': {
0: '2021-12-30',
1: '2021-12-30',
2: '2021-10-27',
3: np.nan,
4: '2021-10-27',
5: '2021-10-27',
6: np.nan,
7: '2021-10-27',
8: np.nan,
9: '2021-10-27'
},
'C': {
0: '2022-12-30',
1: '2022-12-30',
2: '2022-10-27',
3: np.nan,
4: '2022-10-27',
5: '2022-10-27',
6: np.nan,
7: '2022-10-27',
8: np.nan,
9: '2022-10-27'
}
})
pd.testing.assert_frame_equal(expected_sampled, sampled)


def test_inequality_constraint_all_possible_nans_configurations():
"""Test that the inequality constraint works with all possible NaN configurations."""
# Setup
Expand Down
16 changes: 12 additions & 4 deletions tests/unit/constraints/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ def test_cast_to_datetime64():
"""
# Setup
string_value = '2021-02-02'
list_value = [np.nan, '2021-02-02']
series_value = pd.Series(['2021-02-02'])
list_value = [None, np.nan, '2021-02-02']
series_value = pd.Series(['2021-02-02', None, pd.NaT])

# Run
string_out = cast_to_datetime64(string_value)
Expand All @@ -118,8 +118,16 @@ def test_cast_to_datetime64():

# Assert
expected_string_output = np.datetime64('2021-02-02')
expected_series_output = pd.Series(np.datetime64('2021-02-02'))
expected_list_output = np.array([np.datetime64('NaT'), '2021-02-02'], dtype='datetime64[ns]')
expected_series_output = pd.Series([
np.datetime64('2021-02-02'),
np.datetime64('NaT'),
np.datetime64('NaT')
])
expected_list_output = np.array([
np.datetime64('NaT'),
np.datetime64('NaT'),
'2021-02-02'
], dtype='datetime64[ns]')
np.testing.assert_array_equal(expected_list_output, list_out)
pd.testing.assert_series_equal(expected_series_output, series_out)
assert expected_string_output == string_out
Expand Down

0 comments on commit 750332c

Please sign in to comment.