Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use IDGenerator for ID columns #1538

Merged
merged 7 commits into from
Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions sdv/data_processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pandas as pd
import rdt
from pandas.api.types import is_float_dtype, is_integer_dtype
from rdt.transformers import AnonymizedFaker, RegexGenerator, get_default_transformers
from rdt.transformers import AnonymizedFaker, IDGenerator, RegexGenerator, get_default_transformers

from sdv.constraints import Constraint
from sdv.constraints.base import get_subclasses
Expand Down Expand Up @@ -458,13 +458,28 @@ def _create_config(self, data, columns_created_by_constraints):

if sdtype == 'id':
is_numeric = pd.api.types.is_numeric_dtype(data[column].dtype)
transformers[column] = self.create_regex_generator(
column,
sdtype,
column_metadata,
is_numeric
)
sdtypes[column] = 'text'
if column_metadata.get('regex_format', False):
transformers[column] = self.create_regex_generator(
column,
sdtype,
column_metadata,
is_numeric
)
sdtypes[column] = 'text'
frances-h marked this conversation as resolved.
Show resolved Hide resolved
elif column == self.metadata.primary_key or column in self.metadata.alternate_keys:
frances-h marked this conversation as resolved.
Show resolved Hide resolved
prefix = None
if not is_numeric:
prefix = 'sdv-id-'

transformers[column] = IDGenerator(prefix=prefix)
sdtypes[column] = 'text'
else:
transformers[column] = AnonymizedFaker(
provider_name=None,
function_name='bothify',
function_kwargs={'text': '#####'}
)
sdtypes[column] = 'pii'

elif pii:
enforce_uniqueness = bool(column in self._keys)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
'copulas>=0.9.0,<0.10',
'ctgan>=0.7.4,<0.8',
'deepecho>=0.4.2,<0.5',
'rdt>=1.6.1,<2',
'rdt>=1.7.0.dev0',
'sdmetrics>=0.11.0,<0.12',
'cloudpickle>=2.1.0,<3.0',
'boto3>=1.15.0,<2',
Expand Down
16 changes: 8 additions & 8 deletions tests/integration/data_processing/test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np
from rdt.transformers import (
AnonymizedFaker, BinaryEncoder, FloatFormatter, LabelEncoder, RegexGenerator,
AnonymizedFaker, BinaryEncoder, FloatFormatter, RegexGenerator, UniformEncoder,
UnixTimestampEncoder)

from sdv.data_processing import DataProcessor
Expand Down Expand Up @@ -247,23 +247,23 @@ def test_data_processor_prepare_for_fitting():
# Assert
field_transformers = dp._hyper_transformer.field_transformers
expected_transformers = {
'mba_spec': LabelEncoder,
'mba_spec': UniformEncoder,
'employability_perc': FloatFormatter,
'placed': LabelEncoder,
'placed': UniformEncoder,
'student_id': RegexGenerator,
'experience_years': FloatFormatter,
'duration': LabelEncoder,
'duration': UniformEncoder,
'salary': FloatFormatter,
'second_perc': FloatFormatter,
'start_date': UnixTimestampEncoder,
'address': AnonymizedFaker,
'gender': LabelEncoder,
'gender': UniformEncoder,
'mba_perc': FloatFormatter,
'degree_type': LabelEncoder,
'degree_type': UniformEncoder,
'end_date': UnixTimestampEncoder,
'high_spec': LabelEncoder,
'high_spec': UniformEncoder,
'high_perc': FloatFormatter,
'work_experience': LabelEncoder,
'work_experience': UniformEncoder,
'degree_perc': FloatFormatter
}
for column_name, transformer_class in expected_transformers.items():
Expand Down
6 changes: 2 additions & 4 deletions tests/integration/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pandas as pd
import pkg_resources
import pytest
from rdt.transformers import AnonymizedFaker, FloatFormatter, LabelEncoder, RegexGenerator
from rdt.transformers import AnonymizedFaker, FloatFormatter, RegexGenerator, UniformEncoder

from sdv.metadata import SingleTableMetadata
from sdv.sampling import Condition
Expand Down Expand Up @@ -252,7 +252,7 @@ def test_transformers_correctly_auto_assigned():
assert isinstance(transformers['numerical_col'], FloatFormatter)
assert isinstance(transformers['pii_col'], AnonymizedFaker)
assert isinstance(transformers['primary_key'], RegexGenerator)
assert isinstance(transformers['categorical_col'], LabelEncoder)
assert isinstance(transformers['categorical_col'], UniformEncoder)

