From 5285fce02eb7aeeb65938d58df3dd5b53e2ee5d1 Mon Sep 17 00:00:00 2001 From: Quinten Date: Sat, 13 Apr 2024 16:57:13 -0700 Subject: [PATCH] fix tests --- src/revnets/context/context.py | 7 +++++-- src/revnets/evaluations/evaluate.py | 4 +--- src/revnets/main/main.py | 9 +++++---- src/revnets/reconstructions/queries/data.py | 4 +++- src/revnets/training/targets/network.py | 6 +----- tests/test_cli_entry_points.py | 2 +- tests/test_pipelines.py | 16 ++++++++++++++++ 7 files changed, 32 insertions(+), 16 deletions(-) diff --git a/src/revnets/context/context.py b/src/revnets/context/context.py index 28fa284..9a38381 100644 --- a/src/revnets/context/context.py +++ b/src/revnets/context/context.py @@ -1,4 +1,5 @@ from functools import cached_property +from typing import cast import torch from package_utils.context import Context as Context_ @@ -22,8 +23,10 @@ def output_path(self) -> Path: @property def results_path(self) -> Path: - path = self.output_path / "results.yaml" - return path.with_nonexistent_name() + relative_path = self.output_path.relative_to(Path.config) + path = Path.results / relative_path / "results.yaml" + path = path.with_nonexistent_name() + return cast(Path, path) @property def log_path(self) -> Path: diff --git a/src/revnets/evaluations/evaluate.py b/src/revnets/evaluations/evaluate.py index de6394a..6ef437a 100644 --- a/src/revnets/evaluations/evaluate.py +++ b/src/revnets/evaluations/evaluate.py @@ -6,7 +6,7 @@ from revnets.context import context from ..pipelines import Pipeline -from . import outputs, weights +from . import analysis, outputs, weights from .base import Evaluator from .evaluation import Evaluation @@ -17,8 +17,6 @@ def apply(evaluation_module: ModuleType) -> Any: return evaluator.get_evaluation() if context.config.evaluation.run_analysis: - from . import analysis - analysis_modules = (analysis.weights,) for analysis_module in analysis_modules: apply(analysis_module) diff --git a/src/revnets/main/main.py b/src/revnets/main/main.py index f5a0880..69eef32 100644 --- a/src/revnets/main/main.py +++ b/src/revnets/main/main.py @@ -1,5 +1,6 @@ from dataclasses import dataclass from types import ModuleType +from typing import Any import cli import numpy as np @@ -28,15 +29,15 @@ def config(self) -> Config: def run(self) -> None: set_seed() cli.console.rule(self.config.title) - self.run_experiment() + results = self.run_experiment() + context.results_path.yaml = results - def run_experiment(self) -> None: + def run_experiment(self) -> dict[str, Any]: pipeline: Pipeline = extract_module(pipelines, self.config.pipeline).Pipeline() reconstruction = self.create_reconstruction(pipeline) evaluation = evaluations.evaluate(reconstruction, pipeline) evaluation.show() - results = {"metrics": evaluation.dict(), "config": context.config.dict()} - context.results_path.yaml = results + return {"metrics": evaluation.dict(), "config": context.config.dict()} def create_reconstruction(self, pipeline: Pipeline) -> nn.Module: module = extract_module(reconstructions, self.config.reconstruction_technique) diff --git a/src/revnets/reconstructions/queries/data.py b/src/revnets/reconstructions/queries/data.py index 7a6d14c..8a84278 100644 --- a/src/revnets/reconstructions/queries/data.py +++ b/src/revnets/reconstructions/queries/data.py @@ -16,7 +16,9 @@ def __init__( ) -> None: self.target = target if evaluation_batch_size is None: - evaluation_batch_size = context.config.evaluation_batch_size + evaluation_batch_size = ( + context.config.evaluation_batch_size + ) # pragma: nocover self.evaluation_batch_size = evaluation_batch_size tensors = torch.Tensor([]), torch.Tensor([]) super().__init__(*tensors) diff --git a/src/revnets/training/targets/network.py b/src/revnets/training/targets/network.py index f2f28f1..1b0c1f1 100644 --- a/src/revnets/training/targets/network.py +++ b/src/revnets/training/targets/network.py @@ -9,11 +9,7 @@ class Network(network.Network[Metrics]): - def __init__( - self, - model: nn.Module, - do_log: bool = True, - ) -> None: + def __init__(self, model: nn.Module, do_log: bool = True) -> None: learning_rate = context.config.target_network_training.learning_rate super().__init__(model, learning_rate, do_log) diff --git a/tests/test_cli_entry_points.py b/tests/test_cli_entry_points.py index 7abd3a4..3433b4d 100644 --- a/tests/test_cli_entry_points.py +++ b/tests/test_cli_entry_points.py @@ -5,7 +5,7 @@ @no_cli_args -@patch("revnets.main.main.Experiment.run_experiment") +@patch("revnets.main.main.Experiment.run_experiment", return_value={}) def test_main(run: MagicMock) -> None: entry_point.entry_point() run.assert_called_once() diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index bcc0f50..1d6c0a1 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -11,6 +11,16 @@ pipelines.images.mediumnet_small, ) +all_pipeline_modules = ( + *pipeline_modules, + pipelines.mininet.mininet_40, + pipelines.mininet.mininet_100, + pipelines.images.mininet_100, + pipelines.images.mininet_128, + pipelines.images.mininet_200, + pipelines.mediumnet.mediumnet_40, +) + @pytest.mark.parametrize("pipeline_module", pipeline_modules) def test_target_network_training( @@ -18,3 +28,9 @@ def test_target_network_training( ) -> None: pipeline: Pipeline = pipeline_module.Pipeline() pipeline.create_target_network() + + +@pytest.mark.parametrize("pipeline_module", all_pipeline_modules) +def test_network_factory(pipeline_module: ModuleType, test_context: None) -> None: + pipeline: Pipeline = pipeline_module.Pipeline() + pipeline.create_network_factory()