Skip to content

Commit

Permalink
update integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo committed Sep 18, 2024
1 parent bbe0ba6 commit 94b2860
Show file tree
Hide file tree
Showing 10 changed files with 150 additions and 151 deletions.
24 changes: 12 additions & 12 deletions tests/integration/data_processing/test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_with_anonymized_columns(self):
data, metadata = download_demo('single_table', 'adult')

# Add anonymized field
metadata.update_column('adult', 'occupation', sdtype='job', pii=True)
metadata.update_column('occupation', 'adult', sdtype='job', pii=True)

# Instance ``DataProcessor``
dp = DataProcessor(metadata._convert_to_single_table())
Expand Down Expand Up @@ -101,11 +101,11 @@ def test_with_anonymized_columns_and_primary_key(self):
data, metadata = download_demo('single_table', 'adult')

# Add anonymized field
metadata.update_column('adult', 'occupation', sdtype='job', pii=True)
metadata.update_column('occupation', 'adult', sdtype='job', pii=True)

# Add primary key field
metadata.add_column('adult', 'id', sdtype='id', regex_format='ID_\\d{4}[0-9]')
metadata.set_primary_key('adult', 'id')
metadata.add_column('id', 'adult', sdtype='id', regex_format='ID_\\d{4}[0-9]')
metadata.set_primary_key('id', 'adult')

# Add id
size = len(data)
Expand Down Expand Up @@ -159,8 +159,8 @@ def test_with_primary_key_numerical(self):
adult_metadata = Metadata.detect_from_dataframes({'adult': data})

# Add primary key field
adult_metadata.add_column('adult', 'id', sdtype='id')
adult_metadata.set_primary_key('adult', 'id')
adult_metadata.add_column('id', 'adult', sdtype='id')
adult_metadata.set_primary_key('id', 'adult')

# Add id
size = len(data)
Expand Down Expand Up @@ -198,13 +198,13 @@ def test_with_alternate_keys(self):
adult_metadata = Metadata.detect_from_dataframes({'adult': data})

# Add primary key field
adult_metadata.add_column('adult', 'id', sdtype='id')
adult_metadata.set_primary_key('adult', 'id')
adult_metadata.add_column('id', 'adult', sdtype='id')
adult_metadata.set_primary_key('id', 'adult')

adult_metadata.add_column('adult', 'secondary_id', sdtype='id')
adult_metadata.update_column('adult', 'fnlwgt', sdtype='id', regex_format='ID_\\d{4}[0-9]')
adult_metadata.add_column('secondary_id', 'adult', sdtype='id')
adult_metadata.update_column('fnlwgt', 'adult', sdtype='id', regex_format='ID_\\d{4}[0-9]')

adult_metadata.add_alternate_keys('adult', ['secondary_id', 'fnlwgt'])
adult_metadata.add_alternate_keys(['secondary_id', 'fnlwgt'], 'adult')

# Add id
size = len(data)
Expand Down Expand Up @@ -345,7 +345,7 @@ def test_localized_anonymized_columns(self):
"""Test data processor uses the default locale for anonymized columns."""
# Setup
data, metadata = download_demo('single_table', 'adult')
metadata.update_column('adult', 'occupation', sdtype='job', pii=True)
metadata.update_column('occupation', 'adult', sdtype='job', pii=True)

dp = DataProcessor(metadata._convert_to_single_table(), locales=['en_CA', 'fr_CA'])

Expand Down
2 changes: 1 addition & 1 deletion tests/integration/evaluation/test_single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def test_evaluation():
data = pd.DataFrame({'col': [1, 2, 3]})
metadata = Metadata()
metadata.add_table('table')
metadata.add_column('table', 'col', sdtype='numerical')
metadata.add_column('col', 'table', sdtype='numerical')
synthesizer = GaussianCopulaSynthesizer(metadata, default_distribution='truncnorm')

