Skip to content

Commit

Permalink
switch HMA over to using HierarchicalSampler mixin
Browse files Browse the repository at this point in the history
  • Loading branch information
frances-h committed Jun 13, 2023
1 parent 830e5df commit 6661b9a
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 523 deletions.
253 changes: 39 additions & 214 deletions sdv/multi_table/hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
from tqdm import tqdm

from sdv.multi_table.base import BaseMultiTableSynthesizer
from sdv.sampling import BaseHierarchicalSampler

LOGGER = logging.getLogger(__name__)


class HMASynthesizer(BaseMultiTableSynthesizer):
class HMASynthesizer(BaseHierarchicalSampler, BaseMultiTableSynthesizer):
"""Hierarchical Modeling Algorithm One.
Args:
Expand All @@ -31,13 +32,19 @@ class HMASynthesizer(BaseMultiTableSynthesizer):
}

def __init__(self, metadata, locales=None, verbose=True):
super().__init__(metadata, locales=locales)
BaseMultiTableSynthesizer.__init__(self, metadata, locales=locales)
self._table_sizes = {}
self._max_child_rows = {}
self._augmented_tables = []
self._learned_relationships = 0
self.verbose = verbose

BaseHierarchicalSampler.__init__(
self,
self.metadata,
self._table_synthesizers,
self._table_sizes)

def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc):
"""Generate the extension columns for this child table.
Expand Down Expand Up @@ -115,15 +122,6 @@ def _clear_nans(table_data):

table_data[column] = table_data[column].fillna(fill_value)

def _get_foreign_keys(self, table_name, child_name):
foreign_keys = []
for relation in self.metadata.relationships:
if table_name == relation['parent_table_name'] and\
child_name == relation['child_table_name']:
foreign_keys.append(deepcopy(relation['child_foreign_key']))

return foreign_keys

def _augment_table(self, table, tables, table_name):
"""Generate the extension columns for this table.
Expand Down Expand Up @@ -152,7 +150,7 @@ def _augment_table(self, table, tables, table_name):
else:
child_table = tables[child_name]

foreign_keys = self._get_foreign_keys(table_name, child_name)
foreign_keys = self.metadata._get_foreign_keys(table_name, child_name)
for foreign_key in foreign_keys:
progress_bar_desc = (
f'({self._learned_relationships + 1}/{len(self.metadata.relationships)}) '
Expand Down Expand Up @@ -241,46 +239,6 @@ def _augment_tables(self, processed_data):
LOGGER.info('Augmentation Complete')
return augmented_data

def _finalize(self, sampled_data):
"""Do the final touches to the generated data.
This method reverts the previous transformations to go back
to values in the original space and also adds the parent
keys in case foreign key relationships exist between the tables.
Args:
sampled_data (dict):
Generated data
Return:
pandas.DataFrame:
Formatted synthesized data.
"""
final_data = {}
for table_name, table_rows in sampled_data.items():
parents = self.metadata._get_parent_map().get(table_name)
if parents:
for parent_name in parents:
foreign_keys = self._get_foreign_keys(parent_name, table_name)
for foreign_key in foreign_keys:
if foreign_key not in table_rows:
parent_ids = self._find_parent_ids(
table_name,
parent_name,
foreign_key,
sampled_data
)
table_rows[foreign_key] = parent_ids.to_numpy()

synthesizer = self._table_synthesizers.get(table_name)
dtypes = synthesizer._data_processor._dtypes
for name, dtype in dtypes.items():
table_rows[name] = table_rows[name].dropna().astype(dtype)

final_data[table_name] = table_rows[list(dtypes.keys())]

return final_data

def _extract_parameters(self, parent_row, table_name, foreign_key):
"""Get the params from a generated parent row.
Expand Down Expand Up @@ -308,105 +266,17 @@ def _extract_parameters(self, parent_row, table_name, foreign_key):

return flat_parameters.rename(new_keys).to_dict()

def _process_samples(self, table_name, sampled_rows):
"""Process the ``sampled_rows`` for the given ``table_name``.
Process the raw samples and convert them to the original space by reverse transforming
them. Also, when there are synthesizer columns (columns used to recreate an instance
of a synthesizer), those will be returned together.
"""
data_processor = self._table_synthesizers[table_name]._data_processor
sampled = data_processor.reverse_transform(sampled_rows)

synthesizer_columns = list(set(sampled_rows.columns) - set(sampled.columns))
if synthesizer_columns:
sampled = pd.concat([sampled, sampled_rows[synthesizer_columns]], axis=1)

return sampled

