From aaf024c6470e92df8884a81227e2d7a5b2a6adf9 Mon Sep 17 00:00:00 2001 From: Quinten Date: Fri, 12 Apr 2024 23:17:08 -0700 Subject: [PATCH] test analysis --- README.md | 2 +- tests/conftest.py | 23 ++++++++++++----------- tests/test_main.py | 2 +- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 0cd6d40..a7beafa 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # Revnets ![Python version](https://img.shields.io/badge/python-3.10+-brightgreen) ![Operating system](https://img.shields.io/badge/os-linux%20%7c%20macOS-brightgreen) -![Coverage](https://img.shields.io/badge/coverage-84%25-brightgreen) +![Coverage](https://img.shields.io/badge/coverage-87%25-brightgreen) Reverse engineer internal parameters of black box neural networks diff --git a/tests/conftest.py b/tests/conftest.py index ee35341..fa51b90 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,7 +4,7 @@ import pytest from revnets.context import context as context_ from revnets.context.context import Context -from revnets.models import HyperParameters, Path +from revnets.models import Config, HyperParameters, Path @pytest.fixture(scope="session") @@ -24,14 +24,15 @@ def mocked_assets_path() -> Iterator[None]: @pytest.fixture def test_context(context: Context, mocked_assets_path: None) -> Iterator[None]: + config = context.config hyperparameters = HyperParameters(epochs=1, learning_rate=1.0e-2, batch_size=32) - target_network_training = context.config.target_network_training - reconstruction_training = context.config.reconstruction_training - epochs = context.config.max_difficult_inputs_epochs - context.config.target_network_training = hyperparameters - context.config.reconstruction_training = hyperparameters - context.config.max_difficult_inputs_epochs = 1 - yield - context.config.target_network_training = target_network_training - context.config.reconstruction_training = reconstruction_training - context.config.max_difficult_inputs_epochs = epochs + context.loaders.config.value = Config( + target_network_training=hyperparameters, + reconstruction_training=hyperparameters, + max_difficult_inputs_epochs=1, + run_analysis=True, + ) + mock = patch("matplotlib.pyplot.show") + with mock: + yield + context.loaders.config.value = config diff --git a/tests/test_main.py b/tests/test_main.py index 9f12e45..105c1e5 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -8,5 +8,5 @@ @pytest.mark.skipif( not gpu_available, reason="Only test model training if GPU is available" ) -def test_main() -> None: +def test_main(test_context: None) -> None: Experiment().run()