Skip to content

Commit

Permalink
Improve usage of detect_from_dataframes function (#2221)
Browse files Browse the repository at this point in the history
  • Loading branch information
amontanez24 committed Sep 12, 2024
1 parent 065a128 commit b1a7f7d
Show file tree
Hide file tree
Showing 13 changed files with 101 additions and 73 deletions.
3 changes: 1 addition & 2 deletions sdv/io/local/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ def create_metadata(self, data):
An ``sdv.metadata.Metadata`` object with the detected metadata
properties from the data.
"""
metadata = Metadata()
metadata.detect_from_dataframes(data)
metadata = Metadata.detect_from_dataframes(data)
return metadata

def read(self):
Expand Down
25 changes: 25 additions & 0 deletions sdv/metadata/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,31 @@ def load_from_dict(cls, metadata_dict, single_table_name=None):
instance._set_metadata_dict(metadata_dict, single_table_name)
return instance

@classmethod
def detect_from_dataframes(cls, data):
"""Detect the metadata for all tables in a dictionary of dataframes.
This method automatically detects the ``sdtypes`` for the given ``pandas.DataFrames``.
All data column names are converted to strings.
Args:
data (dict):
Dictionary of table names to dataframes.
Returns:
Metadata:
A new metadata object with the sdtypes detected from the data.
"""
if not data or not all(isinstance(df, pd.DataFrame) for df in data.values()):
raise ValueError('The provided dictionary must contain only pandas DataFrame objects.')

metadata = Metadata()
for table_name, dataframe in data.items():
metadata.detect_table_from_dataframe(table_name, dataframe)

metadata._detect_relationships(data)
return metadata

@classmethod
def detect_from_dataframe(cls, data, table_name=DEFAULT_SINGLE_TABLE_NAME):
"""Detect the metadata for a DataFrame.
Expand Down
6 changes: 2 additions & 4 deletions tests/integration/data_processing/test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,7 @@ def test_with_primary_key_numerical(self):
"""
# Load metadata and data
data, _ = download_demo('single_table', 'adult')
adult_metadata = Metadata()
adult_metadata.detect_from_dataframes({'adult': data})
adult_metadata = Metadata.detect_from_dataframes({'adult': data})

# Add primary key field
adult_metadata.add_column('adult', 'id', sdtype='id')
Expand Down Expand Up @@ -196,8 +195,7 @@ def test_with_alternate_keys(self):
# Load metadata and data
data, _ = download_demo('single_table', 'adult')
data['fnlwgt'] = data['fnlwgt'].astype(str)
adult_metadata = Metadata()
adult_metadata.detect_from_dataframes({'adult': data})
adult_metadata = Metadata.detect_from_dataframes({'adult': data})

# Add primary key field
adult_metadata.add_column('adult', 'id', sdtype='id')
Expand Down
9 changes: 3 additions & 6 deletions tests/integration/lite/test_single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ def test_sample():
data = pd.DataFrame({'a': [1, 2, 3, np.nan]})

# Run
metadata = Metadata()
metadata.detect_from_dataframes({'adult': data})
metadata = Metadata.detect_from_dataframes({'adult': data})
preset = SingleTablePreset(metadata, name='FAST_ML')
preset.fit(data)
samples = preset.sample(num_rows=10, max_tries_per_batch=20, batch_size=5)
Expand All @@ -29,8 +28,7 @@ def test_sample_with_constraints():
data = pd.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]})

# Run
metadata = Metadata()
metadata.detect_from_dataframes({'table': data})
metadata = Metadata.detect_from_dataframes({'table': data})
preset = SingleTablePreset(metadata, name='FAST_ML')
constraints = [
{
Expand All @@ -57,8 +55,7 @@ def test_warnings_are_shown():
data = pd.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]})

# Run
metadata = Metadata()
metadata.detect_from_dataframes({'table': data})
metadata = Metadata.detect_from_dataframes({'table': data})

with pytest.warns(FutureWarning, match=warn_message):
preset = SingleTablePreset(metadata, name='FAST_ML')
Expand Down
7 changes: 2 additions & 5 deletions tests/integration/metadata/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,8 @@ def test_detect_from_dataframes_multi_table():
# Setup
real_data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels')

metadata = Metadata()

# Run
metadata.detect_from_dataframes(real_data)
metadata = Metadata.detect_from_dataframes(real_data)

# Assert
metadata.update_column(
Expand Down Expand Up @@ -90,8 +88,7 @@ def test_detect_from_dataframes_single_table():
# Setup
data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels')

metadata = Metadata()
metadata.detect_from_dataframes({'table_1': data['hotels']})
metadata = Metadata.detect_from_dataframes({'table_1': data['hotels']})

# Run
metadata.validate()
Expand Down
6 changes: 2 additions & 4 deletions tests/integration/metadata/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ def test_visualize_graph_for_single_table():
"""Test it runs when a column name contains symbols."""
# Setup
data = pd.DataFrame({'\\|=/bla@#$324%^,"&*()><...': ['a', 'b', 'c']})
metadata = Metadata()
metadata.detect_from_dataframes({'table': data})
metadata = Metadata.detect_from_dataframes({'table': data})
model = GaussianCopulaSynthesizer(metadata)

# Run
Expand All @@ -26,8 +25,7 @@ def test_visualize_graph_for_multi_table():
data1 = pd.DataFrame({'\\|=/bla@#$324%^,"&*()><...': ['a', 'b', 'c']})
data2 = pd.DataFrame({'\\|=/bla@#$324%^,"&*()><...': ['a', 'b', 'c']})
tables = {'1': data1, '2': data2}
metadata = Metadata()
metadata.detect_from_dataframes(tables)
metadata = Metadata.detect_from_dataframes(tables)
metadata.update_column('1', '\\|=/bla@#$324%^,"&*()><...', sdtype='id')
metadata.update_column('2', '\\|=/bla@#$324%^,"&*()><...', sdtype='id')
metadata.set_primary_key('1', '\\|=/bla@#$324%^,"&*()><...')
Expand Down
24 changes: 8 additions & 16 deletions tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -1286,8 +1286,7 @@ def test_metadata_updated_no_warning(self, tmp_path):
assert len(captured_warnings) == 0

# Run 2
metadata_detect = Metadata()
metadata_detect.detect_from_dataframes(data)
metadata_detect = Metadata.detect_from_dataframes(data)

metadata_detect.relationships = metadata.relationships
for table_name, table_metadata in metadata.tables.items():
Expand Down Expand Up @@ -1326,8 +1325,7 @@ def test_metadata_updated_warning_detect(self):
"""
# Setup
data, metadata = download_demo('multi_table', 'got_families')
metadata_detect = Metadata()
metadata_detect.detect_from_dataframes(data)
metadata_detect = Metadata.detect_from_dataframes(data)

metadata_detect.relationships = metadata.relationships
for table_name, table_metadata in metadata.tables.items():
Expand Down Expand Up @@ -1456,8 +1454,7 @@ def test_sampling_with_unknown_sdtype_numerical_column(self):

tables_dict = {'people': table1, 'company': table2}

metadata = Metadata()
metadata.detect_from_dataframes(tables_dict)
metadata = Metadata.detect_from_dataframes(tables_dict)

# Run
synth = HMASynthesizer(metadata)
Expand Down Expand Up @@ -1890,8 +1887,7 @@ def test_detect_from_dataframe_numerical_col():
child_table_name='child_data',
)

