Skip to content

Commit

Permalink
Use IDGenerator for ID columns (#1538)
Browse files Browse the repository at this point in the history
* Use IDGenerator for key columns without specified regex patterns

* lint

* temporarily point to RDT master

* point to RDT release candidate

* use UniformEncoder instead of LabelEncoder

* comment

* add newlines
  • Loading branch information
frances-h committed Aug 17, 2023
1 parent ba6a908 commit 473ecee
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 41 deletions.
33 changes: 25 additions & 8 deletions sdv/data_processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pandas as pd
import rdt
from pandas.api.types import is_float_dtype, is_integer_dtype
from rdt.transformers import AnonymizedFaker, RegexGenerator, get_default_transformers
from rdt.transformers import AnonymizedFaker, IDGenerator, RegexGenerator, get_default_transformers

from sdv.constraints import Constraint
from sdv.constraints.base import get_subclasses
Expand Down Expand Up @@ -458,13 +458,30 @@ def _create_config(self, data, columns_created_by_constraints):

if sdtype == 'id':
is_numeric = pd.api.types.is_numeric_dtype(data[column].dtype)
transformers[column] = self.create_regex_generator(
column,
sdtype,
column_metadata,
is_numeric
)
sdtypes[column] = 'text'
if column_metadata.get('regex_format', False):
transformers[column] = self.create_regex_generator(
column,
sdtype,
column_metadata,
is_numeric
)
sdtypes[column] = 'text'

elif column in self._keys:
prefix = None
if not is_numeric:
prefix = 'sdv-id-'

transformers[column] = IDGenerator(prefix=prefix)
sdtypes[column] = 'text'

else:
transformers[column] = AnonymizedFaker(
provider_name=None,
function_name='bothify',
function_kwargs={'text': '#####'}
)
sdtypes[column] = 'pii'

elif pii:
enforce_uniqueness = bool(column in self._keys)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
'copulas>=0.9.0,<0.10',
'ctgan>=0.7.4,<0.8',
'deepecho>=0.4.2,<0.5',
'rdt>=1.6.1,<2',
'rdt>=1.7.0.dev0',
'sdmetrics>=0.11.0,<0.12',
'cloudpickle>=2.1.0,<3.0',
'boto3>=1.15.0,<2',
Expand Down
16 changes: 8 additions & 8 deletions tests/integration/data_processing/test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np
from rdt.transformers import (
AnonymizedFaker, BinaryEncoder, FloatFormatter, LabelEncoder, RegexGenerator,
AnonymizedFaker, BinaryEncoder, FloatFormatter, RegexGenerator, UniformEncoder,
UnixTimestampEncoder)

from sdv.data_processing import DataProcessor
Expand Down Expand Up @@ -247,23 +247,23 @@ def test_data_processor_prepare_for_fitting():
# Assert
field_transformers = dp._hyper_transformer.field_transformers
expected_transformers = {
'mba_spec': LabelEncoder,
'mba_spec': UniformEncoder,
'employability_perc': FloatFormatter,
'placed': LabelEncoder,
'placed': UniformEncoder,
'student_id': RegexGenerator,
'experience_years': FloatFormatter,
'duration': LabelEncoder,
'duration': UniformEncoder,
'salary': FloatFormatter,
'second_perc': FloatFormatter,
'start_date': UnixTimestampEncoder,
'address': AnonymizedFaker,
'gender': LabelEncoder,
'gender': UniformEncoder,
'mba_perc': FloatFormatter,
'degree_type': LabelEncoder,
'degree_type': UniformEncoder,
'end_date': UnixTimestampEncoder,
'high_spec': LabelEncoder,
'high_spec': UniformEncoder,
'high_perc': FloatFormatter,
'work_experience': LabelEncoder,
'work_experience': UniformEncoder,
'degree_perc': FloatFormatter
}
for column_name, transformer_class in expected_transformers.items():
Expand Down
6 changes: 2 additions & 4 deletions tests/integration/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pandas as pd
import pkg_resources
import pytest
from rdt.transformers import AnonymizedFaker, FloatFormatter, LabelEncoder, RegexGenerator
from rdt.transformers import AnonymizedFaker, FloatFormatter, RegexGenerator, UniformEncoder

from sdv.metadata import SingleTableMetadata
from sdv.sampling import Condition
Expand Down Expand Up @@ -252,7 +252,7 @@ def test_transformers_correctly_auto_assigned():
assert isinstance(transformers['numerical_col'], FloatFormatter)
assert isinstance(transformers['pii_col'], AnonymizedFaker)
assert isinstance(transformers['primary_key'], RegexGenerator)
assert isinstance(transformers['categorical_col'], LabelEncoder)
assert isinstance(transformers['categorical_col'], UniformEncoder)

assert transformers['numerical_col'].missing_value_replacement == 'mean'
assert transformers['numerical_col'].missing_value_generation == 'random'
Expand All @@ -265,8 +265,6 @@ def test_transformers_correctly_auto_assigned():
assert transformers['primary_key'].regex_format == 'user-[0-9]{3}'
assert transformers['primary_key'].enforce_uniqueness is True

assert transformers['categorical_col'].add_noise is True


def test_modeling_with_complex_datetimes():
"""Test that models work with datetimes passed as strings or ints with complex format."""
Expand Down
6 changes: 3 additions & 3 deletions tests/integration/single_table/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,13 +252,13 @@ def test_conditional_sampling_constraint_uses_reject_sampling(gm_mock, isinstanc
model.add_constraints([constraint])
sampled_numeric_data = [
pd.DataFrame({
'city#state': [0, 1, 2, 0, 0],
'city#state': [0.1, 1, 0.75, 0.25, 0.25],
'age': [30, 30, 30, 30, 30]
}),
pd.DataFrame({
'city#state': [1],
'city#state': [0.75],
'age': [30]
})
}),
]
gm_mock.return_value.sample.side_effect = sampled_numeric_data
model.fit(data)
Expand Down
62 changes: 45 additions & 17 deletions tests/unit/data_processing/test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import pytest
from rdt.errors import ConfigNotSetError
from rdt.errors import NotFittedError as RDTNotFittedError
from rdt.transformers import FloatFormatter, LabelEncoder, UnixTimestampEncoder
from rdt.transformers import (
AnonymizedFaker, FloatFormatter, IDGenerator, UniformEncoder, UnixTimestampEncoder)

