Skip to content

Commit

Permalink
Move progress bar out of base multi table synthesizer (#1495)
Browse files Browse the repository at this point in the history
* move progress_bar HMA

* docstring

* set verbose to false in the base

* test

* remove parameter base
  • Loading branch information
R-Palazzo committed Jul 10, 2023
1 parent 750332c commit 804b52e
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 3 deletions.
4 changes: 2 additions & 2 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,11 @@ def _print(self, text='', **kwargs):
if self.verbose:
print(text, **kwargs) # noqa: T001

def __init__(self, metadata, locales=None, synthesizer_kwargs=None, verbose=True):
def __init__(self, metadata, locales=None, synthesizer_kwargs=None):
self.metadata = metadata
self.metadata.validate()
self.locales = locales
self.verbose = verbose
self.verbose = False
self._table_synthesizers = {}
self._table_parameters = defaultdict(dict)
if synthesizer_kwargs is not None:
Expand Down
1 change: 0 additions & 1 deletion sdv/multi_table/hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def __init__(self, metadata, locales=None, verbose=True):
self._augmented_tables = []
self._learned_relationships = 0
self.verbose = verbose

BaseHierarchicalSampler.__init__(
self,
self.metadata,
Expand Down
26 changes: 26 additions & 0 deletions tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import re

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -686,3 +687,28 @@ def test_use_own_data_using_hma(tmp_path):

for table in metadata.tables:
assert set(synthetic_data[table].columns) == set(datasets[table].columns)


def test_progress_bar_print(capsys):
"""Test that the progress bar prints correctly."""
# Setup
data, metadata = download_demo('multi_table', 'got_families')
hmasynthesizer = HMASynthesizer(metadata)

key_phrases = [
r'Preprocess Tables:',
r'Learning relationships:',
r"\(1/2\) Tables 'characters' and 'character_families' \('character_id'\):",
r"\(2/2\) Tables 'families' and 'character_families' \('family_id'\):"
]

# Run
hmasynthesizer.fit(data)
hmasynthesizer.sample(0.5)

captured = capsys.readouterr()

# Assert
for pattern in key_phrases:
match = re.search(pattern, captured.out + captured.err)
assert match is not None

0 comments on commit 804b52e

Please sign in to comment.