test_metadata = Metadata()
test_metadata.detect_from_dataframes(data)
test_metadata = Metadata.detect_from_dataframes(data)
test_metadata.update_column('parent_data', '1', sdtype='id')
test_metadata.update_column('child_data', '3', sdtype='id')
test_metadata.update_column('child_data', '4', sdtype='id')
Expand All @@ -1914,8 +1910,7 @@ def test_detect_from_dataframe_numerical_col():
assert sample['parent_data'].columns.tolist() == data['parent_data'].columns.tolist()
assert sample['child_data'].columns.tolist() == data['child_data'].columns.tolist()

test_metadata = Metadata()
test_metadata.detect_from_dataframes(data)
test_metadata = Metadata.detect_from_dataframes(data)


def test_table_name_logging(caplog):
Expand All @@ -1930,8 +1925,7 @@ def test_table_name_logging(caplog):
'parent_data': parent_data,
'child_data': child_data,
}
metadata = Metadata()
metadata.detect_from_dataframes(data)
metadata = Metadata.detect_from_dataframes(data)
instance = HMASynthesizer(metadata)

# Run
Expand Down Expand Up @@ -2009,8 +2003,7 @@ def test_hma_synthesizer_with_fixed_combinations():
}

# Creating metadata for the dataset
metadata = Metadata()
metadata.detect_from_dataframes(data)
metadata = Metadata.detect_from_dataframes(data)

