Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support null foreign keys in HMA Synthesizer #2124

Merged
merged 17 commits into from
Aug 15, 2024
86 changes: 63 additions & 23 deletions sdv/multi_table/hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def __init__(self, metadata, locales=['en_US'], verbose=True):
self._table_sizes = {}
self._max_child_rows = {}
self._min_child_rows = {}
self._null_child_synthesizers = {}
self._augmented_tables = []
self._learned_relationships = 0
self._default_parameters = {}
Expand Down Expand Up @@ -310,10 +311,17 @@ def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc
pbar_args = self._get_pbar_args(desc=progress_bar_desc)

for foreign_key_value in tqdm(foreign_key_values, **pbar_args):
child_rows = child_table.loc[[foreign_key_value]]
try:
child_rows = child_table.loc[[foreign_key_value]]
except KeyError:
# pre pandas 2.1 df.loc[[np.nan]] causes error
if pd.isna(foreign_key_value):
child_rows = child_table[child_table.index.isna()]
else:
raise
child_rows = child_rows[child_rows.columns.difference(foreign_key_columns)]
try:
if child_rows.empty:
if child_rows.empty and not pd.isna(foreign_key_value):
row = pd.Series({'num_rows': len(child_rows)})
row.index = f'__{child_name}__{foreign_key}__' + row.index
else:
Expand All @@ -324,28 +332,38 @@ def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc
self._set_extended_columns_distributions(
synthesizer, child_name, child_rows.columns
)
synthesizer.fit_processed_data(child_rows.reset_index(drop=True))
row = synthesizer._get_parameters()
row = pd.Series(row)
row.index = f'__{child_name}__{foreign_key}__' + row.index

if scale_columns is None:
scale_columns = [column for column in row.index if column.endswith('scale')]

if len(child_rows) == 1:
row.loc[scale_columns] = None

extension_rows.append(row)
index.append(foreign_key_value)
if not child_rows.empty:
synthesizer.fit_processed_data(child_rows.reset_index(drop=True))
row = synthesizer._get_parameters()
row = pd.Series(row)
row.index = f'__{child_name}__{foreign_key}__' + row.index

if not pd.isna(foreign_key_value):
if scale_columns is None:
scale_columns = [
column for column in row.index if column.endswith('scale')
]

if len(child_rows) == 1:
row.loc[scale_columns] = None

if pd.isna(foreign_key_value):
self._null_child_synthesizers[f'__{child_name}__{foreign_key}'] = synthesizer
else:
extension_rows.append(row)
index.append(foreign_key_value)
except Exception:
# Skip children rows subsets that fail
pass

return pd.DataFrame(extension_rows, index=index)

@staticmethod
def _clear_nans(table_data):
for column in table_data.columns:
def _clear_nans(table_data, ignore_cols=None):
columns = set(table_data.columns)
if ignore_cols is not None:
columns = columns - set(ignore_cols)
for column in columns:
column_data = table_data[column]
if column_data.dtype in (int, float):
fill_value = 0 if column_data.isna().all() else column_data.mean()
Expand Down Expand Up @@ -405,14 +423,19 @@ def _augment_table(self, table, tables, table_name):
table[num_rows_key] = table[num_rows_key].fillna(0)
self._max_child_rows[num_rows_key] = table[num_rows_key].max()
self._min_child_rows[num_rows_key] = table[num_rows_key].min()
self._null_foreign_key_percentages[f'__{child_name}__{foreign_key}'] = 1 - (
table[num_rows_key].sum() / child_table.shape[0]
)

if len(extension.columns) > 0:
self._parent_extended_columns[table_name].extend(list(extension.columns))

tables[table_name] = table
self._learned_relationships += 1
self._augmented_tables.append(table_name)
self._clear_nans(table)

foreign_keys = self.metadata._get_all_foreign_keys(table_name)
self._clear_nans(table, ignore_cols=foreign_keys)

return table

Expand Down Expand Up @@ -525,12 +548,17 @@ def _extract_parameters(self, parent_row, table_name, foreign_key):
def _recreate_child_synthesizer(self, child_name, parent_name, parent_row):
# A child table is created based on only one foreign key.
foreign_key = self.metadata._get_foreign_keys(parent_name, child_name)[0]
parameters = self._extract_parameters(parent_row, child_name, foreign_key)
default_parameters = getattr(self, '_default_parameters', {}).get(child_name, {})

