From 04422c0aa6f9bede1d6b6bde617afb620aebca4d Mon Sep 17 00:00:00 2001 From: rwedge Date: Wed, 26 Jun 2024 14:42:47 -0400 Subject: [PATCH 01/17] fit --- sdv/multi_table/hma.py | 12 +++++++++--- tests/integration/multi_table/test_hma.py | 4 ++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/sdv/multi_table/hma.py b/sdv/multi_table/hma.py index 5df4dc358..68da15933 100644 --- a/sdv/multi_table/hma.py +++ b/sdv/multi_table/hma.py @@ -158,6 +158,8 @@ def __init__(self, metadata, locales=['en_US'], verbose=True): self._table_sizes = {} self._max_child_rows = {} self._min_child_rows = {} + self._null_child_synthesizers = {} + self._null_foreign_key_percentages = {} self._augmented_tables = [] self._learned_relationships = 0 self._default_parameters = {} @@ -335,8 +337,11 @@ def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc if len(child_rows) == 1: row.loc[scale_columns] = None - extension_rows.append(row) - index.append(foreign_key_value) + if pd.isna(foreign_key_value): + self._null_child_synthesizers[f'__{child_name}__{foreign_key}'] = synthesizer + else: + extension_rows.append(row) + index.append(foreign_key_value) except Exception: # Skip children rows subsets that fail pass @@ -405,6 +410,7 @@ def _augment_table(self, table, tables, table_name): table[num_rows_key] = table[num_rows_key].fillna(0) self._max_child_rows[num_rows_key] = table[num_rows_key].max() self._min_child_rows[num_rows_key] = table[num_rows_key].min() + self._null_foreign_key_percentages[f'__{child_name}__{foreign_key}'] = 1 - (table[num_rows_key].sum() / child_table.shape[0]) if len(extension.columns) > 0: self._parent_extended_columns[table_name].extend(list(extension.columns)) @@ -412,7 +418,7 @@ def _augment_table(self, table, tables, table_name): tables[table_name] = table self._learned_relationships += 1 self._augmented_tables.append(table_name) - self._clear_nans(table) + # self._clear_nans(table) TODO: replace with standardizing nans? return table diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index 1f26e2c42..ede2ddc0c 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -1332,6 +1332,7 @@ def test_null_foreign_keys(self): metadata.set_primary_key('child_table2', 'id') metadata.add_column('child_table2', 'fk1', sdtype='id') metadata.add_column('child_table2', 'fk2', sdtype='id') + metadata.add_column('child_table2', 'cat_type', sdtype='categorical') metadata.add_relationship( parent_table_name='parent_table', @@ -1361,6 +1362,7 @@ def test_null_foreign_keys(self): 'id': [1, 2, 3], 'fk1': [1, 2, np.nan], 'fk2': [1, 2, np.nan], + 'cat_type': ['siamese','persian', 'american shorthair'], }), } @@ -1372,6 +1374,8 @@ def test_null_foreign_keys(self): # Run and Assert synthesizer.fit(data) + breakpoint() + def test_sampling_with_unknown_sdtype_numerical_column(self): """Test that if a numerical column is detected as unknown in the metadata, From 3b73660b4b23c1c407c42779d914b796604f545e Mon Sep 17 00:00:00 2001 From: rwedge Date: Fri, 28 Jun 2024 15:00:20 -0400 Subject: [PATCH 02/17] sample (wip) --- sdv/multi_table/hma.py | 25 +++++++++++++++++------ sdv/sampling/hierarchical_sampler.py | 25 ++++++++++++++++++----- tests/integration/multi_table/test_hma.py | 24 +++++++++++----------- 3 files changed, 51 insertions(+), 23 deletions(-) diff --git a/sdv/multi_table/hma.py b/sdv/multi_table/hma.py index 68da15933..646924f1e 100644 --- a/sdv/multi_table/hma.py +++ b/sdv/multi_table/hma.py @@ -159,7 +159,6 @@ def __init__(self, metadata, locales=['en_US'], verbose=True): self._max_child_rows = {} self._min_child_rows = {} self._null_child_synthesizers = {} - self._null_foreign_key_percentages = {} self._augmented_tables = [] self._learned_relationships = 0 self._default_parameters = {} @@ -337,6 +336,7 @@ def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc if len(child_rows) == 1: row.loc[scale_columns] = None + # TODO: handle null synthesizer when child_rows is empty if pd.isna(foreign_key_value): self._null_child_synthesizers[f'__{child_name}__{foreign_key}'] = synthesizer else: @@ -531,12 +531,17 @@ def _extract_parameters(self, parent_row, table_name, foreign_key): def _recreate_child_synthesizer(self, child_name, parent_name, parent_row): # A child table is created based on only one foreign key. foreign_key = self.metadata._get_foreign_keys(parent_name, child_name)[0] - parameters = self._extract_parameters(parent_row, child_name, foreign_key) - default_parameters = getattr(self, '_default_parameters', {}).get(child_name, {}) - table_meta = self.metadata.tables[child_name] - synthesizer = self._synthesizer(table_meta, **self._table_parameters[child_name]) - synthesizer._set_parameters(parameters, default_parameters) + if parent_row is not None: + parameters = self._extract_parameters(parent_row, child_name, foreign_key) + default_parameters = getattr(self, '_default_parameters', {}).get(child_name, {}) + + table_meta = self.metadata.tables[child_name] + synthesizer = self._synthesizer(table_meta, **self._table_parameters[child_name]) + synthesizer._set_parameters(parameters, default_parameters) + else: + synthesizer = self._null_child_synthesizers[f'__{child_name}__{foreign_key}'] + synthesizer._data_processor = self._table_synthesizers[child_name]._data_processor return synthesizer @@ -635,6 +640,13 @@ def _get_likelihoods(self, table_rows, parent_rows, table_name, foreign_key): except (AttributeError, np.linalg.LinAlgError): likelihoods[parent_id] = None + if f'__{table_name}__{foreign_key}' in self._null_child_synthesizers: + try: + likelihoods[np.nan] = synthesizer._get_likelihood(table_rows) + + except (AttributeError, np.linalg.LinAlgError): + likelihoods[np.nan] = None + return pd.DataFrame(likelihoods, index=table_rows.index) def _find_parent_ids(self, child_table, parent_table, child_name, parent_name, foreign_key): @@ -663,6 +675,7 @@ def _find_parent_ids(self, child_table, parent_table, child_name, parent_name, f primary_key = self.metadata.tables[parent_name].primary_key parent_table = parent_table.set_index(primary_key) num_rows = parent_table[f'__{child_name}__{foreign_key}__num_rows'].copy() + num_rows.loc[np.nan] = child_table.shape[0] - num_rows.sum() likelihoods = self._get_likelihoods(child_table, parent_table, child_name, foreign_key) return likelihoods.apply(self._find_parent_id, axis=1, num_rows=num_rows) diff --git a/sdv/sampling/hierarchical_sampler.py b/sdv/sampling/hierarchical_sampler.py index 641b6ccbf..3bff6742f 100644 --- a/sdv/sampling/hierarchical_sampler.py +++ b/sdv/sampling/hierarchical_sampler.py @@ -3,6 +3,7 @@ import logging import warnings +import numpy as np import pandas as pd LOGGER = logging.getLogger(__name__) @@ -24,6 +25,7 @@ class BaseHierarchicalSampler: def __init__(self, metadata, table_synthesizers, table_sizes): self.metadata = metadata + self._null_foreign_key_percentages = {} self._table_synthesizers = table_synthesizers self._table_sizes = table_sizes @@ -103,7 +105,7 @@ def _add_child_rows(self, child_name, parent_name, parent_row, sampled_data, num row_indices = sampled_rows.index sampled_rows[foreign_key].iloc[row_indices] = parent_row[parent_key] else: - sampled_rows[foreign_key] = parent_row[parent_key] + sampled_rows[foreign_key] = parent_row[parent_key] if parent_row is not None else np.nan previous = sampled_data.get(child_name) if previous is None: @@ -143,16 +145,18 @@ def _enforce_table_size(self, child_name, table_name, scale, sampled_data): """ total_num_rows = round(self._table_sizes[child_name] * scale) for foreign_key in self.metadata._get_foreign_keys(table_name, child_name): + null_fk_pctg = self._null_foreign_key_percentages[f'__{child_name}__{foreign_key}'] + total_parent_rows = round(total_num_rows * (1 - null_fk_pctg)) num_rows_key = f'__{child_name}__{foreign_key}__num_rows' min_rows = getattr(self, '_min_child_rows', {num_rows_key: 0})[num_rows_key] max_rows = self._max_child_rows[num_rows_key] key_data = sampled_data[table_name][num_rows_key].fillna(0).round() sampled_data[table_name][num_rows_key] = key_data.clip(min_rows, max_rows).astype(int) - while sum(sampled_data[table_name][num_rows_key]) != total_num_rows: + while sum(sampled_data[table_name][num_rows_key]) != total_parent_rows: num_rows_column = sampled_data[table_name][num_rows_key].argsort() - if sum(sampled_data[table_name][num_rows_key]) < total_num_rows: + if sum(sampled_data[table_name][num_rows_key]) < total_parent_rows: for i in num_rows_column: # If the number of rows is already at the maximum, skip # The exception is when the smallest value is already at the maximum, @@ -164,7 +168,7 @@ def _enforce_table_size(self, child_name, table_name, scale, sampled_data): break sampled_data[table_name].loc[i, num_rows_key] += 1 - if sum(sampled_data[table_name][num_rows_key]) == total_num_rows: + if sum(sampled_data[table_name][num_rows_key]) == total_parent_rows: break else: @@ -179,7 +183,7 @@ def _enforce_table_size(self, child_name, table_name, scale, sampled_data): break sampled_data[table_name].loc[i, num_rows_key] -= 1 - if sum(sampled_data[table_name][num_rows_key]) == total_num_rows: + if sum(sampled_data[table_name][num_rows_key]) == total_parent_rows: break def _sample_children(self, table_name, sampled_data, scale=1.0): @@ -221,6 +225,17 @@ def _sample_children(self, table_name, sampled_data, scale=1.0): num_rows=1, ) + total_num_rows = round(self._table_sizes[child_name] * scale) + num_null_rows = total_num_rows - sampled_data[child_name].shape[0] + if num_null_rows > 0: + self._add_child_rows( + child_name=child_name, + parent_name=table_name, + parent_row=None, + sampled_data=sampled_data, + num_rows=num_null_rows + ) + self._sample_children(table_name=child_name, sampled_data=sampled_data, scale=scale) def _finalize(self, sampled_data): diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index ede2ddc0c..2c92fbc6b 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -1322,10 +1322,10 @@ def test_null_foreign_keys(self): metadata.add_column('parent_table', 'id', sdtype='id') metadata.set_primary_key('parent_table', 'id') - metadata.add_table('child_table1') - metadata.add_column('child_table1', 'id', sdtype='id') - metadata.set_primary_key('child_table1', 'id') - metadata.add_column('child_table1', 'fk', sdtype='id') + # metadata.add_table('child_table1') + # metadata.add_column('child_table1', 'id', sdtype='id') + # metadata.set_primary_key('child_table1', 'id') + # metadata.add_column('child_table1', 'fk', sdtype='id') metadata.add_table('child_table2') metadata.add_column('child_table2', 'id', sdtype='id') @@ -1334,12 +1334,12 @@ def test_null_foreign_keys(self): metadata.add_column('child_table2', 'fk2', sdtype='id') metadata.add_column('child_table2', 'cat_type', sdtype='categorical') - metadata.add_relationship( - parent_table_name='parent_table', - child_table_name='child_table1', - parent_primary_key='id', - child_foreign_key='fk', - ) + # metadata.add_relationship( + # parent_table_name='parent_table', + # child_table_name='child_table1', + # parent_primary_key='id', + # child_foreign_key='fk', + # ) metadata.add_relationship( parent_table_name='parent_table', @@ -1357,7 +1357,7 @@ def test_null_foreign_keys(self): data = { 'parent_table': pd.DataFrame({'id': [1, 2, 3]}), - 'child_table1': pd.DataFrame({'id': [1, 2, 3], 'fk': [1, 2, np.nan]}), + # 'child_table1': pd.DataFrame({'id': [1, 2, 3], 'fk': [1, 2, np.nan]}), 'child_table2': pd.DataFrame({ 'id': [1, 2, 3], 'fk1': [1, 2, np.nan], @@ -1374,7 +1374,7 @@ def test_null_foreign_keys(self): # Run and Assert synthesizer.fit(data) - breakpoint() + sampled_data = synthesizer.sample() def test_sampling_with_unknown_sdtype_numerical_column(self): From c0aa4ddf3e504625370ad865983a4eb09476708c Mon Sep 17 00:00:00 2001 From: rwedge Date: Mon, 8 Jul 2024 19:15:31 -0400 Subject: [PATCH 03/17] handle no columns to learn from case --- sdv/multi_table/hma.py | 12 ++++++------ tests/integration/multi_table/test_hma.py | 22 +++++++++++----------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/sdv/multi_table/hma.py b/sdv/multi_table/hma.py index 646924f1e..06eb08506 100644 --- a/sdv/multi_table/hma.py +++ b/sdv/multi_table/hma.py @@ -314,7 +314,7 @@ def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc child_rows = child_table.loc[[foreign_key_value]] child_rows = child_rows[child_rows.columns.difference(foreign_key_columns)] try: - if child_rows.empty: + if child_rows.empty and not pd.isna(foreign_key_value): row = pd.Series({'num_rows': len(child_rows)}) row.index = f'__{child_name}__{foreign_key}__' + row.index else: @@ -325,10 +325,11 @@ def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc self._set_extended_columns_distributions( synthesizer, child_name, child_rows.columns ) - synthesizer.fit_processed_data(child_rows.reset_index(drop=True)) - row = synthesizer._get_parameters() - row = pd.Series(row) - row.index = f'__{child_name}__{foreign_key}__' + row.index + if not child_rows.empty: + synthesizer.fit_processed_data(child_rows.reset_index(drop=True)) + row = synthesizer._get_parameters() + row = pd.Series(row) + row.index = f'__{child_name}__{foreign_key}__' + row.index if scale_columns is None: scale_columns = [column for column in row.index if column.endswith('scale')] @@ -336,7 +337,6 @@ def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc if len(child_rows) == 1: row.loc[scale_columns] = None - # TODO: handle null synthesizer when child_rows is empty if pd.isna(foreign_key_value): self._null_child_synthesizers[f'__{child_name}__{foreign_key}'] = synthesizer else: diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index 2c92fbc6b..991488319 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -1322,10 +1322,10 @@ def test_null_foreign_keys(self): metadata.add_column('parent_table', 'id', sdtype='id') metadata.set_primary_key('parent_table', 'id') - # metadata.add_table('child_table1') - # metadata.add_column('child_table1', 'id', sdtype='id') - # metadata.set_primary_key('child_table1', 'id') - # metadata.add_column('child_table1', 'fk', sdtype='id') + metadata.add_table('child_table1') + metadata.add_column('child_table1', 'id', sdtype='id') + metadata.set_primary_key('child_table1', 'id') + metadata.add_column('child_table1', 'fk', sdtype='id') metadata.add_table('child_table2') metadata.add_column('child_table2', 'id', sdtype='id') @@ -1334,12 +1334,12 @@ def test_null_foreign_keys(self): metadata.add_column('child_table2', 'fk2', sdtype='id') metadata.add_column('child_table2', 'cat_type', sdtype='categorical') - # metadata.add_relationship( - # parent_table_name='parent_table', - # child_table_name='child_table1', - # parent_primary_key='id', - # child_foreign_key='fk', - # ) + metadata.add_relationship( + parent_table_name='parent_table', + child_table_name='child_table1', + parent_primary_key='id', + child_foreign_key='fk', + ) metadata.add_relationship( parent_table_name='parent_table', @@ -1357,7 +1357,7 @@ def test_null_foreign_keys(self): data = { 'parent_table': pd.DataFrame({'id': [1, 2, 3]}), - # 'child_table1': pd.DataFrame({'id': [1, 2, 3], 'fk': [1, 2, np.nan]}), + 'child_table1': pd.DataFrame({'id': [1, 2, 3], 'fk': [1, 2, np.nan]}), 'child_table2': pd.DataFrame({ 'id': [1, 2, 3], 'fk1': [1, 2, np.nan], From 8bc2590f5c1a3b468e90b16af601637d20b7ae5f Mon Sep 17 00:00:00 2001 From: rwedge Date: Tue, 9 Jul 2024 12:54:49 -0400 Subject: [PATCH 04/17] adjust num null parent calculation --- sdv/sampling/hierarchical_sampler.py | 6 ++++-- tests/unit/sampling/test_hierarchical_sampler.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/sdv/sampling/hierarchical_sampler.py b/sdv/sampling/hierarchical_sampler.py index 3bff6742f..c264b33be 100644 --- a/sdv/sampling/hierarchical_sampler.py +++ b/sdv/sampling/hierarchical_sampler.py @@ -211,8 +211,9 @@ def _sample_children(self, table_name, sampled_data, scale=1.0): sampled_data=sampled_data, ) + foreign_key = self.metadata._get_foreign_keys(table_name, child_name)[0] + if child_name not in sampled_data: # No child rows sampled, force row creation - foreign_key = self.metadata._get_foreign_keys(table_name, child_name)[0] num_rows_key = f'__{child_name}__{foreign_key}__num_rows' max_num_child_index = sampled_data[table_name][num_rows_key].idxmax() parent_row = sampled_data[table_name].iloc[max_num_child_index] @@ -226,7 +227,8 @@ def _sample_children(self, table_name, sampled_data, scale=1.0): ) total_num_rows = round(self._table_sizes[child_name] * scale) - num_null_rows = total_num_rows - sampled_data[child_name].shape[0] + null_fk_pctg = self._null_foreign_key_percentages[f'__{child_name}__{foreign_key}'] + num_null_rows = round(total_num_rows * null_fk_pctg) if num_null_rows > 0: self._add_child_rows( child_name=child_name, diff --git a/tests/unit/sampling/test_hierarchical_sampler.py b/tests/unit/sampling/test_hierarchical_sampler.py index 006b14ffd..a614241de 100644 --- a/tests/unit/sampling/test_hierarchical_sampler.py +++ b/tests/unit/sampling/test_hierarchical_sampler.py @@ -177,7 +177,7 @@ def sample_children(table_name, sampled_data, scale): 'session_id': ['a', 'a', 'b'], }) - def _add_child_rows(child_name, parent_name, parent_row, sampled_data): + def _add_child_rows(child_name, parent_name, parent_row, sampled_data, num_rows=None): if parent_name == 'users': if parent_row['user_id'] == 1: sampled_data[child_name] = pd.DataFrame({ From 9c47c91f3e6b68d244004b85c19fbcc8b93df947 Mon Sep 17 00:00:00 2001 From: rwedge Date: Thu, 11 Jul 2024 12:32:38 -0400 Subject: [PATCH 05/17] fix test; add ignore_cols to clear_nans --- sdv/multi_table/hma.py | 12 +++++++++--- tests/unit/multi_table/test_hma.py | 4 ++++ tests/unit/sampling/test_hierarchical_sampler.py | 8 ++++++++ 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/sdv/multi_table/hma.py b/sdv/multi_table/hma.py index 06eb08506..04299cbe6 100644 --- a/sdv/multi_table/hma.py +++ b/sdv/multi_table/hma.py @@ -349,8 +349,12 @@ def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc return pd.DataFrame(extension_rows, index=index) @staticmethod - def _clear_nans(table_data): - for column in table_data.columns: + def _clear_nans(table_data, ignore_cols=None): + # TODO: test child with foreign key that points to multiple parents + columns = set(table_data.columns) + if ignore_cols is not None: + columns = columns - set(ignore_cols) + for column in columns: column_data = table_data[column] if column_data.dtype in (int, float): fill_value = 0 if column_data.isna().all() else column_data.mean() @@ -418,7 +422,9 @@ def _augment_table(self, table, tables, table_name): tables[table_name] = table self._learned_relationships += 1 self._augmented_tables.append(table_name) - # self._clear_nans(table) TODO: replace with standardizing nans? + + foreign_keys = self.metadata._get_all_foreign_keys(table_name) + self._clear_nans(table, ignore_cols=foreign_keys) return table diff --git a/tests/unit/multi_table/test_hma.py b/tests/unit/multi_table/test_hma.py index 51d5aa1a4..f62c7fd08 100644 --- a/tests/unit/multi_table/test_hma.py +++ b/tests/unit/multi_table/test_hma.py @@ -511,6 +511,7 @@ def test__get_likelihoods(self): instance._table_parameters = {'child_table': {}} instance._extract_parameters = Mock() instance._synthesizer.return_value._get_likelihood.return_value = likelihoods + instance._null_child_synthesizers = {} # Run result = HMASynthesizer._get_likelihoods( @@ -550,6 +551,7 @@ def test__get_likelihoods_attribute_error(self): instance._table_synthesizers = {'child_table': child_synthesizer} instance._table_parameters = {'child_table': {}} instance._extract_parameters = Mock() + instance._null_child_synthesizers = {} instance._synthesizer.return_value._get_likelihood.side_effect = [ likelihoods, AttributeError(), @@ -594,6 +596,7 @@ def test__get_likelihoods_linalg_error(self): instance._table_synthesizers = {'child_table': child_synthesizer} instance._table_parameters = {'child_table': {}} instance._extract_parameters = Mock() + instance._null_child_synthesizers = {} instance._synthesizer.return_value._get_likelihood.side_effect = [ likelihoods, np.linalg.LinAlgError(), @@ -639,6 +642,7 @@ def test_get_likelihoods_filters_over_existing_columns(self, mock_concat): instance._table_synthesizers = {'child_table': child_synthesizer} instance._table_parameters = {'child_table': {}} instance._extract_parameters = Mock() + instance._null_child_synthesizers = {} likelihoods = np.array([0.1, 0.2, 0.3, 0.4]) instance._synthesizer.return_value._get_likelihood.return_value = likelihoods diff --git a/tests/unit/sampling/test_hierarchical_sampler.py b/tests/unit/sampling/test_hierarchical_sampler.py index a614241de..de07a2634 100644 --- a/tests/unit/sampling/test_hierarchical_sampler.py +++ b/tests/unit/sampling/test_hierarchical_sampler.py @@ -202,10 +202,13 @@ def _add_child_rows(child_name, parent_name, parent_row, sampled_data, num_rows= instance = Mock() instance.metadata._get_child_map.return_value = {'users': ['sessions', 'transactions']} instance.metadata._get_parent_map.return_value = {'users': []} + instance.metadata._get_foreign_keys.return_value = ['user_id'] instance._table_sizes = {'users': 10, 'sessions': 5, 'transactions': 3} instance._table_synthesizers = {'users': Mock()} instance._sample_children = sample_children instance._add_child_rows.side_effect = _add_child_rows + instance._null_child_synthesizers = {} + instance._null_foreign_key_percentages = {'__sessions__user_id': 0} # Run result = {'users': pd.DataFrame({'user_id': [1, 3]})} @@ -271,6 +274,7 @@ def _add_child_rows(child_name, parent_name, parent_row, sampled_data, num_rows= instance._table_synthesizers = {'users': Mock()} instance._sample_children = sample_children instance._add_child_rows.side_effect = _add_child_rows + instance._null_foreign_key_percentages = {'__sessions__user_id': 0} # Run result = {'users': pd.DataFrame({'user_id': [1], '__sessions__user_id__num_rows': [1]})} @@ -561,6 +565,7 @@ def test___enforce_table_size_too_many_rows(self): instance._min_child_rows = {'__child__fk__num_rows': 1} instance._max_child_rows = {'__child__fk__num_rows': 3} instance._table_sizes = {'child': 4} + instance._null_foreign_key_percentages = {'__child__fk': 0} # Run BaseHierarchicalSampler._enforce_table_size(instance, 'child', 'parent', 1.0, data) @@ -580,6 +585,7 @@ def test___enforce_table_size_not_enough_rows(self): instance._min_child_rows = {'__child__fk__num_rows': 1} instance._max_child_rows = {'__child__fk__num_rows': 3} instance._table_sizes = {'child': 4} + instance._null_foreign_key_percentages = {'__child__fk': 0} # Run BaseHierarchicalSampler._enforce_table_size(instance, 'child', 'parent', 1.0, data) @@ -599,6 +605,7 @@ def test___enforce_table_size_clipping(self): instance._min_child_rows = {'__child__fk__num_rows': 2} instance._max_child_rows = {'__child__fk__num_rows': 4} instance._table_sizes = {'child': 8} + instance._null_foreign_key_percentages = {'__child__fk': 0} # Run BaseHierarchicalSampler._enforce_table_size(instance, 'child', 'parent', 1.0, data) @@ -618,6 +625,7 @@ def test___enforce_table_size_too_small_sample(self): instance._min_child_rows = {'__child__fk__num_rows': 1} instance._max_child_rows = {'__child__fk__num_rows': 3} instance._table_sizes = {'child': 4} + instance._null_foreign_key_percentages = {'__child__fk': 0} # Run BaseHierarchicalSampler._enforce_table_size(instance, 'child', 'parent', 0.001, data) From 482c3af50aaaf8159f5dcac82a10d5cf0224fe8e Mon Sep 17 00:00:00 2001 From: rwedge Date: Fri, 12 Jul 2024 12:49:04 -0400 Subject: [PATCH 06/17] specify dtye in test series --- tests/integration/multi_table/test_hma.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index 991488319..92211f62f 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -1362,7 +1362,7 @@ def test_null_foreign_keys(self): 'id': [1, 2, 3], 'fk1': [1, 2, np.nan], 'fk2': [1, 2, np.nan], - 'cat_type': ['siamese','persian', 'american shorthair'], + 'cat_type': pd.Series(['siamese','persian', 'american shorthair'], dtype='object'), }), } From 23da253de35ab1693c3ea937b3a9e14818331ba3 Mon Sep 17 00:00:00 2001 From: rwedge Date: Fri, 12 Jul 2024 15:01:48 -0400 Subject: [PATCH 07/17] lint --- sdv/multi_table/hma.py | 4 +++- sdv/multi_table/utils.py | 7 ++++--- sdv/sampling/hierarchical_sampler.py | 6 ++++-- tests/integration/multi_table/test_hma.py | 7 ++++--- tests/integration/utils/test_poc.py | 2 +- tests/unit/multi_table/test_utils.py | 2 +- 6 files changed, 17 insertions(+), 11 deletions(-) diff --git a/sdv/multi_table/hma.py b/sdv/multi_table/hma.py index 04299cbe6..9c0ea21c1 100644 --- a/sdv/multi_table/hma.py +++ b/sdv/multi_table/hma.py @@ -414,7 +414,9 @@ def _augment_table(self, table, tables, table_name): table[num_rows_key] = table[num_rows_key].fillna(0) self._max_child_rows[num_rows_key] = table[num_rows_key].max() self._min_child_rows[num_rows_key] = table[num_rows_key].min() - self._null_foreign_key_percentages[f'__{child_name}__{foreign_key}'] = 1 - (table[num_rows_key].sum() / child_table.shape[0]) + self._null_foreign_key_percentages[f'__{child_name}__{foreign_key}'] = 1 - ( + table[num_rows_key].sum() / child_table.shape[0] + ) if len(extension.columns) > 0: self._parent_extended_columns[table_name].extend(list(extension.columns)) diff --git a/sdv/multi_table/utils.py b/sdv/multi_table/utils.py index 60fd71053..069d9d629 100644 --- a/sdv/multi_table/utils.py +++ b/sdv/multi_table/utils.py @@ -461,9 +461,10 @@ def _subsample_disconnected_roots(data, metadata, table, ratio_to_keep, drop_mis def _subsample_table_and_descendants(data, metadata, table, num_rows, drop_missing_values): """Subsample the table and its descendants. - The logic is to first subsample all the NaN foreign keys of the table when ``drop_missing_values`` - is True. We raise an error if we cannot reach referential integrity while keeping - the number of rows. Then, we drop rows of the descendants to ensure referential integrity. + The logic is to first subsample all the NaN foreign keys of the table when + ``drop_missing_values`` is True. We raise an error if we cannot reach referential integrity + while keeping the number of rows. Then, we drop rows of the descendants to ensure referential + integrity. Args: data (dict): diff --git a/sdv/sampling/hierarchical_sampler.py b/sdv/sampling/hierarchical_sampler.py index c264b33be..a5295f858 100644 --- a/sdv/sampling/hierarchical_sampler.py +++ b/sdv/sampling/hierarchical_sampler.py @@ -105,7 +105,9 @@ def _add_child_rows(self, child_name, parent_name, parent_row, sampled_data, num row_indices = sampled_rows.index sampled_rows[foreign_key].iloc[row_indices] = parent_row[parent_key] else: - sampled_rows[foreign_key] = parent_row[parent_key] if parent_row is not None else np.nan + sampled_rows[foreign_key] = ( + parent_row[parent_key] if parent_row is not None else np.nan + ) previous = sampled_data.get(child_name) if previous is None: @@ -235,7 +237,7 @@ def _sample_children(self, table_name, sampled_data, scale=1.0): parent_name=table_name, parent_row=None, sampled_data=sampled_data, - num_rows=num_null_rows + num_rows=num_null_rows, ) self._sample_children(table_name=child_name, sampled_data=sampled_data, scale=scale) diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index 92211f62f..d43a649f4 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -1362,7 +1362,7 @@ def test_null_foreign_keys(self): 'id': [1, 2, 3], 'fk1': [1, 2, np.nan], 'fk2': [1, 2, np.nan], - 'cat_type': pd.Series(['siamese','persian', 'american shorthair'], dtype='object'), + 'cat_type': pd.Series(['siamese', 'persian', 'american shorthair'], dtype='object'), }), } @@ -1372,10 +1372,11 @@ def test_null_foreign_keys(self): metadata.validate() metadata.validate_data(data) - # Run and Assert + # Run synthesizer.fit(data) - sampled_data = synthesizer.sample() + synthesizer.sample() + # TODO: check results match expected def test_sampling_with_unknown_sdtype_numerical_column(self): """Test that if a numerical column is detected as unknown in the metadata, diff --git a/tests/integration/utils/test_poc.py b/tests/integration/utils/test_poc.py index 0a3e02135..3c917ccba 100644 --- a/tests/integration/utils/test_poc.py +++ b/tests/integration/utils/test_poc.py @@ -242,4 +242,4 @@ def test_get_random_subset_with_missing_values(metadata, data): # Assert assert len(result['child']) == 3 - assert result['child']['parent_id'].isnull().sum() > 0 + assert result['child']['parent_id'].ina().sum() > 0 diff --git a/tests/unit/multi_table/test_utils.py b/tests/unit/multi_table/test_utils.py index cbea576b4..e01195f89 100644 --- a/tests/unit/multi_table/test_utils.py +++ b/tests/unit/multi_table/test_utils.py @@ -1990,7 +1990,7 @@ def test__subsample_data_with_null_foreing_keys(): # Assert assert len(result_with_nan['child']) == 4 - assert result_with_nan['child']['parent_id'].isnull().sum() > 0 + assert result_with_nan['child']['parent_id'].isna().sum() > 0 assert len(result_without_nan['child']) == 2 assert set(result_without_nan['child'].index) == {0, 1} From 349cdc35819136eaffb205a271c1cdbf1beadab7 Mon Sep 17 00:00:00 2001 From: Frances Hartwell Date: Tue, 16 Jul 2024 10:34:00 -0400 Subject: [PATCH 08/17] fix isna in test --- tests/integration/utils/test_poc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/utils/test_poc.py b/tests/integration/utils/test_poc.py index 3c917ccba..b3dfcffdc 100644 --- a/tests/integration/utils/test_poc.py +++ b/tests/integration/utils/test_poc.py @@ -242,4 +242,4 @@ def test_get_random_subset_with_missing_values(metadata, data): # Assert assert len(result['child']) == 3 - assert result['child']['parent_id'].ina().sum() > 0 + assert result['child']['parent_id'].isna().sum() > 0 From 89c3d3fb158990af87e05b9d17e7dc677247f750 Mon Sep 17 00:00:00 2001 From: Frances Hartwell Date: Wed, 17 Jul 2024 21:07:14 -0400 Subject: [PATCH 09/17] finish test + remove todo --- sdv/multi_table/hma.py | 1 - tests/integration/multi_table/test_hma.py | 10 +++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/sdv/multi_table/hma.py b/sdv/multi_table/hma.py index 9c0ea21c1..8d0d338e4 100644 --- a/sdv/multi_table/hma.py +++ b/sdv/multi_table/hma.py @@ -350,7 +350,6 @@ def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc @staticmethod def _clear_nans(table_data, ignore_cols=None): - # TODO: test child with foreign key that points to multiple parents columns = set(table_data.columns) if ignore_cols is not None: columns = columns - set(ignore_cols) diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index d43a649f4..cdb65f191 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -1361,7 +1361,7 @@ def test_null_foreign_keys(self): 'child_table2': pd.DataFrame({ 'id': [1, 2, 3], 'fk1': [1, 2, np.nan], - 'fk2': [1, 2, np.nan], + 'fk2': [1, np.nan, np.nan], 'cat_type': pd.Series(['siamese', 'persian', 'american shorthair'], dtype='object'), }), } @@ -1374,9 +1374,13 @@ def test_null_foreign_keys(self): # Run synthesizer.fit(data) - synthesizer.sample() + sampled = synthesizer.sample() - # TODO: check results match expected + # Assert + assert len(sampled['parent_table']) == 3 + assert sum(pd.isna(sampled['child_table1']['fk'])) == 1 + assert sum(pd.isna(sampled['child_table2']['fk1'])) == 1 + assert sum(pd.isna(sampled['child_table2']['fk2'])) == 2 def test_sampling_with_unknown_sdtype_numerical_column(self): """Test that if a numerical column is detected as unknown in the metadata, From 5437d6844de2a40fb3d6d0db636ee945ef2f1d01 Mon Sep 17 00:00:00 2001 From: rwedge Date: Wed, 24 Jul 2024 20:45:39 -0400 Subject: [PATCH 10/17] use getattr for backwards compatibility --- sdv/sampling/hierarchical_sampler.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sdv/sampling/hierarchical_sampler.py b/sdv/sampling/hierarchical_sampler.py index a5295f858..0d5bf855d 100644 --- a/sdv/sampling/hierarchical_sampler.py +++ b/sdv/sampling/hierarchical_sampler.py @@ -147,7 +147,8 @@ def _enforce_table_size(self, child_name, table_name, scale, sampled_data): """ total_num_rows = round(self._table_sizes[child_name] * scale) for foreign_key in self.metadata._get_foreign_keys(table_name, child_name): - null_fk_pctg = self._null_foreign_key_percentages[f'__{child_name}__{foreign_key}'] + null_fk_pctgs = getattr(self, '_null_foreign_key_percentages', {}) + null_fk_pctg = null_fk_pctgs.get(f'__{child_name}__{foreign_key}', 0) total_parent_rows = round(total_num_rows * (1 - null_fk_pctg)) num_rows_key = f'__{child_name}__{foreign_key}__num_rows' min_rows = getattr(self, '_min_child_rows', {num_rows_key: 0})[num_rows_key] @@ -229,7 +230,8 @@ def _sample_children(self, table_name, sampled_data, scale=1.0): ) total_num_rows = round(self._table_sizes[child_name] * scale) - null_fk_pctg = self._null_foreign_key_percentages[f'__{child_name}__{foreign_key}'] + null_fk_pctgs = getattr(self, '_null_foreign_key_percentages', {}) + null_fk_pctg = null_fk_pctgs.get(f'__{child_name}__{foreign_key}', 0) num_null_rows = round(total_num_rows * null_fk_pctg) if num_null_rows > 0: self._add_child_rows( From c2d6b55c883b8480560f0ee41b9e297ba238898a Mon Sep 17 00:00:00 2001 From: rwedge Date: Fri, 9 Aug 2024 18:10:33 -0400 Subject: [PATCH 11/17] a few NaN bugfixes --- sdv/multi_table/hma.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/sdv/multi_table/hma.py b/sdv/multi_table/hma.py index 8d0d338e4..f47c87b54 100644 --- a/sdv/multi_table/hma.py +++ b/sdv/multi_table/hma.py @@ -331,11 +331,12 @@ def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc row = pd.Series(row) row.index = f'__{child_name}__{foreign_key}__' + row.index - if scale_columns is None: - scale_columns = [column for column in row.index if column.endswith('scale')] + if not pd.isna(foreign_key_value): + if scale_columns is None: + scale_columns = [column for column in row.index if column.endswith('scale')] - if len(child_rows) == 1: - row.loc[scale_columns] = None + if len(child_rows) == 1: + row.loc[scale_columns] = None if pd.isna(foreign_key_value): self._null_child_synthesizers[f'__{child_name}__{foreign_key}'] = synthesizer @@ -598,6 +599,9 @@ def _find_parent_id(likelihoods, num_rows): candidates.append(parent) candidate_weights.append(weight) + # cast candidates to series to ensure np.random.choice uses desired dtype + candidates = pd.Series(candidates, dtype=likelihoods.index.dtype) + # All available candidates were assigned 0 likelihood of being the parent id if sum(candidate_weights) == 0: chosen_parent = np.random.choice(candidates) From 9df452df9f2ef3f94043d3e6283f141d0905ace3 Mon Sep 17 00:00:00 2001 From: rwedge Date: Mon, 12 Aug 2024 10:40:45 -0400 Subject: [PATCH 12/17] lint --- sdv/multi_table/hma.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sdv/multi_table/hma.py b/sdv/multi_table/hma.py index f47c87b54..b2e9d6566 100644 --- a/sdv/multi_table/hma.py +++ b/sdv/multi_table/hma.py @@ -333,7 +333,9 @@ def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc if not pd.isna(foreign_key_value): if scale_columns is None: - scale_columns = [column for column in row.index if column.endswith('scale')] + scale_columns = [ + column for column in row.index if column.endswith('scale') + ] if len(child_rows) == 1: row.loc[scale_columns] = None From a397b3d5d9d9e653d21c2b3f63075997a3a1c025 Mon Sep 17 00:00:00 2001 From: rwedge Date: Mon, 12 Aug 2024 15:53:45 -0400 Subject: [PATCH 13/17] update null integration test to handle more edge cases --- tests/integration/multi_table/test_hma.py | 43 ++++++++++++++++------- 1 file changed, 31 insertions(+), 12 deletions(-) diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index cdb65f191..082a88a54 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -1318,14 +1318,19 @@ def test_null_foreign_keys(self): """Test that the synthesizer does not crash when there are null foreign keys.""" # Setup metadata = MultiTableMetadata() - metadata.add_table('parent_table') - metadata.add_column('parent_table', 'id', sdtype='id') - metadata.set_primary_key('parent_table', 'id') + metadata.add_table('parent_table1') + metadata.add_column('parent_table1', 'id', sdtype='id') + metadata.set_primary_key('parent_table1', 'id') + + metadata.add_table('parent_table2') + metadata.add_column('parent_table2', 'id', sdtype='id') + metadata.set_primary_key('parent_table2', 'id') metadata.add_table('child_table1') metadata.add_column('child_table1', 'id', sdtype='id') metadata.set_primary_key('child_table1', 'id') - metadata.add_column('child_table1', 'fk', sdtype='id') + metadata.add_column('child_table1', 'fk1', sdtype='id') + metadata.add_column('child_table1', 'fk2', sdtype='id') metadata.add_table('child_table2') metadata.add_column('child_table2', 'id', sdtype='id') @@ -1335,29 +1340,41 @@ def test_null_foreign_keys(self): metadata.add_column('child_table2', 'cat_type', sdtype='categorical') metadata.add_relationship( - parent_table_name='parent_table', + parent_table_name='parent_table1', child_table_name='child_table1', parent_primary_key='id', - child_foreign_key='fk', + child_foreign_key='fk1', ) metadata.add_relationship( - parent_table_name='parent_table', + parent_table_name='parent_table2', + child_table_name='child_table1', + parent_primary_key='id', + child_foreign_key='fk2', + ) + + metadata.add_relationship( + parent_table_name='parent_table1', child_table_name='child_table2', parent_primary_key='id', child_foreign_key='fk1', ) metadata.add_relationship( - parent_table_name='parent_table', + parent_table_name='parent_table1', child_table_name='child_table2', parent_primary_key='id', child_foreign_key='fk2', ) data = { - 'parent_table': pd.DataFrame({'id': [1, 2, 3]}), - 'child_table1': pd.DataFrame({'id': [1, 2, 3], 'fk': [1, 2, np.nan]}), + 'parent_table1': pd.DataFrame({'id': [1, 2, 3]}), + 'parent_table2': pd.DataFrame({'id': ['alpha', 'beta', 'gamma']}), + 'child_table1': pd.DataFrame({ + 'id': [1, 2, 3], + 'fk1': [np.nan, 2, np.nan], + 'fk2': ['alpha', 'beta', np.nan], + }), 'child_table2': pd.DataFrame({ 'id': [1, 2, 3], 'fk1': [1, 2, np.nan], @@ -1377,8 +1394,10 @@ def test_null_foreign_keys(self): sampled = synthesizer.sample() # Assert - assert len(sampled['parent_table']) == 3 - assert sum(pd.isna(sampled['child_table1']['fk'])) == 1 + assert len(sampled['parent_table1']) == 3 + assert len(sampled['parent_table2']) == 3 + assert sum(pd.isna(sampled['child_table1']['fk1'])) == 2 + assert sum(pd.isna(sampled['child_table1']['fk2'])) == 1 assert sum(pd.isna(sampled['child_table2']['fk1'])) == 1 assert sum(pd.isna(sampled['child_table2']['fk2'])) == 2 From 669d1b248fc934623b6f5f2d2c85184e31248978 Mon Sep 17 00:00:00 2001 From: rwedge Date: Mon, 12 Aug 2024 17:48:21 -0400 Subject: [PATCH 14/17] handle older pandas edge case --- sdv/multi_table/hma.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/sdv/multi_table/hma.py b/sdv/multi_table/hma.py index b2e9d6566..607ce5907 100644 --- a/sdv/multi_table/hma.py +++ b/sdv/multi_table/hma.py @@ -311,7 +311,14 @@ def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc pbar_args = self._get_pbar_args(desc=progress_bar_desc) for foreign_key_value in tqdm(foreign_key_values, **pbar_args): - child_rows = child_table.loc[[foreign_key_value]] + try: + child_rows = child_table.loc[[foreign_key_value]] + except KeyError: + # pre pandas 2.1 df.loc[[np.nan]] causes error + if pd.isna(foreign_key_value): + child_rows = child_table[child_table.index.isna()] + else: + raise child_rows = child_rows[child_rows.columns.difference(foreign_key_columns)] try: if child_rows.empty and not pd.isna(foreign_key_value): From 6bb7239ee697924bd96a61ea22cbfa4e4e02eb00 Mon Sep 17 00:00:00 2001 From: rwedge Date: Tue, 13 Aug 2024 11:07:33 -0400 Subject: [PATCH 15/17] make get_likelihoods backwards compatible --- sdv/multi_table/hma.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sdv/multi_table/hma.py b/sdv/multi_table/hma.py index 607ce5907..5de322ac4 100644 --- a/sdv/multi_table/hma.py +++ b/sdv/multi_table/hma.py @@ -660,7 +660,8 @@ def _get_likelihoods(self, table_rows, parent_rows, table_name, foreign_key): except (AttributeError, np.linalg.LinAlgError): likelihoods[parent_id] = None - if f'__{table_name}__{foreign_key}' in self._null_child_synthesizers: + null_child_synths = getattr(self, '_null_child_synthesizers', {}) + if f'__{table_name}__{foreign_key}' in null_child_synths: try: likelihoods[np.nan] = synthesizer._get_likelihood(table_rows) From 897b3eb5c4e8a1b1dad6961987e3483a59c56081 Mon Sep 17 00:00:00 2001 From: Roy Wedge Date: Wed, 14 Aug 2024 14:31:39 -0400 Subject: [PATCH 16/17] Set dtypes for some test case Series Co-authored-by: Gaurav Sheni --- tests/integration/multi_table/test_hma.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index 082a88a54..65e04b6fb 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -1372,13 +1372,13 @@ def test_null_foreign_keys(self): 'parent_table2': pd.DataFrame({'id': ['alpha', 'beta', 'gamma']}), 'child_table1': pd.DataFrame({ 'id': [1, 2, 3], - 'fk1': [np.nan, 2, np.nan], - 'fk2': ['alpha', 'beta', np.nan], + 'fk1': pd.Series([np.nan, 2, np.nan], dtype="float64"), + 'fk2': pd.Series(['alpha', 'beta', np.nan], dtype="object"), }), 'child_table2': pd.DataFrame({ 'id': [1, 2, 3], 'fk1': [1, 2, np.nan], - 'fk2': [1, np.nan, np.nan], + 'fk2': pd.Series([1, np.nan, np.nan], dtype="float64"), 'cat_type': pd.Series(['siamese', 'persian', 'american shorthair'], dtype='object'), }), } From 3ef2b1d5587c0c84b0d2c1c8693f068b94283f96 Mon Sep 17 00:00:00 2001 From: Roy Wedge Date: Wed, 14 Aug 2024 15:34:05 -0400 Subject: [PATCH 17/17] single quotes preferred --- tests/integration/multi_table/test_hma.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index 65e04b6fb..69c6a1aab 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -1372,13 +1372,13 @@ def test_null_foreign_keys(self): 'parent_table2': pd.DataFrame({'id': ['alpha', 'beta', 'gamma']}), 'child_table1': pd.DataFrame({ 'id': [1, 2, 3], - 'fk1': pd.Series([np.nan, 2, np.nan], dtype="float64"), - 'fk2': pd.Series(['alpha', 'beta', np.nan], dtype="object"), + 'fk1': pd.Series([np.nan, 2, np.nan], dtype='float64'), + 'fk2': pd.Series(['alpha', 'beta', np.nan], dtype='object'), }), 'child_table2': pd.DataFrame({ 'id': [1, 2, 3], 'fk1': [1, 2, np.nan], - 'fk2': pd.Series([1, np.nan, np.nan], dtype="float64"), + 'fk2': pd.Series([1, np.nan, np.nan], dtype='float64'), 'cat_type': pd.Series(['siamese', 'persian', 'american shorthair'], dtype='object'), }), }