Skip to content

Commit

Permalink
comment
Browse files Browse the repository at this point in the history
  • Loading branch information
frances-h committed Jun 20, 2023
1 parent 6661b9a commit 98b7a83
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
10 changes: 5 additions & 5 deletions sdv/multi_table/hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,16 +358,15 @@ def _find_parent_ids(self, child_table, parent_table, child_name, parent_name, f
"""Find parent ids for the given table and foreign key.
The parent ids are chosen randomly based on the likelihood of the available
parent ids in the parent table. If the parent table is not sampled, this method
will first sample rows for the parent table.
parent ids in the parent table.
Args:
child_table (pd.DataFrame):
The child table dataframe.
parent_table (pd.DataFrame):
The parent table dataframe.
child_name (str):
The name of the child table.
parent_name (dict):
Map of table name to sampled data (pandas.DataFrame).
foreign_key (str):
Expand All @@ -377,6 +376,7 @@ def _find_parent_ids(self, child_table, parent_table, child_name, parent_name, f
pandas.Series:
The parent ids for the given table data.
"""
# Create a copy of the parent table with the primary key as index to calculate likilihoods
primary_key = self.metadata.tables[parent_name].primary_key
parent_table = parent_table.set_index(primary_key)
num_rows = parent_table[f'__{child_name}__{foreign_key}__num_rows'].fillna(0).clip(0)
Expand Down
9 changes: 5 additions & 4 deletions sdv/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,10 +604,11 @@ def _sample_rows(self, num_rows, conditions=None, transformed_conditions=None,
* int:
Number of rows that are considered valid.
"""
if (self._model and (
self._data_processor.get_sdtypes(primary_keys=False) or keep_extra_columns)):
if not self._random_state_set:
self._set_random_state(FIXED_RNG_SEED)
if self._model and not self._random_state_set:
self._set_random_state(FIXED_RNG_SEED)

need_sample = self._data_processor.get_sdtypes(primary_keys=False) or keep_extra_columns
if self._model and need_sample:

if conditions is None:
raw_sampled = self._sample(num_rows)
Expand Down

0 comments on commit 98b7a83

Please sign in to comment.