metadata.update_column('users', 'user_id', sdtype='id')
metadata.update_column('records', 'record_id', sdtype='id')
Expand Down Expand Up @@ -2059,8 +2052,7 @@ def test_fit_int_primary_key_regex_includes_zero(regex):
'parent_data': parent_data,
'child_data': child_data,
}
metadata = Metadata()
metadata.detect_from_dataframes(data)
metadata = Metadata.detect_from_dataframes(data)
metadata.update_column('parent_data', 'parent_id', sdtype='id', regex_format=regex)
metadata.set_primary_key('parent_data', 'parent_id')

Expand Down
12 changes: 4 additions & 8 deletions tests/integration/sequential/test_par.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ def _get_par_data_and_metadata():
'entity': [1, 1, 2, 2],
'context': ['a', 'a', 'b', 'b'],
})
metadata = Metadata()
metadata.detect_from_dataframes({'table': data})
metadata = Metadata.detect_from_dataframes({'table': data})
metadata.update_column('table', 'entity', sdtype='id')
metadata.set_sequence_key('table', 'entity')
metadata.set_sequence_index('table', 'date')
Expand All @@ -34,8 +33,7 @@ def test_par():
# Setup
data = load_demo()
data['date'] = pd.to_datetime(data['date'])
metadata = Metadata()
metadata.detect_from_dataframes({'table': data})
metadata = Metadata.detect_from_dataframes({'table': data})
metadata.update_column('table', 'store_id', sdtype='id')
metadata.set_sequence_key('table', 'store_id')
metadata.set_sequence_index('table', 'date')
Expand Down Expand Up @@ -68,8 +66,7 @@ def test_column_after_date_simple():
'date': [date, date],
'col2': ['hello', 'world'],
})
metadata = Metadata()
metadata.detect_from_dataframes({'table': data})
metadata = Metadata.detect_from_dataframes({'table': data})
metadata.update_column('table', 'col', sdtype='id')
metadata.set_sequence_key('table', 'col')
metadata.set_sequence_index('table', 'date')
Expand Down Expand Up @@ -348,8 +345,7 @@ def test_par_unique_sequence_index_with_enforce_min_max():
test_df[['visits', 'pre_date']] = test_df[['visits', 'pre_date']].apply(
pd.to_datetime, format='%Y-%m-%d', errors='coerce'
)
metadata = Metadata()
metadata.detect_from_dataframes({'table': test_df})
metadata = Metadata.detect_from_dataframes({'table': test_df})
metadata.update_column(table_name='table', column_name='s_key', sdtype='id')
metadata.set_sequence_key('table', 's_key')
metadata.set_sequence_index('table', 'visits')
Expand Down
21 changes: 7 additions & 14 deletions tests/integration/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,10 +313,9 @@ def test_config_creation_doesnt_raise_error():
'address_col': ['223 Williams Rd', '75 Waltham St', '77 Mass Ave'],
'numerical_col': [1, 2, 3],
})
test_metadata = Metadata()

# Run
test_metadata.detect_from_dataframes({'table': test_data})
test_metadata = Metadata.detect_from_dataframes({'table': test_data})
test_metadata.update_column(
table_name='table', column_name='address_col', sdtype='address', pii=False
)
Expand All @@ -335,8 +334,7 @@ def test_transformers_correctly_auto_assigned():
'categorical_col': ['a', 'b', 'a'],
})

metadata = Metadata()
metadata.detect_from_dataframes({'table': data})
metadata = Metadata.detect_from_dataframes({'table': data})
metadata.update_column(
table_name='table', column_name='primary_key', sdtype='id', regex_format='user-[0-9]{3}'
)
Expand Down Expand Up @@ -425,8 +423,7 @@ def test_auto_assign_transformers_and_update_with_pii():
}
)

metadata = Metadata()
metadata.detect_from_dataframes({'table': data})
metadata = Metadata.detect_from_dataframes({'table': data})

# Run
metadata.update_column(table_name='table', column_name='id', sdtype='first_name')
Expand Down Expand Up @@ -458,8 +455,7 @@ def test_refitting_a_model():
}
)

metadata = Metadata()
metadata.detect_from_dataframes({'table': data})
metadata = Metadata.detect_from_dataframes({'table': data})
metadata.update_column(table_name='table', column_name='name', sdtype='name')
metadata.update_column('table', 'id', sdtype='id')
metadata.set_primary_key('table', 'id')
Expand Down Expand Up @@ -619,8 +615,7 @@ def test_metadata_updated_no_warning(mock__fit, tmp_path):
assert len(captured_warnings) == 0