table_meta = self.metadata.tables[child_name]
synthesizer = self._synthesizer(table_meta, **self._table_parameters[child_name])
synthesizer._set_parameters(parameters, default_parameters)
if parent_row is not None:
parameters = self._extract_parameters(parent_row, child_name, foreign_key)
default_parameters = getattr(self, '_default_parameters', {}).get(child_name, {})

table_meta = self.metadata.tables[child_name]
synthesizer = self._synthesizer(table_meta, **self._table_parameters[child_name])
synthesizer._set_parameters(parameters, default_parameters)
else:
synthesizer = self._null_child_synthesizers[f'__{child_name}__{foreign_key}']
amontanez24 marked this conversation as resolved.
Show resolved Hide resolved

synthesizer._data_processor = self._table_synthesizers[child_name]._data_processor

return synthesizer
Expand Down Expand Up @@ -580,6 +608,9 @@ def _find_parent_id(likelihoods, num_rows):
candidates.append(parent)
candidate_weights.append(weight)

# cast candidates to series to ensure np.random.choice uses desired dtype
candidates = pd.Series(candidates, dtype=likelihoods.index.dtype)

# All available candidates were assigned 0 likelihood of being the parent id
if sum(candidate_weights) == 0:
chosen_parent = np.random.choice(candidates)
Expand Down Expand Up @@ -629,6 +660,14 @@ def _get_likelihoods(self, table_rows, parent_rows, table_name, foreign_key):
except (AttributeError, np.linalg.LinAlgError):
likelihoods[parent_id] = None

null_child_synths = getattr(self, '_null_child_synthesizers', {})
if f'__{table_name}__{foreign_key}' in null_child_synths:
try:
likelihoods[np.nan] = synthesizer._get_likelihood(table_rows)

except (AttributeError, np.linalg.LinAlgError):
likelihoods[np.nan] = None
amontanez24 marked this conversation as resolved.
Show resolved Hide resolved

return pd.DataFrame(likelihoods, index=table_rows.index)

