Skip to content

Commit

Permalink
move progress_bar HMA
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo committed Jul 7, 2023
1 parent 750332c commit 889bd7b
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 11 deletions.
10 changes: 3 additions & 7 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ class BaseMultiTableSynthesizer:
for.
locales (list or str):
The default locale(s) to use for AnonymizedFaker transformers. Defaults to ``None``.
verbose (bool):
Whether to print progress for fitting or not.
"""

DEFAULT_SYNTHESIZER_KWARGS = None
Expand Down Expand Up @@ -62,20 +60,18 @@ def _initialize_models(self):

def _get_pbar_args(self, **kwargs):
"""Return a dictionary with the updated keyword args for a progress bar."""
pbar_args = {'disable': not self.verbose}
pbar_args = {'disable': True}
pbar_args.update(kwargs)

return pbar_args

def _print(self, text='', **kwargs):
if self.verbose:
print(text, **kwargs) # noqa: T001
print(text, **kwargs) # noqa: T001

def __init__(self, metadata, locales=None, synthesizer_kwargs=None, verbose=True):
def __init__(self, metadata, locales=None, synthesizer_kwargs=None):
self.metadata = metadata
self.metadata.validate()
self.locales = locales
self.verbose = verbose
self._table_synthesizers = {}
self._table_parameters = defaultdict(dict)
if synthesizer_kwargs is not None:
Expand Down
7 changes: 7 additions & 0 deletions sdv/multi_table/hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ def __init__(self, metadata, locales=None, verbose=True):
self._table_synthesizers,
self._table_sizes)

def _get_pbar_args(self, **kwargs):
"""Return a dictionary with the updated keyword args for a progress bar."""
pbar_args = super()._get_pbar_args(**kwargs)
pbar_args['disable'] = not self.verbose

return pbar_args

def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc):
"""Generate the extension columns for this child table.
Expand Down
28 changes: 28 additions & 0 deletions tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import re

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -686,3 +687,30 @@ def test_use_own_data_using_hma(tmp_path):

for table in metadata.tables:
assert set(synthetic_data[table].columns) == set(datasets[table].columns)


def test_progress_bar_print(capsys):
"""Test that the progress bar prints correctly."""
# Setup
data, metadata = download_demo('multi_table', 'got_families')
hmasynthesizer = HMASynthesizer(metadata)

# Define the phrases that should be printed
key_phrases = [
r'Preprocess Tables:',
r'Learning relationships:',
r"\(1/2\) Tables 'characters' and 'character_families' \('character_id'\):",
r"\(2/2\) Tables 'families' and 'character_families' \('family_id'\):"
]

# Run
hmasynthesizer.fit(data)
hmasynthesizer.sample(0.5)

# Capture output
captured = capsys.readouterr()

# Assert
for pattern in key_phrases:
match = re.search(pattern, captured.out + captured.err)
assert match is not None
6 changes: 2 additions & 4 deletions tests/unit/multi_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,9 @@ def test__initialize_models(self):
])

def test__get_pbar_args(self):
"""Test that ``_get_pbar_args`` returns a dictionary with disable opposite to verbose."""
"""Test that ``_get_pbar_args`` returns a dictionary with disable True."""
# Setup
instance = Mock()
instance.verbose = False

# Run
result = BaseMultiTableSynthesizer._get_pbar_args(instance)
Expand All @@ -68,7 +67,6 @@ def test__get_pbar_args_kwargs(self):
"""Test that ``_get_pbar_args`` returns a dictionary with the given kwargs."""
# Setup
instance = Mock()
instance.verbose = True

# Run
result = BaseMultiTableSynthesizer._get_pbar_args(
Expand All @@ -79,7 +77,7 @@ def test__get_pbar_args_kwargs(self):

# Assert
assert result == {
'disable': False,
'disable': True,
'desc': 'Process Table',
'position': 0
}
Expand Down

0 comments on commit 889bd7b

Please sign in to comment.