diff --git a/sdv/constraints/tabular.py b/sdv/constraints/tabular.py index 2e249b5d2..55476a741 100644 --- a/sdv/constraints/tabular.py +++ b/sdv/constraints/tabular.py @@ -681,9 +681,12 @@ def _transform(self, table_data): Transformed data. """ column = table_data[self._column_name].to_numpy() - diff_column = abs(column - self._value) if self._is_datetime: + column = cast_to_datetime64(column) + diff_column = abs(column - self._value) diff_column = diff_column.astype(np.float64) + else: + diff_column = abs(column - self._value) self._diff_column_name = create_unique_name(self._diff_column_name, table_data.columns) table_data[self._diff_column_name] = np.log(diff_column + 1) @@ -709,7 +712,7 @@ def _reverse_transform(self, table_data): diff_column = diff_column.round() if self._is_datetime: - diff_column = diff_column.astype('timedelta64[ns]') + diff_column = convert_to_timedelta(diff_column) if self._operator in [np.greater, np.greater_equal]: original_column = self._value + diff_column @@ -914,9 +917,10 @@ def _transform(self, table_data): pandas.DataFrame: Transformed data. """ - low = table_data[self.low_column_name] - middle = table_data[self.middle_column_name] - high = table_data[self.high_column_name] + # Using ``to_numpy`` since ``get_datetime_diff`` requires ``numpy.ndarray`` + low = table_data[self.low_column_name].to_numpy() + middle = table_data[self.middle_column_name].to_numpy() + high = table_data[self.high_column_name].to_numpy() if self._is_datetime: low_diff_column = get_datetime_diff(middle, low, self._dtype) @@ -1115,6 +1119,9 @@ def is_valid(self, table_data): """ data = table_data[self._column_name] + if self._is_datetime: + data = cast_to_datetime64(data) + satisfy_low_bound = np.logical_or( self._operator(self._low_value, data), np.isnan(self._low_value), @@ -1144,7 +1151,11 @@ def _transform(self, table_data): pandas.DataFrame: Transformed data. """ - data = logit(table_data[self._column_name], self._low_value, self._high_value) + data = table_data[self._column_name] + if self._is_datetime: + data = cast_to_datetime64(table_data[self._column_name]) + + data = logit(data, self._low_value, self._high_value) table_data[self._transformed_column] = data table_data = table_data.drop(self._column_name, axis=1) diff --git a/sdv/constraints/utils.py b/sdv/constraints/utils.py index c35edae98..cd083c1d1 100644 --- a/sdv/constraints/utils.py +++ b/sdv/constraints/utils.py @@ -20,11 +20,12 @@ def cast_to_datetime64(value): if isinstance(value, str): value = pd.to_datetime(value).to_datetime64() elif isinstance(value, pd.Series): - value.apply(lambda x: pd.to_datetime(x).to_datetime64()) value = value.astype('datetime64[ns]') elif isinstance(value, (np.ndarray, list)): value = np.array([ pd.to_datetime(item).to_datetime64() + if not pd.isna(item) + else pd.NaT.to_datetime64() for item in value ]) diff --git a/tests/integration/single_table/test_base.py b/tests/integration/single_table/test_base.py index c4a8fedc9..8e1b3706b 100644 --- a/tests/integration/single_table/test_base.py +++ b/tests/integration/single_table/test_base.py @@ -810,6 +810,263 @@ def test_synthesizer_with_inequality_constraint(): ) +def test_inequality_constraint_with_datetimes_and_nones(): + """Test that the ``Inequality`` constraint works with ``None`` and ``datetime``.""" + # Setup + data = pd.DataFrame(data={ + 'A': [None, None, '2020-01-02', '2020-03-04'] * 2, + 'B': [None, '2021-03-04', '2021-12-31', None] * 2 + }) + + metadata = SingleTableMetadata.load_from_dict({ + 'columns': { + 'A': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, + 'B': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'} + } + }) + + metadata.validate() + synth = GaussianCopulaSynthesizer(metadata) + synth.add_constraints([ + { + 'constraint_class': 'Inequality', + 'constraint_parameters': { + 'low_column_name': 'A', + 'high_column_name': 'B' + } + } + ]) + synth.validate(data) + + # Run + synth.fit(data) + sampled = synth.sample(10) + + # Assert + expected_sampled = pd.DataFrame({ + 'A': { + 0: '2020-01-02', + 1: '2019-10-30', + 2: np.nan, + 3: np.nan, + 4: '2020-01-02', + 5: np.nan, + 6: '2019-10-30', + 7: np.nan, + 8: '2020-01-02', + 9: np.nan + }, + 'B': { + 0: '2021-12-30', + 1: '2021-10-27', + 2: '2021-10-27', + 3: '2021-10-27', + 4: np.nan, + 5: '2021-10-27', + 6: '2021-10-27', + 7: '2021-12-30', + 8: np.nan, + 9: '2021-10-27' + } + }) + pd.testing.assert_frame_equal(expected_sampled, sampled) + + +def test_scalar_inequality_constraint_with_datetimes_and_nones(): + """Test that the ``ScalarInequality`` constraint works with ``None`` and ``datetime``.""" + # Setup + data = pd.DataFrame(data={ + 'A': [None, None, '2020-01-02', '2020-03-04'], + 'B': [None, '2021-03-04', '2021-12-31', None] + }) + + metadata = SingleTableMetadata.load_from_dict({ + 'columns': { + 'A': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, + 'B': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'} + } + }) + + metadata.validate() + synth = GaussianCopulaSynthesizer(metadata) + synth.add_constraints([ + { + 'constraint_class': 'ScalarInequality', + 'constraint_parameters': { + 'column_name': 'A', + 'relation': '>=', + 'value': '2019-01-01' + } + } + ]) + synth.validate(data) + + # Run + synth.fit(data) + sampled = synth.sample(5) + + # Assert + expected_sampled = pd.DataFrame({ + 'A': { + 0: np.nan, + 1: '2020-01-19', + 2: np.nan, + 3: '2020-01-29', + 4: '2020-01-31', + }, + 'B': { + 0: '2021-07-28', + 1: '2021-07-14', + 2: '2021-07-26', + 3: '2021-07-02', + 4: '2021-06-06', + } + }) + pd.testing.assert_frame_equal(expected_sampled, sampled) + + +def test_scalar_range_constraint_with_datetimes_and_nones(): + """Test that the ``ScalarRange`` constraint works with ``None`` and ``datetime``.""" + # Setup + data = pd.DataFrame(data={ + 'A': [None, None, '2020-01-02', '2020-03-04'], + 'B': [None, '2021-03-04', '2021-12-31', None] + }) + + metadata = SingleTableMetadata.load_from_dict({ + 'columns': { + 'A': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, + 'B': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'} + } + }) + + metadata.validate() + synth = GaussianCopulaSynthesizer(metadata) + synth.add_constraints([ + { + 'constraint_class': 'ScalarRange', + 'constraint_parameters': { + 'column_name': 'A', + 'low_value': '2019-10-30', + 'high_value': '2020-03-04', + 'strict_boundaries': False + } + } + ]) + synth.validate(data) + + # Run + synth.fit(data) + sampled = synth.sample(10) + + # Assert + expected_sampled = pd.DataFrame({ + 'A': { + 0: '2020-03-03', + 1: np.nan, + 2: '2020-03-03', + 3: np.nan, + 4: np.nan, + 5: '2020-03-03', + 6: np.nan, + 7: np.nan, + 8: np.nan, + 9: '2020-02-27', + }, + 'B': { + 0: np.nan, + 1: np.nan, + 2: np.nan, + 3: np.nan, + 4: np.nan, + 5: '2021-04-14', + 6: np.nan, + 7: '2021-05-21', + 8: np.nan, + 9: np.nan, + } + }) + pd.testing.assert_frame_equal(expected_sampled, sampled) + + +def test_range_constraint_with_datetimes_and_nones(): + """Test that the ``Range`` constraint works with ``None`` and ``datetime``.""" + # Setup + data = pd.DataFrame(data={ + 'A': [None, None, '2020-01-02', '2020-03-04'], + 'B': [None, '2021-03-04', '2021-12-31', None], + 'C': [None, '2022-03-04', '2022-12-31', None], + }) + + metadata = SingleTableMetadata.load_from_dict({ + 'columns': { + 'A': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, + 'B': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, + 'C': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'} + } + }) + + metadata.validate() + synth = GaussianCopulaSynthesizer(metadata) + synth.add_constraints([ + { + 'constraint_class': 'Range', + 'constraint_parameters': { + 'low_column_name': 'A', + 'middle_column_name': 'B', + 'high_column_name': 'C', + 'strict_boundaries': False + } + } + ]) + synth.validate(data) + + # Run + synth.fit(data) + sampled = synth.sample(10) + + # Assert + expected_sampled = pd.DataFrame({ + 'A': { + 0: '2020-01-02', + 1: '2020-01-02', + 2: np.nan, + 3: '2020-01-02', + 4: '2019-10-30', + 5: np.nan, + 6: '2020-01-02', + 7: '2019-10-30', + 8: '2019-10-30', + 9: np.nan + }, + 'B': { + 0: '2021-12-30', + 1: '2021-12-30', + 2: '2021-10-27', + 3: np.nan, + 4: '2021-10-27', + 5: '2021-10-27', + 6: np.nan, + 7: '2021-10-27', + 8: np.nan, + 9: '2021-10-27' + }, + 'C': { + 0: '2022-12-30', + 1: '2022-12-30', + 2: '2022-10-27', + 3: np.nan, + 4: '2022-10-27', + 5: '2022-10-27', + 6: np.nan, + 7: '2022-10-27', + 8: np.nan, + 9: '2022-10-27' + } + }) + pd.testing.assert_frame_equal(expected_sampled, sampled) + + def test_inequality_constraint_all_possible_nans_configurations(): """Test that the inequality constraint works with all possible NaN configurations.""" # Setup diff --git a/tests/unit/constraints/test_utils.py b/tests/unit/constraints/test_utils.py index e564dc859..8d6c2b72d 100644 --- a/tests/unit/constraints/test_utils.py +++ b/tests/unit/constraints/test_utils.py @@ -108,8 +108,8 @@ def test_cast_to_datetime64(): """ # Setup string_value = '2021-02-02' - list_value = [np.nan, '2021-02-02'] - series_value = pd.Series(['2021-02-02']) + list_value = [None, np.nan, '2021-02-02'] + series_value = pd.Series(['2021-02-02', None, pd.NaT]) # Run string_out = cast_to_datetime64(string_value) @@ -118,8 +118,16 @@ def test_cast_to_datetime64(): # Assert expected_string_output = np.datetime64('2021-02-02') - expected_series_output = pd.Series(np.datetime64('2021-02-02')) - expected_list_output = np.array([np.datetime64('NaT'), '2021-02-02'], dtype='datetime64[ns]') + expected_series_output = pd.Series([ + np.datetime64('2021-02-02'), + np.datetime64('NaT'), + np.datetime64('NaT') + ]) + expected_list_output = np.array([ + np.datetime64('NaT'), + np.datetime64('NaT'), + '2021-02-02' + ], dtype='datetime64[ns]') np.testing.assert_array_equal(expected_list_output, list_out) pd.testing.assert_series_equal(expected_series_output, series_out) assert expected_string_output == string_out