Skip to content

Commit

Permalink
Add a minimum number of rows for sample
Browse files Browse the repository at this point in the history
  • Loading branch information
lajohn4747 committed Jun 13, 2024
1 parent 18f392c commit 698f391
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
14 changes: 13 additions & 1 deletion sdv/sampling/hierarchical_sampler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Hierarchical Samplers."""
import logging
import warnings

import pandas as pd

Expand Down Expand Up @@ -138,7 +139,7 @@ def _enforce_table_size(self, child_name, table_name, scale, sampled_data):
sampled_data (dict):
A dictionary mapping table names to sampled data (pd.DataFrame).
"""
total_num_rows = round(self._table_sizes[child_name] * scale)
total_num_rows = max(round(self._table_sizes[child_name] * scale), 1)
for foreign_key in self.metadata._get_foreign_keys(table_name, child_name):
num_rows_key = f'__{child_name}__{foreign_key}__num_rows'
min_rows = getattr(self, '_min_child_rows', {num_rows_key: 0})[num_rows_key]
Expand Down Expand Up @@ -273,13 +274,24 @@ def _sample(self, scale=1.0):
# DFS to sample roots and then their children
non_root_parents = set(self.metadata._get_parent_map().keys())
root_parents = set(self.metadata.tables.keys()) - non_root_parents
send_min_sample_warning = False
for table in root_parents:
num_rows = round(self._table_sizes[table] * scale)
if num_rows <= 0:
send_min_sample_warning = True
num_rows = 1
synthesizer = self._table_synthesizers[table]
LOGGER.info(f'Sampling {num_rows} rows from table {table}')
sampled_data[table] = self._sample_rows(synthesizer, num_rows)
self._sample_children(table_name=table, sampled_data=sampled_data, scale=scale)

if send_min_sample_warning:
warn_msg = (
"The 'scale' parameter it too small. Some tables may have 1 row."
' For better quality data, please choose a larger scale.'
)
warnings.warn(warn_msg)

added_relationships = set()
for relationship in self.metadata.relationships:
parent_name = relationship['parent_table_name']
Expand Down
24 changes: 24 additions & 0 deletions tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -1839,3 +1839,27 @@ def test_disjointed_tables():
# Assert
for table in real_data:
assert list(real_data[table].columns) == list(disjoin_synthetic_data[table].columns)


def test_small_sample():
"""Test that the sample function still works with a small scale"""
# Setup
data, metadata = download_demo(
modality='multi_table',
dataset_name='fake_hotels'
)
synthesizer = HMASynthesizer(metadata)
synthesizer.fit(data)

# Run and Assert
warn_msg = re.escape(
"The 'scale' parameter it too small. Some tables may have 1 row."
' For better quality data, please choose a larger scale.'
)
with pytest.warns(Warning, match=warn_msg):
synthetic_data = synthesizer.sample(scale=0.01)

assert (len(synthetic_data['hotels']) == 1)
assert (len(synthetic_data['guests']) >= len(data['guests']) * .01)
assert synthetic_data['hotels'].columns.tolist() == data['hotels'].columns.tolist()
assert synthetic_data['guests'].columns.tolist() == data['guests'].columns.tolist()

0 comments on commit 698f391

Please sign in to comment.