def _sample_rows(self, synthesizer, table_name, num_rows=None):
"""Sample ``num_rows`` from ``synthesizer``.
Args:
synthesizer (copula.multivariate.base):
Fitted synthesizer.
table_name (str):
Name of the table to sample from.
num_rows (int):
Number of rows to sample.
def _recreate_child_synthesizer(self, child_name, parent_name, parent_row):
foreign_key = self.metadata._get_foreign_keys(parent_name, child_name)[0]
parameters = self._extract_parameters(parent_row, child_name, foreign_key)
table_meta = self.metadata.tables[child_name]

Returns:
pandas.DataFrame:
Sampled rows, shape (, num_rows)
"""
num_rows = num_rows or synthesizer._num_rows
if synthesizer._model:
sampled_rows = synthesizer._sample(num_rows)
else:
sampled_rows = pd.DataFrame(index=range(num_rows))

return self._process_samples(table_name, sampled_rows)

def _get_child_synthesizer(self, parent_row, table_name, foreign_key):
parameters = self._extract_parameters(parent_row, table_name, foreign_key)
table_meta = self.metadata.tables[table_name]
synthesizer = self._synthesizer(table_meta, **self._table_parameters[table_name])
synthesizer = self._synthesizer(table_meta, **self._table_parameters[child_name])
synthesizer._set_parameters(parameters)
synthesizer._data_processor = self._table_synthesizers[child_name]._data_processor

return synthesizer

def _sample_child_rows(self, table_name, parent_name, parent_row, sampled_data):
"""Sample child rows that reference the given parent row.
The sampled rows will be stored in ``sampled_data`` under the ``table_name`` key.
Args:
table_name (str):
The name of the table to sample.
parent_name (str):
The name of the parent table.
parent_row (pandas.Series):
The parent row the child rows should reference.
sampled_data (dict):
A map of table name to sampled table data (pandas.DataFrame).
"""
foreign_key = self._get_foreign_keys(parent_name, table_name)[0]
synthesizer = self._get_child_synthesizer(parent_row, table_name, foreign_key)
table_rows = self._sample_rows(synthesizer, table_name)

if len(table_rows):
parent_key = self.metadata.tables[parent_name].primary_key
table_rows[foreign_key] = parent_row[parent_key]

previous = sampled_data.get(table_name)
if previous is None:
sampled_data[table_name] = table_rows
else:
sampled_data[table_name] = pd.concat(
[previous, table_rows]).reset_index(drop=True)

def _sample_children(self, table_name, sampled_data, table_rows):
"""Recursively sample the child tables of the given table.
Sampled child data will be stored into `sampled_data`.
Args:
table_name (str):
The name of the table whose children will be sampled.
sampled_data (dict):
A map of table name to the sampled table data (pandas.DataFrame).
table_rows (pandas.DataFrame):
The sampled rows of the given table.
"""
for child_name in self.metadata._get_child_map()[table_name]:
if child_name not in sampled_data:
LOGGER.info('Sampling rows from child table %s', child_name)
for _, row in table_rows.iterrows():
self._sample_child_rows(child_name, table_name, row, sampled_data)

child_rows = sampled_data[child_name]
self._sample_children(child_name, sampled_data, child_rows)