# Run 2
metadata_detect = Metadata()
metadata_detect.detect_from_dataframes({'mock_table': data})
metadata_detect = Metadata.detect_from_dataframes({'mock_table': data})
file_name = tmp_path / 'singletable.json'
metadata_detect.save_to_json(file_name)
with warnings.catch_warnings(record=True) as captured_warnings:
Expand Down Expand Up @@ -737,8 +732,7 @@ def test_fit_raises_version_error():
'col 2': [4, 5, 6],
'col 3': ['a', 'b', 'c'],
})
metadata = Metadata()
metadata.detect_from_dataframes({'table': data})
metadata = Metadata.detect_from_dataframes({'table': data})
instance = BaseSingleTableSynthesizer(metadata)
instance._fitted_sdv_version = '1.0.0'

Expand Down Expand Up @@ -813,10 +807,9 @@ def test_detect_from_dataframe_numerical_col(synthesizer_class):
2: [4, 5, 6],
3: ['a', 'b', 'c'],
})
metadata = Metadata()

# Run
metadata.detect_from_dataframes({'table': data})
metadata = Metadata.detect_from_dataframes({'table': data})
instance = synthesizer_class(metadata)
instance.fit(data)
sample = instance.sample(5)
Expand Down
9 changes: 3 additions & 6 deletions tests/integration/single_table/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,7 @@ def test_custom_constraints_from_file(tmpdir):
'categorical_col': ['a', 'b', 'a'],
})

metadata = Metadata()
metadata.detect_from_dataframes({'table': data})
metadata = Metadata.detect_from_dataframes({'table': data})
metadata.update_column(table_name='table', column_name='pii_col', sdtype='address', pii=True)
synthesizer = GaussianCopulaSynthesizer(
metadata, enforce_min_max_values=False, enforce_rounding=False
Expand Down Expand Up @@ -383,8 +382,7 @@ def test_custom_constraints_from_object(tmpdir):
'categorical_col': ['a', 'b', 'a'],
})

metadata = Metadata()
metadata.detect_from_dataframes({'table': data})
metadata = Metadata.detect_from_dataframes({'table': data})
metadata.update_column(table_name='table', column_name='pii_col', sdtype='address', pii=True)
synthesizer = GaussianCopulaSynthesizer(
metadata, enforce_min_max_values=False, enforce_rounding=False
Expand Down Expand Up @@ -816,8 +814,7 @@ def reverse_transform(column_names, data):
'number': ['1', '2', '3'],
'other': [7, 8, 9],
})
metadata = Metadata()
metadata.detect_from_dataframes({'table': data})
metadata = Metadata.detect_from_dataframes({'table': data})
metadata.update_column('table', 'key', sdtype='id', regex_format=r'\w_\d')
metadata.set_primary_key('table', 'key')
synth = GaussianCopulaSynthesizer(metadata)
Expand Down
9 changes: 3 additions & 6 deletions tests/integration/single_table/test_copulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,7 @@ def test_update_transformers_with_id_generator():
sample_num = 20
data = pd.DataFrame({'user_id': list(range(4)), 'user_cat': ['a', 'b', 'c', 'd']})

stm = Metadata()
stm.detect_from_dataframes({'table': data})
stm = Metadata.detect_from_dataframes({'table': data})
stm.update_column('table', 'user_id', sdtype='id')
stm.set_primary_key('table', 'user_id')

Expand Down Expand Up @@ -406,8 +405,7 @@ def test_categorical_column_with_numbers():
'numerical_col': np.random.rand(20),
})

metadata = Metadata()
metadata.detect_from_dataframes({'table': data})
metadata = Metadata.detect_from_dataframes({'table': data})

synthesizer = GaussianCopulaSynthesizer(metadata)

Expand Down Expand Up @@ -435,8 +433,7 @@ def test_unknown_sdtype():
'numerical_col': np.random.rand(3),
})

metadata = Metadata()
metadata.detect_from_dataframes({'table': data})
metadata = Metadata.detect_from_dataframes({'table': data})
metadata.update_column('table', 'unknown', sdtype='unknown')

synthesizer = GaussianCopulaSynthesizer(metadata)
Expand Down
Loading

0 comments on commit b1a7f7d

Please sign in to comment.