assert transformers['numerical_col'].missing_value_replacement == 'mean'
assert transformers['numerical_col'].missing_value_generation == 'random'
Expand All @@ -265,8 +265,6 @@ def test_transformers_correctly_auto_assigned():
assert transformers['primary_key'].regex_format == 'user-[0-9]{3}'
assert transformers['primary_key'].enforce_uniqueness is True

assert transformers['categorical_col'].add_noise is True


def test_modeling_with_complex_datetimes():
"""Test that models work with datetimes passed as strings or ints with complex format."""
Expand Down
6 changes: 3 additions & 3 deletions tests/integration/single_table/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,13 +252,13 @@ def test_conditional_sampling_constraint_uses_reject_sampling(gm_mock, isinstanc
model.add_constraints([constraint])
sampled_numeric_data = [
pd.DataFrame({
'city#state': [0, 1, 2, 0, 0],
'city#state': [0.1, 1, 0.75, 0.25, 0.25],
'age': [30, 30, 30, 30, 30]
}),
pd.DataFrame({
'city#state': [1],
'city#state': [0.75],
'age': [30]
})
}),
]
gm_mock.return_value.sample.side_effect = sampled_numeric_data
model.fit(data)
Expand Down
60 changes: 44 additions & 16 deletions tests/unit/data_processing/test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import pytest
from rdt.errors import ConfigNotSetError
from rdt.errors import NotFittedError as RDTNotFittedError
from rdt.transformers import FloatFormatter, LabelEncoder, UnixTimestampEncoder
from rdt.transformers import (
AnonymizedFaker, FloatFormatter, IDGenerator, UniformEncoder, UnixTimestampEncoder)

from sdv.constraints.errors import MissingConstraintColumnError
from sdv.constraints.tabular import Positive, ScalarRange
Expand Down Expand Up @@ -73,8 +74,8 @@ def test___init__(

mock_default_transformers.return_value = {
'numerical': 'FloatFormatter()',
'categorical': 'LabelEncoder(add_noise=True)',
'boolean': 'LabelEncoder(add_noise=True)',
'categorical': 'UniformEncoder()',
'boolean': 'UniformEncoder()',
'datetime': 'UnixTimestampEncoder()',
'text': 'RegexGenerator()',
'pii': 'AnonymizedFaker()',
Expand Down Expand Up @@ -114,8 +115,8 @@ def test___init__(

expected_default_transformers = {
'numerical': 'FloatFormatter()',
'categorical': 'LabelEncoder(add_noise=True)',
'boolean': 'LabelEncoder(add_noise=True)',
'categorical': 'UniformEncoder()',
'boolean': 'UniformEncoder()',
'datetime': 'UnixTimestampEncoder()',
'id': 'RegexGenerator()',
'pii': 'AnonymizedFaker()',
Expand Down Expand Up @@ -758,7 +759,7 @@ def test__update_transformers_by_sdtypes(self):
# Setup
instance = Mock()
instance._transformers_by_sdtype = {
'categorical': 'labelencoder',
'categorical': 'UniformEncoder',
'numerical': 'float',
'boolean': None
}
Expand Down Expand Up @@ -924,7 +925,7 @@ def test__get_transformer_instance_no_kwargs(self):
dp = Mock()
dp._transformers_by_sdtype = {
'numerical': 'FloatFormatter',
'categorical': 'LabelEncoder'
'categorical': 'UniformEncoder'
}

# Run
Expand Down Expand Up @@ -968,7 +969,7 @@ def test__update_constraint_transformers(self, mock_rdt, mock_log):
'map_col#cat_col': ['z#a', 'x#b'],
'low#high': [0.2, 0.5]
})
dp._get_transformer_instance = Mock(return_value='LabelEncoder')
dp._get_transformer_instance = Mock(return_value='UniformEncoder')
mock_rdt.transformers.FloatFormatter.return_value = 'FloatFormatter'

config = {
Expand All @@ -977,8 +978,8 @@ def test__update_constraint_transformers(self, mock_rdt, mock_log):
'pii_col': 'FloatFormatter',
'low': 'FloatFormatter',
'high': 'FloatFormatter',
'cat_col': 'LabelEncoder',
'map_col': 'LabelEncoder',
'cat_col': 'UniformEncoder',
'map_col': 'UniformEncoder',
},
'sdtypes': {
'id_col': 'numerical',
Expand All @@ -1001,8 +1002,8 @@ def test__update_constraint_transformers(self, mock_rdt, mock_log):
'pii_col': 'FloatFormatter',
'low': 'FloatFormatter',
'high': 'FloatFormatter',
'map_col': 'LabelEncoder',
'map_col#cat_col': 'LabelEncoder',
'map_col': 'UniformEncoder',
'map_col#cat_col': 'UniformEncoder',
'low#high': 'FloatFormatter'
},
'sdtypes': {
Expand Down Expand Up @@ -1052,15 +1053,21 @@ def test__create_config(self):
'email': ['[email protected]', '[email protected]', '[email protected]'],
'first_name': ['John', 'Doe', 'Johanna'],
'id': ['ID_001', 'ID_002', 'ID_003'],
'id_no_regex': ['ID_001', 'ID_002', 'ID_003'],
'id_numeric': [0, 1, 2],
'id_column': ['ID_999', 'ID_999', 'ID_007'],
'date': ['2021-02-01', '2022-03-05', '2023-01-31']
})
dp = DataProcessor(SingleTableMetadata(), locales=locales)
dp.metadata = Mock()
dp.create_anonymized_transformer = Mock()
dp.create_regex_generator = Mock()
dp.create_id_generator = Mock()
dp.create_anonymized_transformer.return_value = 'AnonymizedFaker'
dp.create_regex_generator.return_value = 'RegexGenerator'
dp.create_id_generator.return_value = 'IDGenerator'
dp.metadata.primary_key = 'id'
dp.metadata.alternate_keys = ['id_no_regex', 'id_numeric']
dp._primary_key = 'id'
dp._keys = ['id']
dp.metadata.columns = {
Expand All @@ -1071,6 +1078,9 @@ def test__create_config(self):
'email': {'sdtype': 'email', 'pii': True},
'first_name': {'sdtype': 'first_name'},
'id': {'sdtype': 'id', 'regex_format': 'ID_\\d{3}[0-9]'},
'id_no_regex': {'sdtype': 'id'},
'id_numeric': {'sdtype': 'id'},
'id_column': {'sdtype': 'id'},
'date': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}
}