# Run and Assert
Expand Down
46 changes: 19 additions & 27 deletions tests/integration/metadata/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,54 +419,46 @@ def test_any_metadata_update_single_table(method, args, kwargs):
metadata.update_column(
table_name='fake_hotel_guests', column_name='billing_address', sdtype='street_address'
)
metadata_kwargs = deepcopy(metadata)
metadata_args = deepcopy(metadata)
metadata_kwargs_with_table_name = deepcopy(metadata)
metadata_args_with_table_name = deepcopy(metadata)
parameter = [kwargs[arg] for arg in args]
remaining_kwargs = {key: value for key, value in kwargs.items() if key not in args}
metadata_before = deepcopy(metadata).to_dict()

# Run
result = getattr(metadata_kwargs, method)(**kwargs)
getattr(metadata_kwargs_with_table_name, method)(table_name='fake_hotel_guests', **kwargs)
arg_values = [kwargs[arg] for arg in args]
extra_param = {key: value for key, value in kwargs.items() if key not in args}
getattr(metadata_args, method)(*arg_values, **extra_param)
getattr(metadata_args_with_table_name, method)('fake_hotel_guests', *arg_values, **extra_param)
result = getattr(metadata, method)(*parameter, **remaining_kwargs)

# Assert
expected_dict = metadata_kwargs.to_dict()
expected_dict = metadata.to_dict()
if method != 'get_column_names':
assert expected_dict != metadata.to_dict()
assert expected_dict != metadata_before
else:
assert result == ['checkin_date', 'checkout_date']

other_metadata = [metadata_args, metadata_kwargs_with_table_name, metadata_args_with_table_name]
for metadata_obj in other_metadata:
assert expected_dict == metadata_obj.to_dict()


@pytest.mark.parametrize('method, args, kwargs', params)
def test_any_metadata_update_multi_table(method, args, kwargs):
"""Test that any method that updates metadata works for multi-table case."""
# Setup
args.insert(0, 'table_name')
kwargs['table_name'] = 'guests'
_, metadata = download_demo('multi_table', 'fake_hotels')
metadata.update_column(
table_name='guests', column_name='billing_address', sdtype='street_address'
)
metadata_kwargs = deepcopy(metadata)
metadata_args = deepcopy(metadata)
parameter = [kwargs[arg] for arg in args]
remaining_kwargs = {key: value for key, value in kwargs.items() if key not in args}
metadata_before = deepcopy(metadata).to_dict()
expected_error = re.escape(
'Metadata contains more than one table, please specify the `table_name`.'
)

# Run
result = getattr(metadata_kwargs, method)(**kwargs)
arg_values = [kwargs[arg] for arg in args]
extra_param = {key: value for key, value in kwargs.items() if key not in args}
getattr(metadata_args, method)(*arg_values, **extra_param)
with pytest.raises(ValueError, match=expected_error):
getattr(metadata, method)(*parameter, **remaining_kwargs)

parameter.append('guests')
result = getattr(metadata, method)(*parameter, **remaining_kwargs)

# Assert
expected_dict = metadata_kwargs.to_dict()
assert expected_dict == metadata_args.to_dict()
expected_dict = metadata.to_dict()
if method != 'get_column_names':
assert expected_dict != metadata.to_dict()
assert expected_dict != metadata_before
else:
assert result == ['checkin_date', 'checkout_date']
12 changes: 6 additions & 6 deletions tests/integration/metadata/test_multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ def _validate_sdtypes(cls, columns_to_sdtypes):

mock_rdt_transformers.address.RandomLocationGenerator = RandomLocationGeneratorMock
_, instance = download_demo('multi_table', 'fake_hotels')
instance.update_column('hotels', 'city', sdtype='city')
instance.update_column('hotels', 'state', sdtype='state')
instance.update_column('city', 'hotels', sdtype='city')
instance.update_column('state', 'hotels', sdtype='state')

