Skip to content

Commit

Permalink
test analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
quintenroets committed Apr 13, 2024
1 parent e13ef15 commit aaf024c
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 13 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Revnets
![Python version](https://img.shields.io/badge/python-3.10+-brightgreen)
![Operating system](https://img.shields.io/badge/os-linux%20%7c%20macOS-brightgreen)
![Coverage](https://img.shields.io/badge/coverage-84%25-brightgreen)
![Coverage](https://img.shields.io/badge/coverage-87%25-brightgreen)

Reverse engineer internal parameters of black box neural networks

Expand Down
23 changes: 12 additions & 11 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
from revnets.context import context as context_
from revnets.context.context import Context
from revnets.models import HyperParameters, Path
from revnets.models import Config, HyperParameters, Path


@pytest.fixture(scope="session")
Expand All @@ -24,14 +24,15 @@ def mocked_assets_path() -> Iterator[None]:

@pytest.fixture
def test_context(context: Context, mocked_assets_path: None) -> Iterator[None]:
config = context.config
hyperparameters = HyperParameters(epochs=1, learning_rate=1.0e-2, batch_size=32)
target_network_training = context.config.target_network_training
reconstruction_training = context.config.reconstruction_training
epochs = context.config.max_difficult_inputs_epochs
context.config.target_network_training = hyperparameters
context.config.reconstruction_training = hyperparameters
context.config.max_difficult_inputs_epochs = 1
yield
context.config.target_network_training = target_network_training
context.config.reconstruction_training = reconstruction_training
context.config.max_difficult_inputs_epochs = epochs
context.loaders.config.value = Config(
target_network_training=hyperparameters,
reconstruction_training=hyperparameters,
max_difficult_inputs_epochs=1,
run_analysis=True,
)
mock = patch("matplotlib.pyplot.show")
with mock:
yield
context.loaders.config.value = config
2 changes: 1 addition & 1 deletion tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@
@pytest.mark.skipif(
not gpu_available, reason="Only test model training if GPU is available"
)
def test_main() -> None:
def test_main(test_context: None) -> None:
Experiment().run()

0 comments on commit aaf024c

Please sign in to comment.