Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor progress bars #1272

Merged
merged 2 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@
# Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
os.environ["PYDEVD_DISABLE_FILE_VALIDATION"] = "1"

# TQDM and nbsphinx do not play well together. Therefore, disable TQDM
# for the documentation build.
# (`Content block expected for the "raw" directive; none found.`)
os.environ["TQDM_DISABLE"] = "1"

# -- General configuration ------------------------------------------------

# If your documentation needs a minimal Sphinx version, state it here.
Expand Down
4 changes: 2 additions & 2 deletions pypesto/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self):

@abc.abstractmethod
def execute(
self, tasks: list[Task], progress_bar: bool = True
self, tasks: list[Task], progress_bar: bool = None
) -> list[Any]:
"""Execute tasks.

Expand All @@ -22,6 +22,6 @@ def execute(
tasks:
List of tasks to execute.
progress_bar:
Whether to display a progress bar. Defaults to ``True``.
Whether to display a progress bar.
"""
raise NotImplementedError("This engine is not intended to be called.")
8 changes: 4 additions & 4 deletions pypesto/engine/mpi_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import cloudpickle as pickle
from mpi4py import MPI
from mpi4py.futures import MPIPoolExecutor
from tqdm import tqdm

from ..util import tqdm
from .base import Engine
from .task import Task

Expand All @@ -32,7 +32,7 @@ def __init__(self):
super().__init__()

def execute(
self, tasks: list[Task], progress_bar: bool = True
self, tasks: list[Task], progress_bar: bool = None
) -> list[Any]:
"""
Pickle tasks and distribute work to workers.
Expand All @@ -42,7 +42,7 @@ def execute(
tasks:
List of :class:`pypesto.engine.Task` to execute.
progress_bar:
Whether to display a progress bar. Defaults to ``True``.
Whether to display a progress bar.

Returns
-------
Expand All @@ -55,6 +55,6 @@ def execute(

with MPIPoolExecutor() as executor:
results = executor.map(
work, tqdm(pickled_tasks, disable=not progress_bar)
work, tqdm(pickled_tasks, enable=progress_bar)
)
return results
8 changes: 4 additions & 4 deletions pypesto/engine/multi_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from typing import Any, Union

import cloudpickle as pickle
from tqdm import tqdm

from ..util import tqdm
from .base import Engine
from .task import Task

Expand Down Expand Up @@ -52,7 +52,7 @@ def __init__(
self.method: str = method

def execute(
self, tasks: list[Task], progress_bar: bool = True
self, tasks: list[Task], progress_bar: bool = None
) -> list[Any]:
"""Pickle tasks and distribute work over parallel processes.