def _find_parent_ids(self, child_table, parent_table, child_name, parent_name, foreign_key):
Expand Down Expand Up @@ -657,6 +696,7 @@ def _find_parent_ids(self, child_table, parent_table, child_name, parent_name, f
primary_key = self.metadata.tables[parent_name].primary_key
parent_table = parent_table.set_index(primary_key)
num_rows = parent_table[f'__{child_name}__{foreign_key}__num_rows'].copy()
num_rows.loc[np.nan] = child_table.shape[0] - num_rows.sum()

likelihoods = self._get_likelihoods(child_table, parent_table, child_name, foreign_key)
return likelihoods.apply(self._find_parent_id, axis=1, num_rows=num_rows)
Expand Down
7 changes: 4 additions & 3 deletions sdv/multi_table/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,9 +461,10 @@ def _subsample_disconnected_roots(data, metadata, table, ratio_to_keep, drop_mis
def _subsample_table_and_descendants(data, metadata, table, num_rows, drop_missing_values):
"""Subsample the table and its descendants.

The logic is to first subsample all the NaN foreign keys of the table when ``drop_missing_values``
is True. We raise an error if we cannot reach referential integrity while keeping
the number of rows. Then, we drop rows of the descendants to ensure referential integrity.
The logic is to first subsample all the NaN foreign keys of the table when
``drop_missing_values`` is True. We raise an error if we cannot reach referential integrity
while keeping the number of rows. Then, we drop rows of the descendants to ensure referential
integrity.

Args:
data (dict):
Expand Down
33 changes: 27 additions & 6 deletions sdv/sampling/hierarchical_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import warnings

import numpy as np
import pandas as pd

LOGGER = logging.getLogger(__name__)
Expand All @@ -24,6 +25,7 @@ class BaseHierarchicalSampler:

def __init__(self, metadata, table_synthesizers, table_sizes):
self.metadata = metadata
self._null_foreign_key_percentages = {}
self._table_synthesizers = table_synthesizers
self._table_sizes = table_sizes

Expand Down Expand Up @@ -103,7 +105,9 @@ def _add_child_rows(self, child_name, parent_name, parent_row, sampled_data, num
row_indices = sampled_rows.index
sampled_rows[foreign_key].iloc[row_indices] = parent_row[parent_key]
else:
sampled_rows[foreign_key] = parent_row[parent_key]
sampled_rows[foreign_key] = (
parent_row[parent_key] if parent_row is not None else np.nan
)

previous = sampled_data.get(child_name)
if previous is None:
Expand Down Expand Up @@ -143,16 +147,19 @@ def _enforce_table_size(self, child_name, table_name, scale, sampled_data):
"""
total_num_rows = round(self._table_sizes[child_name] * scale)
for foreign_key in self.metadata._get_foreign_keys(table_name, child_name):
null_fk_pctgs = getattr(self, '_null_foreign_key_percentages', {})
null_fk_pctg = null_fk_pctgs.get(f'__{child_name}__{foreign_key}', 0)
total_parent_rows = round(total_num_rows * (1 - null_fk_pctg))
num_rows_key = f'__{child_name}__{foreign_key}__num_rows'
min_rows = getattr(self, '_min_child_rows', {num_rows_key: 0})[num_rows_key]
max_rows = self._max_child_rows[num_rows_key]
key_data = sampled_data[table_name][num_rows_key].fillna(0).round()
sampled_data[table_name][num_rows_key] = key_data.clip(min_rows, max_rows).astype(int)

while sum(sampled_data[table_name][num_rows_key]) != total_num_rows:
while sum(sampled_data[table_name][num_rows_key]) != total_parent_rows:
num_rows_column = sampled_data[table_name][num_rows_key].argsort()

if sum(sampled_data[table_name][num_rows_key]) < total_num_rows:
if sum(sampled_data[table_name][num_rows_key]) < total_parent_rows:
for i in num_rows_column:
# If the number of rows is already at the maximum, skip
# The exception is when the smallest value is already at the maximum,
Expand All @@ -164,7 +171,7 @@ def _enforce_table_size(self, child_name, table_name, scale, sampled_data):
break

sampled_data[table_name].loc[i, num_rows_key] += 1
if sum(sampled_data[table_name][num_rows_key]) == total_num_rows:
if sum(sampled_data[table_name][num_rows_key]) == total_parent_rows:
break

else:
Expand All @@ -179,7 +186,7 @@ def _enforce_table_size(self, child_name, table_name, scale, sampled_data):
break

sampled_data[table_name].loc[i, num_rows_key] -= 1
if sum(sampled_data[table_name][num_rows_key]) == total_num_rows:
if sum(sampled_data[table_name][num_rows_key]) == total_parent_rows:
break

def _sample_children(self, table_name, sampled_data, scale=1.0):
Expand Down Expand Up @@ -207,8 +214,9 @@ def _sample_children(self, table_name, sampled_data, scale=1.0):
sampled_data=sampled_data,
)

foreign_key = self.metadata._get_foreign_keys(table_name, child_name)[0]

if child_name not in sampled_data: # No child rows sampled, force row creation
foreign_key = self.metadata._get_foreign_keys(table_name, child_name)[0]
num_rows_key = f'__{child_name}__{foreign_key}__num_rows'
max_num_child_index = sampled_data[table_name][num_rows_key].idxmax()
parent_row = sampled_data[table_name].iloc[max_num_child_index]
Expand All @@ -221,6 +229,19 @@ def _sample_children(self, table_name, sampled_data, scale=1.0):
num_rows=1,
)

total_num_rows = round(self._table_sizes[child_name] * scale)
null_fk_pctgs = getattr(self, '_null_foreign_key_percentages', {})
null_fk_pctg = null_fk_pctgs.get(f'__{child_name}__{foreign_key}', 0)
num_null_rows = round(total_num_rows * null_fk_pctg)
if num_null_rows > 0:
self._add_child_rows(
child_name=child_name,
parent_name=table_name,
parent_row=None,
sampled_data=sampled_data,
num_rows=num_null_rows,
)

self._sample_children(table_name=child_name, sampled_data=sampled_data, scale=scale)