Expand All @@ -1091,6 +1101,9 @@ def test__create_config(self):
'email': 'pii',
'first_name': 'pii',
'id': 'text',
'id_no_regex': 'text',
'id_numeric': 'text',
'id_column': 'pii',
'date': 'datetime'
}

Expand All @@ -1104,10 +1117,10 @@ def test__create_config(self):
assert float_transformer.missing_value_replacement == 'mean'
assert float_transformer.missing_value_generation == 'random'

assert isinstance(config['transformers']['bool'], LabelEncoder)
assert isinstance(config['transformers']['created_bool'], LabelEncoder)
assert isinstance(config['transformers']['categorical'], LabelEncoder)
assert isinstance(config['transformers']['created_categorical'], LabelEncoder)
assert isinstance(config['transformers']['bool'], UniformEncoder)
assert isinstance(config['transformers']['created_bool'], UniformEncoder)
assert isinstance(config['transformers']['categorical'], UniformEncoder)
assert isinstance(config['transformers']['created_categorical'], UniformEncoder)

assert isinstance(config['transformers']['int'], FloatFormatter)
assert isinstance(config['transformers']['float'], FloatFormatter)
Expand All @@ -1126,6 +1139,21 @@ def test__create_config(self):
assert datetime_transformer.datetime_format == '%Y-%m-%d'
assert dp._primary_key == 'id'

id_no_regex_transformer = config['transformers']['id_no_regex']
assert isinstance(id_no_regex_transformer, IDGenerator)
assert id_no_regex_transformer.prefix == 'sdv-id-'
assert id_no_regex_transformer.starting_value == 0

id_numeric_transformer = config['transformers']['id_numeric']
assert isinstance(id_numeric_transformer, IDGenerator)
assert id_numeric_transformer.prefix is None
assert id_numeric_transformer.starting_value == 0

id_column_transformer = config['transformers']['id_column']
assert isinstance(id_column_transformer, AnonymizedFaker)
assert id_column_transformer.function_name == 'bothify'
assert id_column_transformer.function_kwargs == {'text': '#####'}

dp.create_anonymized_transformer.calls == [
call('email', {'sdtype': 'email', 'pii': True, 'locales': locales}),
call('first_name', {'sdtype': 'first_name', 'locales': locales})
Expand Down
Loading