diff --git a/README.md b/README.md index 0cd6d40..a7beafa 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # Revnets ![Python version](https://img.shields.io/badge/python-3.10+-brightgreen) ![Operating system](https://img.shields.io/badge/os-linux%20%7c%20macOS-brightgreen) -![Coverage](https://img.shields.io/badge/coverage-84%25-brightgreen) +![Coverage](https://img.shields.io/badge/coverage-87%25-brightgreen) Reverse engineer internal parameters of black box neural networks diff --git a/tests/conftest.py b/tests/conftest.py index ee35341..fa51b90 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,7 +4,7 @@ import pytest from revnets.context import context as context_ from revnets.context.context import Context -from revnets.models import HyperParameters, Path +from revnets.models import Config, HyperParameters, Path @pytest.fixture(scope="session") @@ -24,14 +24,15 @@ def mocked_assets_path() -> Iterator[None]: @pytest.fixture def test_context(context: Context, mocked_assets_path: None) -> Iterator[None]: + config = context.config hyperparameters = HyperParameters(epochs=1, learning_rate=1.0e-2, batch_size=32) - target_network_training = context.config.target_network_training - reconstruction_training = context.config.reconstruction_training - epochs = context.config.max_difficult_inputs_epochs - context.config.target_network_training = hyperparameters - context.config.reconstruction_training = hyperparameters - context.config.max_difficult_inputs_epochs = 1 - yield - context.config.target_network_training = target_network_training - context.config.reconstruction_training = reconstruction_training - context.config.max_difficult_inputs_epochs = epochs + context.loaders.config.value = Config( + target_network_training=hyperparameters, + reconstruction_training=hyperparameters, + max_difficult_inputs_epochs=1, + run_analysis=True, + ) + mock = patch("matplotlib.pyplot.show") + with mock: + yield + context.loaders.config.value = config diff --git a/tests/test_main.py b/tests/test_main.py index 9f12e45..105c1e5 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -8,5 +8,5 @@ @pytest.mark.skipif( not gpu_available, reason="Only test model training if GPU is available" ) -def test_main() -> None: +def test_main(test_context: None) -> None: Experiment().run()