Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use HierarchicalSampler mixin with HMA #1449

Merged
merged 2 commits into from
Jun 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 41 additions & 216 deletions sdv/multi_table/hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
from tqdm import tqdm

from sdv.multi_table.base import BaseMultiTableSynthesizer
from sdv.sampling import BaseHierarchicalSampler

LOGGER = logging.getLogger(__name__)


class HMASynthesizer(BaseMultiTableSynthesizer):
class HMASynthesizer(BaseHierarchicalSampler, BaseMultiTableSynthesizer):
amontanez24 marked this conversation as resolved.
Show resolved Hide resolved
"""Hierarchical Modeling Algorithm One.
Args:
Expand All @@ -31,13 +32,19 @@ class HMASynthesizer(BaseMultiTableSynthesizer):
}

def __init__(self, metadata, locales=None, verbose=True):
super().__init__(metadata, locales=locales)
BaseMultiTableSynthesizer.__init__(self, metadata, locales=locales)
self._table_sizes = {}
self._max_child_rows = {}
self._augmented_tables = []
self._learned_relationships = 0
self.verbose = verbose

BaseHierarchicalSampler.__init__(
self,
self.metadata,
self._table_synthesizers,
self._table_sizes)

def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc):
"""Generate the extension columns for this child table.
Expand Down Expand Up @@ -115,15 +122,6 @@ def _clear_nans(table_data):

table_data[column] = table_data[column].fillna(fill_value)

def _get_foreign_keys(self, table_name, child_name):
foreign_keys = []
for relation in self.metadata.relationships:
if table_name == relation['parent_table_name'] and\
child_name == relation['child_table_name']:
foreign_keys.append(deepcopy(relation['child_foreign_key']))

return foreign_keys

def _augment_table(self, table, tables, table_name):
"""Generate the extension columns for this table.
Expand Down Expand Up @@ -152,7 +150,7 @@ def _augment_table(self, table, tables, table_name):
else:
child_table = tables[child_name]

foreign_keys = self._get_foreign_keys(table_name, child_name)
foreign_keys = self.metadata._get_foreign_keys(table_name, child_name)
for foreign_key in foreign_keys:
progress_bar_desc = (
f'({self._learned_relationships + 1}/{len(self.metadata.relationships)}) '
Expand Down Expand Up @@ -241,46 +239,6 @@ def _augment_tables(self, processed_data):
LOGGER.info('Augmentation Complete')
return augmented_data

def _finalize(self, sampled_data):
"""Do the final touches to the generated data.
This method reverts the previous transformations to go back
to values in the original space and also adds the parent
keys in case foreign key relationships exist between the tables.
Args:
sampled_data (dict):
Generated data
Return:
pandas.DataFrame:
Formatted synthesized data.
"""
final_data = {}
for table_name, table_rows in sampled_data.items():
parents = self.metadata._get_parent_map().get(table_name)
if parents:
for parent_name in parents:
foreign_keys = self._get_foreign_keys(parent_name, table_name)
for foreign_key in foreign_keys:
if foreign_key not in table_rows:
parent_ids = self._find_parent_ids(
table_name,
parent_name,
foreign_key,
sampled_data
)
table_rows[foreign_key] = parent_ids.to_numpy()

synthesizer = self._table_synthesizers.get(table_name)
dtypes = synthesizer._data_processor._dtypes
for name, dtype in dtypes.items():
table_rows[name] = table_rows[name].dropna().astype(dtype)

final_data[table_name] = table_rows[list(dtypes.keys())]

return final_data

def _extract_parameters(self, parent_row, table_name, foreign_key):
"""Get the params from a generated parent row.
Expand Down Expand Up @@ -308,105 +266,17 @@ def _extract_parameters(self, parent_row, table_name, foreign_key):

return flat_parameters.rename(new_keys).to_dict()

def _process_samples(self, table_name, sampled_rows):
"""Process the ``sampled_rows`` for the given ``table_name``.
Process the raw samples and convert them to the original space by reverse transforming
them. Also, when there are synthesizer columns (columns used to recreate an instance
of a synthesizer), those will be returned together.
"""
data_processor = self._table_synthesizers[table_name]._data_processor
sampled = data_processor.reverse_transform(sampled_rows)

synthesizer_columns = list(set(sampled_rows.columns) - set(sampled.columns))
if synthesizer_columns:
sampled = pd.concat([sampled, sampled_rows[synthesizer_columns]], axis=1)

return sampled
def _recreate_child_synthesizer(self, child_name, parent_name, parent_row):
foreign_key = self.metadata._get_foreign_keys(parent_name, child_name)[0]
parameters = self._extract_parameters(parent_row, child_name, foreign_key)
table_meta = self.metadata.tables[child_name]