@staticmethod
def _find_parent_id(likelihoods, num_rows):
"""Find the parent id for one row based on the likelihoods of parent id values.
Expand Down Expand Up @@ -484,89 +354,44 @@ def _get_likelihoods(self, table_rows, parent_rows, table_name, foreign_key):

return pd.DataFrame(likelihoods, index=table_rows.index)

def _find_parent_ids(self, table_name, parent_name, foreign_key, sampled_data):
def _find_parent_ids(self, child_table, parent_table, child_name, parent_name, foreign_key):
"""Find parent ids for the given table and foreign key.
The parent ids are chosen randomly based on the likelihood of the available
parent ids in the parent table. If the parent table is not sampled, this method
will first sample rows for the parent table.
Args:
table_name (str):
The name of the table to find parent ids for.
parent_name (str):
The name of the parent table.
child_table (pd.DataFrame):
parent_table (pd.DataFrame):
child_name (str):
parent_name (dict):
Map of table name to sampled data (pandas.DataFrame).
foreign_key (str):
The name of the foreign key column in the child table.
sampled_data (dict):
Map of table name to sampled data (pandas.DataFrame).
Returns:
pandas.Series:
The parent ids for the given table data.
"""
table_rows = sampled_data[table_name]
if parent_name in sampled_data:
parent_rows = sampled_data[parent_name]
else:
ratio = self._table_sizes[parent_name] / self._table_sizes[table_name]
num_parent_rows = max(int(round(len(table_rows) * ratio)), 1)
parent_model = self._table_synthesizers[parent_name]
parent_rows = self._sample_rows(parent_model, parent_name, num_parent_rows)

primary_key = self.metadata.tables[parent_name].primary_key
parent_rows = parent_rows.set_index(primary_key)
num_rows = parent_rows[f'__{table_name}__{foreign_key}__num_rows'].fillna(0).clip(0)
parent_table = parent_table.set_index(primary_key)
num_rows = parent_table[f'__{child_name}__{foreign_key}__num_rows'].fillna(0).clip(0)

likelihoods = self._get_likelihoods(table_rows, parent_rows, table_name, foreign_key)
likelihoods = self._get_likelihoods(child_table, parent_table, child_name, foreign_key)
return likelihoods.apply(self._find_parent_id, axis=1, num_rows=num_rows)

def _sample_table(self, table_name, scale=1.0, sample_children=True, sampled_data=None):
"""Sample a single table and optionally its children."""
if sampled_data is None:
sampled_data = {}

num_rows = int(self._table_sizes[table_name] * scale)

LOGGER.info('Sampling %s rows from table %s', num_rows, table_name)

synthesizer = self._table_synthesizers[table_name]
table_rows = self._sample_rows(synthesizer, table_name, num_rows)
sampled_data[table_name] = table_rows

if sample_children:
self._sample_children(table_name, sampled_data, table_rows)

return sampled_data

def _sample(self, scale=1.0):
"""Sample the entire dataset.
Returns a dictionary with all the tables of the dataset. The amount of rows sampled will
depend from table to table. This is because the children tables are created modelling the
relation that they have with their parent tables, so its behavior may change from one
table to another.
Args:
scale (float):
A float representing how much to scale the data by. If scale is set to ``1.0``,
this does not scale the sizes of the tables. If ``scale`` is greater than ``1.0``
create more rows than the original data by a factor of ``scale``.
If ``scale`` is lower than ``1.0`` create fewer rows by the factor of ``scale``
than the original tables. Defaults to ``1.0``.
Returns:
dict:
A dictionary containing as keys the names of the tables and as values the
sampled data tables as ``pandas.DataFrame``.
Raises:
NotFittedError:
A ``NotFittedError`` is raised when the ``SDV`` instance has not been fitted yet.
"""
sampled_data = {}
for table in self.metadata.tables:
if not self.metadata._get_parent_map().get(table):
self._sample_table(table, scale=scale, sampled_data=sampled_data)

return self._finalize(sampled_data)
def _add_foreign_key_columns(self, child_table, parent_table, child_name, parent_name):
for foreign_key in self.metadata._get_foreign_keys(parent_name, child_name):
if foreign_key not in child_table:
parent_ids = self._find_parent_ids(
child_table=child_table,
parent_table=parent_table,
child_name=child_name,
parent_name=parent_name,
foreign_key=foreign_key
)
child_table[foreign_key] = parent_ids.to_numpy()
4 changes: 4 additions & 0 deletions sdv/sampling/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
"""SDV Sampling module."""

from sdv.sampling.hierarchical_sampler import BaseHierarchicalSampler
from sdv.sampling.independent_sampler import BaseIndependentSampler
from sdv.sampling.tabular import Condition

__all__ = [
'BaseHierarchicalSampler',
'BaseIndependentSampler',
'Condition',
]
2 changes: 1 addition & 1 deletion sdv/sampling/hierarchical_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _sample_rows(self, synthesizer, num_rows=None):
Sampled rows, shape (, num_rows)
"""
num_rows = num_rows or synthesizer._num_rows
return synthesizer._sample_batch(num_rows, remove_extra_columns=False)
return synthesizer._sample_batch(int(num_rows), keep_extra_columns=True)

def _get_num_rows_from_parent(self, parent_row, child_name, foreign_key):
"""Get the number of rows to sample for the child from the parent row."""
Expand Down
11 changes: 6 additions & 5 deletions sdv/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ def _sample_rows(self, num_rows, conditions=None, transformed_conditions=None,
Maximum tolerance when considering a float match.
previous_rows (pandas.DataFrame):
Valid rows sampled in the previous iterations.
remove_extra_columns (bool):
keep_extra_columns (bool):
Whether to keep extra columns from the sampled data. Defaults to False.
Returns:
Expand All @@ -604,10 +604,11 @@ def _sample_rows(self, num_rows, conditions=None, transformed_conditions=None,
* int:
Number of rows that are considered valid.
"""
if not self._random_state_set:
self._set_random_state(FIXED_RNG_SEED)
if (self._model and (
self._data_processor.get_sdtypes(primary_keys=False) or keep_extra_columns)):
if not self._random_state_set:
self._set_random_state(FIXED_RNG_SEED)

if self._data_processor.get_sdtypes(primary_keys=False):
if conditions is None:
raw_sampled = self._sample(num_rows)
else:
Expand Down Expand Up @@ -682,7 +683,7 @@ def _sample_batch(self, batch_size, max_tries=100,
output_file_path (str or None):
The file to periodically write sampled rows to. If None, does not write
rows anywhere.
remove_extra_columns (bool):
keep_extra_columns (bool):
Whether to keep extra columns from the sampled data. Defaults to False.
Returns:
Expand Down
Loading

0 comments on commit 6661b9a

Please sign in to comment.