# Run
instance.add_column_relationship('hotels', 'address', ['city', 'state'])
instance.add_column_relationship('address', ['city', 'state'], 'hotels')

# Assert
instance.validate()
Expand Down Expand Up @@ -303,9 +303,9 @@ def test_get_table_metadata():
"""Test the ``get_table_metadata`` method."""
# Setup
metadata = get_multi_table_metadata()
metadata.add_column('nesreca', 'latitude', sdtype='latitude')
metadata.add_column('nesreca', 'longitude', sdtype='longitude')
metadata.add_column_relationship('nesreca', 'gps', ['latitude', 'longitude'])
metadata.add_column('latitude', 'nesreca', sdtype='latitude')
metadata.add_column('longitude', 'nesreca', sdtype='longitude')
metadata.add_column_relationship('gps', ['latitude', 'longitude'], 'nesreca')

# Run
table_metadata = metadata.get_table_metadata('nesreca')
Expand Down
6 changes: 3 additions & 3 deletions tests/integration/metadata/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ def test_visualize_graph_for_multi_table():
data2 = pd.DataFrame({'\\|=/bla@#$324%^,"&*()><...': ['a', 'b', 'c']})
tables = {'1': data1, '2': data2}
metadata = Metadata.detect_from_dataframes(tables)
metadata.update_column('1', '\\|=/bla@#$324%^,"&*()><...', sdtype='id')
metadata.update_column('2', '\\|=/bla@#$324%^,"&*()><...', sdtype='id')
metadata.set_primary_key('1', '\\|=/bla@#$324%^,"&*()><...')
metadata.update_column('\\|=/bla@#$324%^,"&*()><...', '1', sdtype='id')
metadata.update_column('\\|=/bla@#$324%^,"&*()><...', '2', sdtype='id')
metadata.set_primary_key('\\|=/bla@#$324%^,"&*()><...', '1')
metadata.add_relationship(
'1', '2', '\\|=/bla@#$324%^,"&*()><...', '\\|=/bla@#$324%^,"&*()><...'
)
Expand Down
102 changes: 51 additions & 51 deletions tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ def test_hma_reset_sampling(self):
faker = Faker()
data, metadata = download_demo('multi_table', 'got_families')
metadata.add_column(
'characters',
'ssn',
'characters',
sdtype='ssn',
)
data['characters']['ssn'] = [faker.lexify() for _ in range(len(data['characters']))]
Expand Down Expand Up @@ -126,7 +126,7 @@ def test_get_info(self):
today = datetime.datetime.today().strftime('%Y-%m-%d')
metadata = Metadata()
metadata.add_table('tab')
metadata.add_column('tab', 'col', sdtype='numerical')
metadata.add_column('col', 'tab', sdtype='numerical')
synthesizer = HMASynthesizer(metadata)

# Run
Expand Down Expand Up @@ -221,12 +221,12 @@ def get_custom_constraint_data_and_metadata(self):