def _finalize(self, sampled_data):
Expand Down
52 changes: 40 additions & 12 deletions tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -1318,49 +1318,68 @@ def test_null_foreign_keys(self):
"""Test that the synthesizer does not crash when there are null foreign keys."""
# Setup
metadata = MultiTableMetadata()
metadata.add_table('parent_table')
metadata.add_column('parent_table', 'id', sdtype='id')
metadata.set_primary_key('parent_table', 'id')
metadata.add_table('parent_table1')
metadata.add_column('parent_table1', 'id', sdtype='id')
metadata.set_primary_key('parent_table1', 'id')

metadata.add_table('parent_table2')
metadata.add_column('parent_table2', 'id', sdtype='id')
metadata.set_primary_key('parent_table2', 'id')

metadata.add_table('child_table1')
metadata.add_column('child_table1', 'id', sdtype='id')
metadata.set_primary_key('child_table1', 'id')
metadata.add_column('child_table1', 'fk', sdtype='id')
metadata.add_column('child_table1', 'fk1', sdtype='id')
metadata.add_column('child_table1', 'fk2', sdtype='id')

metadata.add_table('child_table2')
metadata.add_column('child_table2', 'id', sdtype='id')
metadata.set_primary_key('child_table2', 'id')
metadata.add_column('child_table2', 'fk1', sdtype='id')
metadata.add_column('child_table2', 'fk2', sdtype='id')
metadata.add_column('child_table2', 'cat_type', sdtype='categorical')

metadata.add_relationship(
parent_table_name='parent_table',
parent_table_name='parent_table1',
child_table_name='child_table1',
parent_primary_key='id',
child_foreign_key='fk',
child_foreign_key='fk1',
)

metadata.add_relationship(
parent_table_name='parent_table',
parent_table_name='parent_table2',
child_table_name='child_table1',
parent_primary_key='id',
child_foreign_key='fk2',
)

metadata.add_relationship(
parent_table_name='parent_table1',
child_table_name='child_table2',
parent_primary_key='id',
child_foreign_key='fk1',
)

metadata.add_relationship(
parent_table_name='parent_table',
parent_table_name='parent_table1',
child_table_name='child_table2',
parent_primary_key='id',
child_foreign_key='fk2',
)

data = {
'parent_table': pd.DataFrame({'id': [1, 2, 3]}),
'child_table1': pd.DataFrame({'id': [1, 2, 3], 'fk': [1, 2, np.nan]}),
'parent_table1': pd.DataFrame({'id': [1, 2, 3]}),
'parent_table2': pd.DataFrame({'id': ['alpha', 'beta', 'gamma']}),
'child_table1': pd.DataFrame({
'id': [1, 2, 3],
'fk1': pd.Series([np.nan, 2, np.nan], dtype='float64'),
'fk2': pd.Series(['alpha', 'beta', np.nan], dtype='object'),
}),
'child_table2': pd.DataFrame({
'id': [1, 2, 3],
'fk1': [1, 2, np.nan],
'fk2': [1, 2, np.nan],
'fk2': pd.Series([1, np.nan, np.nan], dtype='float64'),
'cat_type': pd.Series(['siamese', 'persian', 'american shorthair'], dtype='object'),
}),
}

Expand All @@ -1370,8 +1389,17 @@ def test_null_foreign_keys(self):
metadata.validate()
metadata.validate_data(data)

# Run and Assert
# Run
synthesizer.fit(data)
sampled = synthesizer.sample()

# Assert
assert len(sampled['parent_table1']) == 3
assert len(sampled['parent_table2']) == 3
assert sum(pd.isna(sampled['child_table1']['fk1'])) == 2
assert sum(pd.isna(sampled['child_table1']['fk2'])) == 1
assert sum(pd.isna(sampled['child_table2']['fk1'])) == 1
assert sum(pd.isna(sampled['child_table2']['fk2'])) == 2

def test_sampling_with_unknown_sdtype_numerical_column(self):
"""Test that if a numerical column is detected as unknown in the metadata,
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/utils/test_poc.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,4 +242,4 @@ def test_get_random_subset_with_missing_values(metadata, data):

# Assert
assert len(result['child']) == 3
assert result['child']['parent_id'].isnull().sum() > 0
assert result['child']['parent_id'].isna().sum() > 0
Loading
Loading