Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
quintenroets committed Apr 13, 2024
1 parent 3c7e5f4 commit 5285fce
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 16 deletions.
7 changes: 5 additions & 2 deletions src/revnets/context/context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from functools import cached_property
from typing import cast

import torch
from package_utils.context import Context as Context_
Expand All @@ -22,8 +23,10 @@ def output_path(self) -> Path:

@property
def results_path(self) -> Path:
path = self.output_path / "results.yaml"
return path.with_nonexistent_name()
relative_path = self.output_path.relative_to(Path.config)
path = Path.results / relative_path / "results.yaml"
path = path.with_nonexistent_name()
return cast(Path, path)

@property
def log_path(self) -> Path:
Expand Down
4 changes: 1 addition & 3 deletions src/revnets/evaluations/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from revnets.context import context

from ..pipelines import Pipeline
from . import outputs, weights
from . import analysis, outputs, weights
from .base import Evaluator
from .evaluation import Evaluation

Expand All @@ -17,8 +17,6 @@ def apply(evaluation_module: ModuleType) -> Any:
return evaluator.get_evaluation()

if context.config.evaluation.run_analysis:
from . import analysis

analysis_modules = (analysis.weights,)
for analysis_module in analysis_modules:
apply(analysis_module)
Expand Down
9 changes: 5 additions & 4 deletions src/revnets/main/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dataclasses import dataclass
from types import ModuleType
from typing import Any

import cli
import numpy as np
Expand Down Expand Up @@ -28,15 +29,15 @@ def config(self) -> Config:
def run(self) -> None:
set_seed()
cli.console.rule(self.config.title)
self.run_experiment()
results = self.run_experiment()
context.results_path.yaml = results

def run_experiment(self) -> None:
def run_experiment(self) -> dict[str, Any]:
pipeline: Pipeline = extract_module(pipelines, self.config.pipeline).Pipeline()
reconstruction = self.create_reconstruction(pipeline)
evaluation = evaluations.evaluate(reconstruction, pipeline)
evaluation.show()
results = {"metrics": evaluation.dict(), "config": context.config.dict()}
context.results_path.yaml = results
return {"metrics": evaluation.dict(), "config": context.config.dict()}

def create_reconstruction(self, pipeline: Pipeline) -> nn.Module:
module = extract_module(reconstructions, self.config.reconstruction_technique)
Expand Down
4 changes: 3 additions & 1 deletion src/revnets/reconstructions/queries/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ def __init__(
) -> None:
self.target = target
if evaluation_batch_size is None:
evaluation_batch_size = context.config.evaluation_batch_size
evaluation_batch_size = (
context.config.evaluation_batch_size
) # pragma: nocover
self.evaluation_batch_size = evaluation_batch_size
tensors = torch.Tensor([]), torch.Tensor([])
super().__init__(*tensors)
Expand Down
6 changes: 1 addition & 5 deletions src/revnets/training/targets/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,7 @@


class Network(network.Network[Metrics]):
def __init__(
self,
model: nn.Module,
do_log: bool = True,
) -> None:
def __init__(self, model: nn.Module, do_log: bool = True) -> None:
learning_rate = context.config.target_network_training.learning_rate
super().__init__(model, learning_rate, do_log)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_cli_entry_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


@no_cli_args
@patch("revnets.main.main.Experiment.run_experiment")
@patch("revnets.main.main.Experiment.run_experiment", return_value={})
def test_main(run: MagicMock) -> None:
entry_point.entry_point()
run.assert_called_once()
Expand Down
16 changes: 16 additions & 0 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,26 @@
pipelines.images.mediumnet_small,
)

all_pipeline_modules = (
*pipeline_modules,
pipelines.mininet.mininet_40,
pipelines.mininet.mininet_100,
pipelines.images.mininet_100,
pipelines.images.mininet_128,
pipelines.images.mininet_200,
pipelines.mediumnet.mediumnet_40,
)


@pytest.mark.parametrize("pipeline_module", pipeline_modules)
def test_target_network_training(
pipeline_module: ModuleType, test_context: None
) -> None:
pipeline: Pipeline = pipeline_module.Pipeline()
pipeline.create_target_network()


@pytest.mark.parametrize("pipeline_module", all_pipeline_modules)
def test_network_factory(pipeline_module: ModuleType, test_context: None) -> None:
pipeline: Pipeline = pipeline_module.Pipeline()
pipeline.create_network_factory()

0 comments on commit 5285fce

Please sign in to comment.