from sdv.constraints.errors import MissingConstraintColumnError
from sdv.constraints.tabular import Positive, ScalarRange
Expand Down Expand Up @@ -73,8 +74,8 @@ def test___init__(

mock_default_transformers.return_value = {
'numerical': 'FloatFormatter()',
'categorical': 'LabelEncoder(add_noise=True)',
'boolean': 'LabelEncoder(add_noise=True)',
'categorical': 'UniformEncoder()',
'boolean': 'UniformEncoder()',
'datetime': 'UnixTimestampEncoder()',
'text': 'RegexGenerator()',
'pii': 'AnonymizedFaker()',
Expand Down Expand Up @@ -114,8 +115,8 @@ def test___init__(

expected_default_transformers = {
'numerical': 'FloatFormatter()',
'categorical': 'LabelEncoder(add_noise=True)',
'boolean': 'LabelEncoder(add_noise=True)',
'categorical': 'UniformEncoder()',
'boolean': 'UniformEncoder()',
'datetime': 'UnixTimestampEncoder()',
'id': 'RegexGenerator()',
'pii': 'AnonymizedFaker()',
Expand Down Expand Up @@ -758,7 +759,7 @@ def test__update_transformers_by_sdtypes(self):
# Setup
instance = Mock()
instance._transformers_by_sdtype = {
'categorical': 'labelencoder',
'categorical': 'UniformEncoder',
'numerical': 'float',
'boolean': None
}
Expand Down Expand Up @@ -924,7 +925,7 @@ def test__get_transformer_instance_no_kwargs(self):
dp = Mock()
dp._transformers_by_sdtype = {
'numerical': 'FloatFormatter',
'categorical': 'LabelEncoder'
'categorical': 'UniformEncoder'
}

# Run
Expand Down Expand Up @@ -968,7 +969,7 @@ def test__update_constraint_transformers(self, mock_rdt, mock_log):
'map_col#cat_col': ['z#a', 'x#b'],
'low#high': [0.2, 0.5]
})
dp._get_transformer_instance = Mock(return_value='LabelEncoder')
dp._get_transformer_instance = Mock(return_value='UniformEncoder')
mock_rdt.transformers.FloatFormatter.return_value = 'FloatFormatter'

config = {
Expand All @@ -977,8 +978,8 @@ def test__update_constraint_transformers(self, mock_rdt, mock_log):
'pii_col': 'FloatFormatter',
'low': 'FloatFormatter',
'high': 'FloatFormatter',
'cat_col': 'LabelEncoder',
'map_col': 'LabelEncoder',
'cat_col': 'UniformEncoder',
'map_col': 'UniformEncoder',
},
'sdtypes': {
'id_col': 'numerical',
Expand All @@ -1001,8 +1002,8 @@ def test__update_constraint_transformers(self, mock_rdt, mock_log):
'pii_col': 'FloatFormatter',
'low': 'FloatFormatter',
'high': 'FloatFormatter',
'map_col': 'LabelEncoder',
'map_col#cat_col': 'LabelEncoder',
'map_col': 'UniformEncoder',
'map_col#cat_col': 'UniformEncoder',
'low#high': 'FloatFormatter'
},
'sdtypes': {
Expand Down Expand Up @@ -1052,17 +1053,23 @@ def test__create_config(self):
'email': ['[email protected]', '[email protected]', '[email protected]'],
'first_name': ['John', 'Doe', 'Johanna'],
'id': ['ID_001', 'ID_002', 'ID_003'],
'id_no_regex': ['ID_001', 'ID_002', 'ID_003'],
'id_numeric': [0, 1, 2],
'id_column': ['ID_999', 'ID_999', 'ID_007'],
'date': ['2021-02-01', '2022-03-05', '2023-01-31']
})
dp = DataProcessor(SingleTableMetadata(), locales=locales)
dp.metadata = Mock()
dp.create_anonymized_transformer = Mock()
dp.create_regex_generator = Mock()
dp.create_id_generator = Mock()
dp.create_anonymized_transformer.return_value = 'AnonymizedFaker'
dp.create_regex_generator.return_value = 'RegexGenerator'
dp.create_id_generator.return_value = 'IDGenerator'
dp.metadata.primary_key = 'id'
dp.metadata.alternate_keys = ['id_no_regex', 'id_numeric']
dp._primary_key = 'id'
dp._keys = ['id']
dp._keys = ['id', 'id_no_regex', 'id_numeric']
dp.metadata.columns = {
'int': {'sdtype': 'numerical'},
'float': {'sdtype': 'numerical'},
Expand All @@ -1071,6 +1078,9 @@ def test__create_config(self):
'email': {'sdtype': 'email', 'pii': True},
'first_name': {'sdtype': 'first_name'},
'id': {'sdtype': 'id', 'regex_format': 'ID_\\d{3}[0-9]'},
'id_no_regex': {'sdtype': 'id'},
'id_numeric': {'sdtype': 'id'},
'id_column': {'sdtype': 'id'},
'date': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}
}

