From 804b52ef2baeec28860396e1efda73c560cea60b Mon Sep 17 00:00:00 2001 From: R-Palazzo <116157184+R-Palazzo@users.noreply.github.com> Date: Mon, 10 Jul 2023 19:22:27 +0200 Subject: [PATCH] Move progress bar out of base multi table synthesizer (#1495) * move progress_bar HMA * docstring * set verbose to false in the base * test * remove parameter base --- sdv/multi_table/base.py | 4 ++-- sdv/multi_table/hma.py | 1 - tests/integration/multi_table/test_hma.py | 26 +++++++++++++++++++++++ 3 files changed, 28 insertions(+), 3 deletions(-) diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 39ea391b7..cb8034efe 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -71,11 +71,11 @@ def _print(self, text='', **kwargs): if self.verbose: print(text, **kwargs) # noqa: T001 - def __init__(self, metadata, locales=None, synthesizer_kwargs=None, verbose=True): + def __init__(self, metadata, locales=None, synthesizer_kwargs=None): self.metadata = metadata self.metadata.validate() self.locales = locales - self.verbose = verbose + self.verbose = False self._table_synthesizers = {} self._table_parameters = defaultdict(dict) if synthesizer_kwargs is not None: diff --git a/sdv/multi_table/hma.py b/sdv/multi_table/hma.py index 9615257fc..84189613b 100644 --- a/sdv/multi_table/hma.py +++ b/sdv/multi_table/hma.py @@ -38,7 +38,6 @@ def __init__(self, metadata, locales=None, verbose=True): self._augmented_tables = [] self._learned_relationships = 0 self.verbose = verbose - BaseHierarchicalSampler.__init__( self, self.metadata, diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index 016b3fde5..7ce5213dc 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -1,4 +1,5 @@ import datetime +import re import numpy as np import pandas as pd @@ -686,3 +687,28 @@ def test_use_own_data_using_hma(tmp_path): for table in metadata.tables: assert set(synthetic_data[table].columns) == set(datasets[table].columns) + + +def test_progress_bar_print(capsys): + """Test that the progress bar prints correctly.""" + # Setup + data, metadata = download_demo('multi_table', 'got_families') + hmasynthesizer = HMASynthesizer(metadata) + + key_phrases = [ + r'Preprocess Tables:', + r'Learning relationships:', + r"\(1/2\) Tables 'characters' and 'character_families' \('character_id'\):", + r"\(2/2\) Tables 'families' and 'character_families' \('family_id'\):" + ] + + # Run + hmasynthesizer.fit(data) + hmasynthesizer.sample(0.5) + + captured = capsys.readouterr() + + # Assert + for pattern in key_phrases: + match = re.search(pattern, captured.out + captured.err) + assert match is not None