From 98b7a83b3e30bc207fd3f3b47bf6d715676dbbb1 Mon Sep 17 00:00:00 2001 From: Frances Hartwell Date: Tue, 20 Jun 2023 09:22:26 -0400 Subject: [PATCH] comment --- sdv/multi_table/hma.py | 10 +++++----- sdv/single_table/base.py | 9 +++++---- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/sdv/multi_table/hma.py b/sdv/multi_table/hma.py index 361115c7e..9615257fc 100644 --- a/sdv/multi_table/hma.py +++ b/sdv/multi_table/hma.py @@ -358,16 +358,15 @@ def _find_parent_ids(self, child_table, parent_table, child_name, parent_name, f """Find parent ids for the given table and foreign key. The parent ids are chosen randomly based on the likelihood of the available - parent ids in the parent table. If the parent table is not sampled, this method - will first sample rows for the parent table. + parent ids in the parent table. Args: child_table (pd.DataFrame): - + The child table dataframe. parent_table (pd.DataFrame): - + The parent table dataframe. child_name (str): - + The name of the child table. parent_name (dict): Map of table name to sampled data (pandas.DataFrame). foreign_key (str): @@ -377,6 +376,7 @@ def _find_parent_ids(self, child_table, parent_table, child_name, parent_name, f pandas.Series: The parent ids for the given table data. """ + # Create a copy of the parent table with the primary key as index to calculate likilihoods primary_key = self.metadata.tables[parent_name].primary_key parent_table = parent_table.set_index(primary_key) num_rows = parent_table[f'__{child_name}__{foreign_key}__num_rows'].fillna(0).clip(0) diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index 6a690eb28..4bc080a1e 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -604,10 +604,11 @@ def _sample_rows(self, num_rows, conditions=None, transformed_conditions=None, * int: Number of rows that are considered valid. """ - if (self._model and ( - self._data_processor.get_sdtypes(primary_keys=False) or keep_extra_columns)): - if not self._random_state_set: - self._set_random_state(FIXED_RNG_SEED) + if self._model and not self._random_state_set: + self._set_random_state(FIXED_RNG_SEED) + + need_sample = self._data_processor.get_sdtypes(primary_keys=False) or keep_extra_columns + if self._model and need_sample: if conditions is None: raw_sampled = self._sample(num_rows)