Expand All @@ -1091,6 +1101,9 @@ def test__create_config(self):
'email': 'pii',
'first_name': 'pii',
'id': 'text',
'id_no_regex': 'text',
'id_numeric': 'text',
'id_column': 'pii',
'date': 'datetime'
}

Expand All @@ -1104,10 +1117,10 @@ def test__create_config(self):
assert float_transformer.missing_value_replacement == 'mean'
assert float_transformer.missing_value_generation == 'random'

assert isinstance(config['transformers']['bool'], LabelEncoder)
assert isinstance(config['transformers']['created_bool'], LabelEncoder)
assert isinstance(config['transformers']['categorical'], LabelEncoder)
assert isinstance(config['transformers']['created_categorical'], LabelEncoder)
assert isinstance(config['transformers']['bool'], UniformEncoder)
assert isinstance(config['transformers']['created_bool'], UniformEncoder)
assert isinstance(config['transformers']['categorical'], UniformEncoder)
assert isinstance(config['transformers']['created_categorical'], UniformEncoder)

assert isinstance(config['transformers']['int'], FloatFormatter)
assert isinstance(config['transformers']['float'], FloatFormatter)
Expand All @@ -1126,6 +1139,21 @@ def test__create_config(self):
assert datetime_transformer.datetime_format == '%Y-%m-%d'
assert dp._primary_key == 'id'

id_no_regex_transformer = config['transformers']['id_no_regex']
assert isinstance(id_no_regex_transformer, IDGenerator)
assert id_no_regex_transformer.prefix == 'sdv-id-'
assert id_no_regex_transformer.starting_value == 0

id_numeric_transformer = config['transformers']['id_numeric']
assert isinstance(id_numeric_transformer, IDGenerator)
assert id_numeric_transformer.prefix is None
assert id_numeric_transformer.starting_value == 0

id_column_transformer = config['transformers']['id_column']
assert isinstance(id_column_transformer, AnonymizedFaker)
assert id_column_transformer.function_name == 'bothify'
assert id_column_transformer.function_kwargs == {'text': '#####'}

dp.create_anonymized_transformer.calls == [
call('email', {'sdtype': 'email', 'pii': True, 'locales': locales}),
call('first_name', {'sdtype': 'first_name', 'locales': locales})
Expand Down

0 comments on commit 473ecee

Please sign in to comment.