Expand All @@ -61,7 +61,7 @@ def execute(
tasks:
List of :class:`pypesto.engine.Task` to execute.
progress_bar:
Whether to display a progress bar. Defaults to ``True``.
Whether to display a progress bar.

Returns
-------
Expand All @@ -81,7 +81,7 @@ def execute(
tqdm(
pool.imap(work, pickled_tasks),
total=len(pickled_tasks),
disable=not progress_bar,
enable=progress_bar,
),
)

Expand Down
7 changes: 3 additions & 4 deletions pypesto/engine/multi_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Union

from tqdm import tqdm

from ..util import tqdm
from .base import Engine
from .task import Task

Expand Down Expand Up @@ -43,7 +42,7 @@ def __init__(self, n_threads: Union[int, None] = None):
self.n_threads: int = n_threads

def execute(
self, tasks: list[Task], progress_bar: bool = True
self, tasks: list[Task], progress_bar: bool = None
) -> list[Any]:
"""Deepcopy tasks and distribute work over parallel threads.

Expand All @@ -70,7 +69,7 @@ def execute(
tqdm(
pool.map(work, copied_tasks),
total=len(copied_tasks),
disable=not progress_bar,
enable=progress_bar,
),
)

Expand Down
10 changes: 6 additions & 4 deletions pypesto/engine/single_core.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
"""Engines without parallelization."""
from typing import Any

from tqdm import tqdm

from ..util import tqdm
from .base import Engine
from .task import Task

Expand All @@ -18,7 +17,7 @@ def __init__(self):
super().__init__()

def execute(
self, tasks: list[Task], progress_bar: bool = True
self, tasks: list[Task], progress_bar: bool = None
) -> list[Any]:
"""Execute all tasks in a simple for loop sequentially.

Expand All @@ -34,7 +33,10 @@ def execute(
A list of results.
"""
results = []
for task in tqdm(tasks, disable=not progress_bar):
for task in tqdm(
tasks,
enable=progress_bar,
):
results.append(task.execute())

return results
2 changes: 1 addition & 1 deletion pypesto/ensemble/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,7 +877,7 @@ def predict(
include_llh_weights: bool = False,
include_sigmay: bool = False,
engine: Engine = None,
progress_bar: bool = True,
progress_bar: bool = None,
) -> EnsemblePrediction:
"""
Run predictions for a full ensemble.
Expand Down
2 changes: 1 addition & 1 deletion pypesto/optimize/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def minimize(
startpoint_method: Union[StartpointMethod, Callable, bool] = None,
result: Result = None,
engine: Engine = None,
progress_bar: bool = True,
progress_bar: bool = None,
options: OptimizeOptions = None,
history_options: HistoryOptions = None,
filename: Union[str, Callable, None] = None,
Expand Down
2 changes: 1 addition & 1 deletion pypesto/profile/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def parameter_profile(
result_index: int = 0,
next_guess_method: Union[Callable, str] = 'adaptive_step_regression',
profile_options: ProfileOptions = None,
progress_bar: bool = True,
progress_bar: bool = None,
filename: Union[str, Callable, None] = None,
overwrite: bool = False,
) -> Result:
Expand Down
2 changes: 1 addition & 1 deletion pypesto/sample/adaptive_metropolis.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def default_options(cls):
# target acceptance rate
'target_acceptance_rate': 0.234,
# show progress
'show_progress': True,
'show_progress': None,
}

def initialize(self, problem: Problem, x0: np.ndarray):
Expand Down
8 changes: 4 additions & 4 deletions pypesto/sample/metropolis.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Dict, Sequence, Union

import numpy as np
from tqdm import tqdm

from ..history import NoHistory
from ..objective import NegLogPriors, ObjectiveBase
from ..problem import Problem
from ..result import McmcPtResult
from ..util import tqdm
from .sampler import InternalSample, InternalSampler


Expand Down Expand Up @@ -51,7 +51,7 @@ def default_options(cls):
"""Return the default options for the sampler."""
return {
'std': 1.0, # the proposal standard deviation
'show_progress': True, # whether to show the progress
'show_progress': None, # whether to show the progress
}

def initialize(self, problem: Problem, x0: np.ndarray):
Expand All @@ -73,10 +73,10 @@ def sample(self, n_samples: int, beta: float = 1.0):
lpost = -self.trace_neglogpost[-1]
lprior = -self.trace_neglogprior[-1]

show_progress = self.options['show_progress']
show_progress = self.options.get('show_progress', None)

# loop over iterations
for _ in tqdm(range(int(n_samples)), disable=not show_progress):
for _ in tqdm(range(int(n_samples)), enable=show_progress):
# perform step
x, lpost, lprior = self._perform_step(
x=x, lpost=lpost, lprior=lprior, beta=beta
Expand Down
8 changes: 4 additions & 4 deletions pypesto/sample/parallel_tempering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from typing import Dict, List, Sequence, Union

import numpy as np
from tqdm import tqdm

from ..problem import Problem
from ..result import McmcPtResult
from ..util import tqdm
from .sampler import InternalSampler, Sampler


Expand Down Expand Up @@ -70,7 +70,7 @@ def default_options(cls) -> Dict:
'max_temp': 5e4,
'exponent': 4,
'temper_log_posterior': False,
'show_progress': True,
'show_progress': None,
}

def initialize(
Expand All @@ -89,9 +89,9 @@ def initialize(

def sample(self, n_samples: int, beta: float = 1.0):
"""Sample and swap in between samplers."""
show_progress = self.options['show_progress']
show_progress = self.options.get('show_progress', None)
# loop over iterations
for i_sample in tqdm(range(int(n_samples)), disable=not show_progress):
for i_sample in tqdm(range(int(n_samples)), enable=show_progress):
# TODO test
# sample
for sampler, beta in zip(self.samplers, self.betas):
Expand Down
38 changes: 38 additions & 0 deletions pypesto/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import numpy as np
from scipy import cluster
from tqdm import tqdm as _tqdm


def _check_none(fun: Callable[..., Any]) -> Callable[..., Union[Any, None]]:
Expand Down Expand Up @@ -295,3 +296,40 @@ def delete_nan_inf(
)
)
return x, fvals[finite_fvals]


def tqdm(*args, enable: bool = None, **kwargs):
"""
Create a progress using tqdm.
dweindl marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
args:
Arguments passed to tqdm.
enable:
Whether to enable the progress bar.
If None, use tqdm defaults.
Mutually exclusive with `disable`.
kwargs:
Keyword arguments passed to tqdm.

Returns
-------
progress_bar:
A progress bar.
"""
# Drop the `disable` argument unless it is not-None.
# This way, we don't interfere with TQDM_DISABLE or other global
# tqdm settings.
disable = kwargs.pop("disable", None)

if enable is not None:
if disable is not None and enable != disable:
raise ValueError(
"Contradicting values for `enable` and `disable` passed."
)
disable = not enable

if disable is not None:
kwargs["disable"] = disable
return _tqdm(*args, **kwargs)