metadata = Metadata()
metadata.detect_table_from_dataframe('parent', parent_data)
metadata.update_column('parent', 'primary_key', sdtype='id')
metadata.update_column('primary_key', 'parent', sdtype='id')
metadata.detect_table_from_dataframe('child', child_data)
metadata.update_column('child', 'user_id', sdtype='id')
metadata.update_column('child', 'id', sdtype='id')
metadata.set_primary_key('parent', 'primary_key')
metadata.set_primary_key('child', 'id')
metadata.update_column('user_id', 'child', sdtype='id')
metadata.update_column('id', 'child', sdtype='id')
metadata.set_primary_key('primary_key', 'parent')
metadata.set_primary_key('id', 'child')
metadata.add_relationship(
parent_primary_key='primary_key',
parent_table_name='parent',
Expand Down Expand Up @@ -361,10 +361,10 @@ def test_hma_with_inequality_constraint(self):

metadata = Metadata()
metadata.detect_table_from_dataframe(table_name='parent_table', data=parent_table)
metadata.update_column('parent_table', 'id', sdtype='id')
metadata.update_column('id', 'parent_table', sdtype='id')
metadata.detect_table_from_dataframe(table_name='child_table', data=child_table)
metadata.update_column('child_table', 'id', sdtype='id')
metadata.update_column('child_table', 'parent_id', sdtype='id')
metadata.update_column('id', 'child_table', sdtype='id')
metadata.update_column('parent_id', 'child_table', sdtype='id')

metadata.set_primary_key(table_name='parent_table', column_name='id')
metadata.set_primary_key(table_name='child_table', column_name='id')
Expand Down Expand Up @@ -452,14 +452,14 @@ def test_hma_primary_key_and_foreign_key_only(self):
for table_name, table in data.items():
metadata.detect_table_from_dataframe(table_name, table)

metadata.update_column('users', 'user_id', sdtype='id')
metadata.update_column('sessions', 'session_id', sdtype='id')
metadata.update_column('games', 'game_id', sdtype='id')
metadata.update_column('games', 'session_id', sdtype='id')
metadata.update_column('games', 'user_id', sdtype='id')
metadata.set_primary_key('users', 'user_id')
metadata.set_primary_key('sessions', 'session_id')
metadata.set_primary_key('games', 'game_id')
metadata.update_column('user_id', 'users', sdtype='id')
metadata.update_column('session_id', 'sessions', sdtype='id')
metadata.update_column('game_id', 'games', sdtype='id')
metadata.update_column('session_id', 'games', sdtype='id')
metadata.update_column('user_id', 'games', sdtype='id')
metadata.set_primary_key('user_id', 'users')
metadata.set_primary_key('session_id', 'sessions')
metadata.set_primary_key('game_id', 'games')
metadata.add_relationship('users', 'games', 'user_id', 'user_id')
metadata.add_relationship('sessions', 'games', 'session_id', 'session_id')

Expand Down Expand Up @@ -1351,25 +1351,25 @@ def test_null_foreign_keys(self):
metadata = Metadata()

metadata.add_table('parent_table1')
metadata.add_column('parent_table1', 'id', sdtype='id')
metadata.set_primary_key('parent_table1', 'id')
metadata.add_column('id', 'parent_table1', sdtype='id')
metadata.set_primary_key('id', 'parent_table1')

metadata.add_table('parent_table2')
metadata.add_column('parent_table2', 'id', sdtype='id')
metadata.set_primary_key('parent_table2', 'id')
metadata.add_column('id', 'parent_table2', sdtype='id')
metadata.set_primary_key('id', 'parent_table2')

metadata.add_table('child_table1')
metadata.add_column('child_table1', 'id', sdtype='id')
metadata.set_primary_key('child_table1', 'id')
metadata.add_column('child_table1', 'fk1', sdtype='id')
metadata.add_column('child_table1', 'fk2', sdtype='id')
metadata.add_column('id', 'child_table1', sdtype='id')
metadata.set_primary_key('id', 'child_table1')
metadata.add_column('fk1', 'child_table1', sdtype='id')
metadata.add_column('fk2', 'child_table1', sdtype='id')

metadata.add_table('child_table2')
metadata.add_column('child_table2', 'id', sdtype='id')
metadata.set_primary_key('child_table2', 'id')
metadata.add_column('child_table2', 'fk1', sdtype='id')
metadata.add_column('child_table2', 'fk2', sdtype='id')
metadata.add_column('child_table2', 'cat_type', sdtype='categorical')
metadata.add_column('id', 'child_table2', sdtype='id')
metadata.set_primary_key('id', 'child_table2')
metadata.add_column('fk1', 'child_table2', sdtype='id')
metadata.add_column('fk2', 'child_table2', sdtype='id')
metadata.add_column('cat_type', 'child_table2', sdtype='categorical')

metadata.add_relationship(
parent_table_name='parent_table1',
Expand Down Expand Up @@ -1842,7 +1842,7 @@ def test_fit_and_sample_numerical_col_names():
}
]
metadata = Metadata.load_from_dict(metadata_dict)
metadata.set_primary_key('0', '1')
metadata.set_primary_key('1', '0')