def _sample_rows(self, synthesizer, table_name, num_rows=None):
"""Sample ``num_rows`` from ``synthesizer``.
Args:
synthesizer (copula.multivariate.base):
Fitted synthesizer.
table_name (str):
Name of the table to sample from.
num_rows (int):
Number of rows to sample.
Returns:
pandas.DataFrame:
Sampled rows, shape (, num_rows)
"""
num_rows = num_rows or synthesizer._num_rows
if synthesizer._model:
sampled_rows = synthesizer._sample(num_rows)
else:
sampled_rows = pd.DataFrame(index=range(num_rows))

return self._process_samples(table_name, sampled_rows)

def _get_child_synthesizer(self, parent_row, table_name, foreign_key):
parameters = self._extract_parameters(parent_row, table_name, foreign_key)
table_meta = self.metadata.tables[table_name]
synthesizer = self._synthesizer(table_meta, **self._table_parameters[table_name])
synthesizer = self._synthesizer(table_meta, **self._table_parameters[child_name])
synthesizer._set_parameters(parameters)
synthesizer._data_processor = self._table_synthesizers[child_name]._data_processor

return synthesizer

def _sample_child_rows(self, table_name, parent_name, parent_row, sampled_data):
"""Sample child rows that reference the given parent row.
The sampled rows will be stored in ``sampled_data`` under the ``table_name`` key.
Args:
table_name (str):
The name of the table to sample.
parent_name (str):
The name of the parent table.
parent_row (pandas.Series):
The parent row the child rows should reference.
sampled_data (dict):
A map of table name to sampled table data (pandas.DataFrame).
"""
foreign_key = self._get_foreign_keys(parent_name, table_name)[0]
synthesizer = self._get_child_synthesizer(parent_row, table_name, foreign_key)
table_rows = self._sample_rows(synthesizer, table_name)

if len(table_rows):
parent_key = self.metadata.tables[parent_name].primary_key
table_rows[foreign_key] = parent_row[parent_key]

previous = sampled_data.get(table_name)
if previous is None:
sampled_data[table_name] = table_rows
else:
sampled_data[table_name] = pd.concat(
[previous, table_rows]).reset_index(drop=True)

def _sample_children(self, table_name, sampled_data, table_rows):
"""Recursively sample the child tables of the given table.
Sampled child data will be stored into `sampled_data`.
Args:
table_name (str):
The name of the table whose children will be sampled.
sampled_data (dict):
A map of table name to the sampled table data (pandas.DataFrame).
table_rows (pandas.DataFrame):
The sampled rows of the given table.
"""
for child_name in self.metadata._get_child_map()[table_name]:
if child_name not in sampled_data:
LOGGER.info('Sampling rows from child table %s', child_name)
for _, row in table_rows.iterrows():
self._sample_child_rows(child_name, table_name, row, sampled_data)

child_rows = sampled_data[child_name]
self._sample_children(child_name, sampled_data, child_rows)

