diff --git a/sdv/data_processing/data_processor.py b/sdv/data_processing/data_processor.py index badc9a5a4..e70d4ee61 100644 --- a/sdv/data_processing/data_processor.py +++ b/sdv/data_processing/data_processor.py @@ -9,7 +9,7 @@ import pandas as pd import rdt from pandas.api.types import is_float_dtype, is_integer_dtype -from rdt.transformers import AnonymizedFaker, RegexGenerator, get_default_transformers +from rdt.transformers import AnonymizedFaker, IDGenerator, RegexGenerator, get_default_transformers from sdv.constraints import Constraint from sdv.constraints.base import get_subclasses @@ -458,13 +458,30 @@ def _create_config(self, data, columns_created_by_constraints): if sdtype == 'id': is_numeric = pd.api.types.is_numeric_dtype(data[column].dtype) - transformers[column] = self.create_regex_generator( - column, - sdtype, - column_metadata, - is_numeric - ) - sdtypes[column] = 'text' + if column_metadata.get('regex_format', False): + transformers[column] = self.create_regex_generator( + column, + sdtype, + column_metadata, + is_numeric + ) + sdtypes[column] = 'text' + + elif column in self._keys: + prefix = None + if not is_numeric: + prefix = 'sdv-id-' + + transformers[column] = IDGenerator(prefix=prefix) + sdtypes[column] = 'text' + + else: + transformers[column] = AnonymizedFaker( + provider_name=None, + function_name='bothify', + function_kwargs={'text': '#####'} + ) + sdtypes[column] = 'pii' elif pii: enforce_uniqueness = bool(column in self._keys) diff --git a/setup.py b/setup.py index 5d22f4f60..f67f2cd6b 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,7 @@ 'copulas>=0.9.0,<0.10', 'ctgan>=0.7.4,<0.8', 'deepecho>=0.4.2,<0.5', - 'rdt>=1.6.1,<2', + 'rdt>=1.7.0.dev0', 'sdmetrics>=0.11.0,<0.12', 'cloudpickle>=2.1.0,<3.0', 'boto3>=1.15.0,<2', diff --git a/tests/integration/data_processing/test_data_processor.py b/tests/integration/data_processing/test_data_processor.py index b8eb73d41..5fbc610c9 100644 --- a/tests/integration/data_processing/test_data_processor.py +++ b/tests/integration/data_processing/test_data_processor.py @@ -3,7 +3,7 @@ import numpy as np from rdt.transformers import ( - AnonymizedFaker, BinaryEncoder, FloatFormatter, LabelEncoder, RegexGenerator, + AnonymizedFaker, BinaryEncoder, FloatFormatter, RegexGenerator, UniformEncoder, UnixTimestampEncoder) from sdv.data_processing import DataProcessor @@ -247,23 +247,23 @@ def test_data_processor_prepare_for_fitting(): # Assert field_transformers = dp._hyper_transformer.field_transformers expected_transformers = { - 'mba_spec': LabelEncoder, + 'mba_spec': UniformEncoder, 'employability_perc': FloatFormatter, - 'placed': LabelEncoder, + 'placed': UniformEncoder, 'student_id': RegexGenerator, 'experience_years': FloatFormatter, - 'duration': LabelEncoder, + 'duration': UniformEncoder, 'salary': FloatFormatter, 'second_perc': FloatFormatter, 'start_date': UnixTimestampEncoder, 'address': AnonymizedFaker, - 'gender': LabelEncoder, + 'gender': UniformEncoder, 'mba_perc': FloatFormatter, - 'degree_type': LabelEncoder, + 'degree_type': UniformEncoder, 'end_date': UnixTimestampEncoder, - 'high_spec': LabelEncoder, + 'high_spec': UniformEncoder, 'high_perc': FloatFormatter, - 'work_experience': LabelEncoder, + 'work_experience': UniformEncoder, 'degree_perc': FloatFormatter } for column_name, transformer_class in expected_transformers.items(): diff --git a/tests/integration/single_table/test_base.py b/tests/integration/single_table/test_base.py index 0176e477a..d7426a96a 100644 --- a/tests/integration/single_table/test_base.py +++ b/tests/integration/single_table/test_base.py @@ -3,7 +3,7 @@ import pandas as pd import pkg_resources import pytest -from rdt.transformers import AnonymizedFaker, FloatFormatter, LabelEncoder, RegexGenerator +from rdt.transformers import AnonymizedFaker, FloatFormatter, RegexGenerator, UniformEncoder from sdv.metadata import SingleTableMetadata from sdv.sampling import Condition @@ -252,7 +252,7 @@ def test_transformers_correctly_auto_assigned(): assert isinstance(transformers['numerical_col'], FloatFormatter) assert isinstance(transformers['pii_col'], AnonymizedFaker) assert isinstance(transformers['primary_key'], RegexGenerator) - assert isinstance(transformers['categorical_col'], LabelEncoder) + assert isinstance(transformers['categorical_col'], UniformEncoder) assert transformers['numerical_col'].missing_value_replacement == 'mean' assert transformers['numerical_col'].missing_value_generation == 'random' @@ -265,8 +265,6 @@ def test_transformers_correctly_auto_assigned(): assert transformers['primary_key'].regex_format == 'user-[0-9]{3}' assert transformers['primary_key'].enforce_uniqueness is True - assert transformers['categorical_col'].add_noise is True - def test_modeling_with_complex_datetimes(): """Test that models work with datetimes passed as strings or ints with complex format.""" diff --git a/tests/integration/single_table/test_constraints.py b/tests/integration/single_table/test_constraints.py index 4a1b51fd3..1ace72f63 100644 --- a/tests/integration/single_table/test_constraints.py +++ b/tests/integration/single_table/test_constraints.py @@ -252,13 +252,13 @@ def test_conditional_sampling_constraint_uses_reject_sampling(gm_mock, isinstanc model.add_constraints([constraint]) sampled_numeric_data = [ pd.DataFrame({ - 'city#state': [0, 1, 2, 0, 0], + 'city#state': [0.1, 1, 0.75, 0.25, 0.25], 'age': [30, 30, 30, 30, 30] }), pd.DataFrame({ - 'city#state': [1], + 'city#state': [0.75], 'age': [30] - }) + }), ] gm_mock.return_value.sample.side_effect = sampled_numeric_data model.fit(data) diff --git a/tests/unit/data_processing/test_data_processor.py b/tests/unit/data_processing/test_data_processor.py index 9f8994323..58ec22590 100644 --- a/tests/unit/data_processing/test_data_processor.py +++ b/tests/unit/data_processing/test_data_processor.py @@ -6,7 +6,8 @@ import pytest from rdt.errors import ConfigNotSetError from rdt.errors import NotFittedError as RDTNotFittedError -from rdt.transformers import FloatFormatter, LabelEncoder, UnixTimestampEncoder +from rdt.transformers import ( + AnonymizedFaker, FloatFormatter, IDGenerator, UniformEncoder, UnixTimestampEncoder) from sdv.constraints.errors import MissingConstraintColumnError from sdv.constraints.tabular import Positive, ScalarRange @@ -73,8 +74,8 @@ def test___init__( mock_default_transformers.return_value = { 'numerical': 'FloatFormatter()', - 'categorical': 'LabelEncoder(add_noise=True)', - 'boolean': 'LabelEncoder(add_noise=True)', + 'categorical': 'UniformEncoder()', + 'boolean': 'UniformEncoder()', 'datetime': 'UnixTimestampEncoder()', 'text': 'RegexGenerator()', 'pii': 'AnonymizedFaker()', @@ -114,8 +115,8 @@ def test___init__( expected_default_transformers = { 'numerical': 'FloatFormatter()', - 'categorical': 'LabelEncoder(add_noise=True)', - 'boolean': 'LabelEncoder(add_noise=True)', + 'categorical': 'UniformEncoder()', + 'boolean': 'UniformEncoder()', 'datetime': 'UnixTimestampEncoder()', 'id': 'RegexGenerator()', 'pii': 'AnonymizedFaker()', @@ -758,7 +759,7 @@ def test__update_transformers_by_sdtypes(self): # Setup instance = Mock() instance._transformers_by_sdtype = { - 'categorical': 'labelencoder', + 'categorical': 'UniformEncoder', 'numerical': 'float', 'boolean': None } @@ -924,7 +925,7 @@ def test__get_transformer_instance_no_kwargs(self): dp = Mock() dp._transformers_by_sdtype = { 'numerical': 'FloatFormatter', - 'categorical': 'LabelEncoder' + 'categorical': 'UniformEncoder' } # Run @@ -968,7 +969,7 @@ def test__update_constraint_transformers(self, mock_rdt, mock_log): 'map_col#cat_col': ['z#a', 'x#b'], 'low#high': [0.2, 0.5] }) - dp._get_transformer_instance = Mock(return_value='LabelEncoder') + dp._get_transformer_instance = Mock(return_value='UniformEncoder') mock_rdt.transformers.FloatFormatter.return_value = 'FloatFormatter' config = { @@ -977,8 +978,8 @@ def test__update_constraint_transformers(self, mock_rdt, mock_log): 'pii_col': 'FloatFormatter', 'low': 'FloatFormatter', 'high': 'FloatFormatter', - 'cat_col': 'LabelEncoder', - 'map_col': 'LabelEncoder', + 'cat_col': 'UniformEncoder', + 'map_col': 'UniformEncoder', }, 'sdtypes': { 'id_col': 'numerical', @@ -1001,8 +1002,8 @@ def test__update_constraint_transformers(self, mock_rdt, mock_log): 'pii_col': 'FloatFormatter', 'low': 'FloatFormatter', 'high': 'FloatFormatter', - 'map_col': 'LabelEncoder', - 'map_col#cat_col': 'LabelEncoder', + 'map_col': 'UniformEncoder', + 'map_col#cat_col': 'UniformEncoder', 'low#high': 'FloatFormatter' }, 'sdtypes': { @@ -1052,17 +1053,23 @@ def test__create_config(self): 'email': ['a@aol.com', 'b@gmail.com', 'c@gmx.com'], 'first_name': ['John', 'Doe', 'Johanna'], 'id': ['ID_001', 'ID_002', 'ID_003'], + 'id_no_regex': ['ID_001', 'ID_002', 'ID_003'], + 'id_numeric': [0, 1, 2], + 'id_column': ['ID_999', 'ID_999', 'ID_007'], 'date': ['2021-02-01', '2022-03-05', '2023-01-31'] }) dp = DataProcessor(SingleTableMetadata(), locales=locales) dp.metadata = Mock() dp.create_anonymized_transformer = Mock() dp.create_regex_generator = Mock() + dp.create_id_generator = Mock() dp.create_anonymized_transformer.return_value = 'AnonymizedFaker' dp.create_regex_generator.return_value = 'RegexGenerator' + dp.create_id_generator.return_value = 'IDGenerator' dp.metadata.primary_key = 'id' + dp.metadata.alternate_keys = ['id_no_regex', 'id_numeric'] dp._primary_key = 'id' - dp._keys = ['id'] + dp._keys = ['id', 'id_no_regex', 'id_numeric'] dp.metadata.columns = { 'int': {'sdtype': 'numerical'}, 'float': {'sdtype': 'numerical'}, @@ -1071,6 +1078,9 @@ def test__create_config(self): 'email': {'sdtype': 'email', 'pii': True}, 'first_name': {'sdtype': 'first_name'}, 'id': {'sdtype': 'id', 'regex_format': 'ID_\\d{3}[0-9]'}, + 'id_no_regex': {'sdtype': 'id'}, + 'id_numeric': {'sdtype': 'id'}, + 'id_column': {'sdtype': 'id'}, 'date': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'} } @@ -1091,6 +1101,9 @@ def test__create_config(self): 'email': 'pii', 'first_name': 'pii', 'id': 'text', + 'id_no_regex': 'text', + 'id_numeric': 'text', + 'id_column': 'pii', 'date': 'datetime' } @@ -1104,10 +1117,10 @@ def test__create_config(self): assert float_transformer.missing_value_replacement == 'mean' assert float_transformer.missing_value_generation == 'random' - assert isinstance(config['transformers']['bool'], LabelEncoder) - assert isinstance(config['transformers']['created_bool'], LabelEncoder) - assert isinstance(config['transformers']['categorical'], LabelEncoder) - assert isinstance(config['transformers']['created_categorical'], LabelEncoder) + assert isinstance(config['transformers']['bool'], UniformEncoder) + assert isinstance(config['transformers']['created_bool'], UniformEncoder) + assert isinstance(config['transformers']['categorical'], UniformEncoder) + assert isinstance(config['transformers']['created_categorical'], UniformEncoder) assert isinstance(config['transformers']['int'], FloatFormatter) assert isinstance(config['transformers']['float'], FloatFormatter) @@ -1126,6 +1139,21 @@ def test__create_config(self): assert datetime_transformer.datetime_format == '%Y-%m-%d' assert dp._primary_key == 'id' + id_no_regex_transformer = config['transformers']['id_no_regex'] + assert isinstance(id_no_regex_transformer, IDGenerator) + assert id_no_regex_transformer.prefix == 'sdv-id-' + assert id_no_regex_transformer.starting_value == 0 + + id_numeric_transformer = config['transformers']['id_numeric'] + assert isinstance(id_numeric_transformer, IDGenerator) + assert id_numeric_transformer.prefix is None + assert id_numeric_transformer.starting_value == 0 + + id_column_transformer = config['transformers']['id_column'] + assert isinstance(id_column_transformer, AnonymizedFaker) + assert id_column_transformer.function_name == 'bothify' + assert id_column_transformer.function_kwargs == {'text': '#####'} + dp.create_anonymized_transformer.calls == [ call('email', {'sdtype': 'email', 'pii': True, 'locales': locales}), call('first_name', {'sdtype': 'first_name', 'locales': locales})