# Run
synth = HMASynthesizer(metadata)
Expand Down Expand Up @@ -1875,11 +1875,11 @@ def test_detect_from_dataframe_numerical_col():
metadata = Metadata()
metadata.detect_table_from_dataframe('parent_data', parent_data)
metadata.detect_table_from_dataframe('child_data', child_data)
metadata.update_column('parent_data', '1', sdtype='id')
metadata.update_column('child_data', '3', sdtype='id')
metadata.update_column('child_data', '4', sdtype='id')
metadata.set_primary_key('parent_data', '1')
metadata.set_primary_key('child_data', '4')
metadata.update_column('1', 'parent_data', sdtype='id')
metadata.update_column('3', 'child_data', sdtype='id')
metadata.update_column('4', 'child_data', sdtype='id')
metadata.set_primary_key('1', 'parent_data')
metadata.set_primary_key('4', 'child_data')
metadata.add_relationship(
parent_primary_key='1',
parent_table_name='parent_data',
Expand All @@ -1888,11 +1888,11 @@ def test_detect_from_dataframe_numerical_col():
)

test_metadata = Metadata.detect_from_dataframes(data)
test_metadata.update_column('parent_data', '1', sdtype='id')
test_metadata.update_column('child_data', '3', sdtype='id')
test_metadata.update_column('child_data', '4', sdtype='id')
test_metadata.set_primary_key('parent_data', '1')
test_metadata.set_primary_key('child_data', '4')
test_metadata.update_column('1', 'parent_data', sdtype='id')
test_metadata.update_column('3', 'child_data', sdtype='id')
test_metadata.update_column('4', 'child_data', sdtype='id')
test_metadata.set_primary_key('1', 'parent_data')
test_metadata.set_primary_key('4', 'child_data')
test_metadata.add_relationship(
parent_primary_key='1',
parent_table_name='parent_data',
Expand Down Expand Up @@ -2005,13 +2005,13 @@ def test_hma_synthesizer_with_fixed_combinations():
# Creating metadata for the dataset
metadata = Metadata.detect_from_dataframes(data)

metadata.update_column('users', 'user_id', sdtype='id')
metadata.update_column('records', 'record_id', sdtype='id')
metadata.update_column('records', 'user_id', sdtype='id')
metadata.update_column('records', 'location_id', sdtype='id')
metadata.update_column('locations', 'location_id', sdtype='id')
metadata.set_primary_key('users', 'user_id')
metadata.set_primary_key('locations', 'location_id')
metadata.update_column('user_id', 'users', sdtype='id')
metadata.update_column('record_id', 'records', sdtype='id')
metadata.update_column('user_id', 'records', sdtype='id')
metadata.update_column('location_id', 'records', sdtype='id')
metadata.update_column('location_id', 'locations', sdtype='id')
metadata.set_primary_key('user_id', 'users')
metadata.set_primary_key('location_id', 'locations')
metadata.add_relationship('users', 'records', 'user_id', 'user_id')
metadata.add_relationship('locations', 'records', 'location_id', 'location_id')

Expand Down Expand Up @@ -2053,8 +2053,8 @@ def test_fit_int_primary_key_regex_includes_zero(regex):
'child_data': child_data,
}
metadata = Metadata.detect_from_dataframes(data)
metadata.update_column('parent_data', 'parent_id', sdtype='id', regex_format=regex)
metadata.set_primary_key('parent_data', 'parent_id')
metadata.update_column('parent_id', 'parent_data', sdtype='id', regex_format=regex)
metadata.set_primary_key('parent_id', 'parent_data')

# Run and Assert
instance = HMASynthesizer(metadata)
Expand Down
Loading

0 comments on commit 94b2860

Please sign in to comment.