@staticmethod
def _find_parent_id(likelihoods, num_rows):
"""Find the parent id for one row based on the likelihoods of parent id values.
Expand Down Expand Up @@ -484,89 +354,44 @@ def _get_likelihoods(self, table_rows, parent_rows, table_name, foreign_key):

return pd.DataFrame(likelihoods, index=table_rows.index)

def _find_parent_ids(self, table_name, parent_name, foreign_key, sampled_data):
def _find_parent_ids(self, child_table, parent_table, child_name, parent_name, foreign_key):
"""Find parent ids for the given table and foreign key.
The parent ids are chosen randomly based on the likelihood of the available
parent ids in the parent table. If the parent table is not sampled, this method
will first sample rows for the parent table.
parent ids in the parent table.
Args:
table_name (str):
The name of the table to find parent ids for.
parent_name (str):
The name of the parent table.
child_table (pd.DataFrame):
frances-h marked this conversation as resolved.
Show resolved Hide resolved
The child table dataframe.
parent_table (pd.DataFrame):
The parent table dataframe.
child_name (str):
The name of the child table.
parent_name (dict):
Map of table name to sampled data (pandas.DataFrame).
foreign_key (str):
The name of the foreign key column in the child table.
sampled_data (dict):
Map of table name to sampled data (pandas.DataFrame).
Returns:
pandas.Series:
The parent ids for the given table data.
"""
table_rows = sampled_data[table_name]
if parent_name in sampled_data:
parent_rows = sampled_data[parent_name]
else:
ratio = self._table_sizes[parent_name] / self._table_sizes[table_name]
num_parent_rows = max(int(round(len(table_rows) * ratio)), 1)
parent_model = self._table_synthesizers[parent_name]
parent_rows = self._sample_rows(parent_model, parent_name, num_parent_rows)

# Create a copy of the parent table with the primary key as index to calculate likilihoods
primary_key = self.metadata.tables[parent_name].primary_key
parent_rows = parent_rows.set_index(primary_key)
num_rows = parent_rows[f'__{table_name}__{foreign_key}__num_rows'].fillna(0).clip(0)
parent_table = parent_table.set_index(primary_key)
frances-h marked this conversation as resolved.
Show resolved Hide resolved
num_rows = parent_table[f'__{child_name}__{foreign_key}__num_rows'].fillna(0).clip(0)

likelihoods = self._get_likelihoods(table_rows, parent_rows, table_name, foreign_key)
likelihoods = self._get_likelihoods(child_table, parent_table, child_name, foreign_key)
return likelihoods.apply(self._find_parent_id, axis=1, num_rows=num_rows)

def _sample_table(self, table_name, scale=1.0, sample_children=True, sampled_data=None):
"""Sample a single table and optionally its children."""
if sampled_data is None:
sampled_data = {}

num_rows = int(self._table_sizes[table_name] * scale)

LOGGER.info('Sampling %s rows from table %s', num_rows, table_name)

synthesizer = self._table_synthesizers[table_name]
table_rows = self._sample_rows(synthesizer, table_name, num_rows)
sampled_data[table_name] = table_rows

if sample_children:
self._sample_children(table_name, sampled_data, table_rows)

return sampled_data

def _sample(self, scale=1.0):
"""Sample the entire dataset.
Returns a dictionary with all the tables of the dataset. The amount of rows sampled will
depend from table to table. This is because the children tables are created modelling the
relation that they have with their parent tables, so its behavior may change from one
table to another.
Args:
scale (float):
A float representing how much to scale the data by. If scale is set to ``1.0``,
this does not scale the sizes of the tables. If ``scale`` is greater than ``1.0``
create more rows than the original data by a factor of ``scale``.
If ``scale`` is lower than ``1.0`` create fewer rows by the factor of ``scale``
than the original tables. Defaults to ``1.0``.
Returns:
dict:
A dictionary containing as keys the names of the tables and as values the
sampled data tables as ``pandas.DataFrame``.
Raises:
NotFittedError:
A ``NotFittedError`` is raised when the ``SDV`` instance has not been fitted yet.
"""
sampled_data = {}
for table in self.metadata.tables:
if not self.metadata._get_parent_map().get(table):
self._sample_table(table, scale=scale, sampled_data=sampled_data)

return self._finalize(sampled_data)
def _add_foreign_key_columns(self, child_table, parent_table, child_name, parent_name):
for foreign_key in self.metadata._get_foreign_keys(parent_name, child_name):
if foreign_key not in child_table:
parent_ids = self._find_parent_ids(
child_table=child_table,
parent_table=parent_table,
child_name=child_name,
parent_name=parent_name,
foreign_key=foreign_key
)
child_table[foreign_key] = parent_ids.to_numpy()
4 changes: 4 additions & 0 deletions sdv/sampling/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
"""SDV Sampling module."""

from sdv.sampling.hierarchical_sampler import BaseHierarchicalSampler
from sdv.sampling.independent_sampler import BaseIndependentSampler
from sdv.sampling.tabular import Condition

__all__ = [
'BaseHierarchicalSampler',
'BaseIndependentSampler',
'Condition',
]
2 changes: 1 addition & 1 deletion sdv/sampling/hierarchical_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _sample_rows(self, synthesizer, num_rows=None):
Sampled rows, shape (, num_rows)
"""
num_rows = num_rows or synthesizer._num_rows
return synthesizer._sample_batch(num_rows, remove_extra_columns=False)
return synthesizer._sample_batch(int(num_rows), keep_extra_columns=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amontanez24 @pvk-developer This line is an example where using an assert could be useful. num_rows should always be an integer, and if it is not then wherever it was set did so incorrectly, so the code there should be updated to cast it to an int.

So at this point we should assert isinstance(num_rows, int), or change the docstring to say that num_rows can be a float as well. You shouldn't have to write code assuming other methods/inputs are invalid, ie casting to int in this case.


def _get_num_rows_from_parent(self, parent_row, child_name, foreign_key):
"""Get the number of rows to sample for the child from the parent row."""
Expand Down
10 changes: 6 additions & 4 deletions sdv/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ def _sample_rows(self, num_rows, conditions=None, transformed_conditions=None,
Maximum tolerance when considering a float match.
previous_rows (pandas.DataFrame):
Valid rows sampled in the previous iterations.
remove_extra_columns (bool):
keep_extra_columns (bool):
Whether to keep extra columns from the sampled data. Defaults to False.
Returns:
Expand All @@ -604,10 +604,12 @@ def _sample_rows(self, num_rows, conditions=None, transformed_conditions=None,
* int:
Number of rows that are considered valid.
"""
if not self._random_state_set:
if self._model and not self._random_state_set:
self._set_random_state(FIXED_RNG_SEED)

if self._data_processor.get_sdtypes(primary_keys=False):
need_sample = self._data_processor.get_sdtypes(primary_keys=False) or keep_extra_columns
if self._model and need_sample:

if conditions is None:
raw_sampled = self._sample(num_rows)
else:
Expand Down Expand Up @@ -682,7 +684,7 @@ def _sample_batch(self, batch_size, max_tries=100,
output_file_path (str or None):
The file to periodically write sampled rows to. If None, does not write
rows anywhere.
remove_extra_columns (bool):
keep_extra_columns (bool):
Whether to keep extra columns from the sampled data. Defaults to False.
Returns:
Expand Down
Loading