Skip to content

Commit

Permalink
increase test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
quintenroets committed Apr 13, 2024
1 parent aaf024c commit bccd449
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 15 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Revnets
![Python version](https://img.shields.io/badge/python-3.10+-brightgreen)
![Operating system](https://img.shields.io/badge/os-linux%20%7c%20macOS-brightgreen)
![Coverage](https://img.shields.io/badge/coverage-87%25-brightgreen)
![Coverage](https://img.shields.io/badge/coverage-92%25-brightgreen)

Reverse engineer internal parameters of black box neural networks

Expand Down
2 changes: 1 addition & 1 deletion src/revnets/evaluations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from . import attack, outputs, weights
from . import analysis, attack, outputs, weights
from .evaluate import evaluate
from .evaluation import Evaluation
2 changes: 1 addition & 1 deletion src/revnets/evaluations/attack/attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def evaluate(self) -> Evaluation:
return self.compare_attacks()

def compare_attacks(self) -> Evaluation:
data = self.load_data()
data = self.pipeline.load_prepared_data()
model = AttackNetwork(self.original, self.reconstruction)
dataloader = data.test_dataloader()
precision = 32 # adversarial attack library only works with precision 32
Expand Down
1 change: 0 additions & 1 deletion src/revnets/evaluations/attack/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ def configure_attack(self, inputs: torch.Tensor) -> None:
self.model_under_attack = PyTorchClassifier(
model=self.reconstruction,
loss=torch.nn.CrossEntropyLoss(),
optimizer=self.reconstruction.configure_optimizers(),
input_shape=inputs.shape[1:],
nb_classes=outputs.shape[-1],
device_type="gpu",
Expand Down
6 changes: 0 additions & 6 deletions src/revnets/evaluations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from revnets.context import context
from revnets.pipelines import Pipeline

from ..data import DataModule


@dataclass
class Evaluator:
Expand Down Expand Up @@ -40,7 +38,3 @@ def format_evaluation(

def evaluate(self) -> Any:
raise NotImplementedError

def load_data(self) -> DataModule:
assert self.pipeline is not None
return self.pipeline.load_prepared_data()
2 changes: 1 addition & 1 deletion src/revnets/evaluations/outputs/val.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def evaluate(self) -> Metrics:

def compare_outputs(self) -> Metrics:
model = CompareModel(self.original, self.reconstruction)
data = self.load_data()
data = self.pipeline.load_prepared_data()
dataloader = self.extract_dataloader(data)
Trainer().test(model, dataloaders=dataloader) # noqa
return cast(Metrics, model.metrics)
Expand Down
35 changes: 31 additions & 4 deletions tests/test_evaluations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,31 @@

import pytest
from revnets import evaluations, pipelines, reconstructions
from revnets.evaluations import analysis, attack, outputs, weights
from revnets.pipelines import Pipeline

pipeline_modules = (pipelines.mininet,)
evaluation_modules = (
weights.mse,
weights.mae,
weights.max_ae,
weights.layers_mae,
weights.named_layers_mae,
attack.attack,
outputs.train,
outputs.val,
outputs.test,
analysis.weights,
analysis.activations,
analysis.trained_target,
)


@pytest.mark.parametrize("pipeline_module", pipeline_modules)
def test_cheat_evaluations(pipeline_module: ModuleType) -> None:
pipeline = pipeline_module.Pipeline()
@pytest.fixture
def pipeline() -> Pipeline:
return pipelines.mininet.Pipeline()


def test_cheat_evaluations(pipeline: Pipeline) -> None:
reconstructor = reconstructions.cheat.Reconstructor(pipeline)
reconstruction = reconstructor.create_reconstruction()

Expand All @@ -24,3 +42,12 @@ def test_cheat_evaluations(pipeline_module: ModuleType) -> None:
for value in perfect_metrics:
if value is not None and value != "/":
assert math.isclose(float(value), 0, abs_tol=1e-5)


@pytest.mark.parametrize("evaluation_module", evaluation_modules)
def test_evaluations(
evaluation_module: ModuleType, pipeline: Pipeline, test_context: None
) -> None:
reconstructor = reconstructions.empty.Reconstructor(pipeline)
reconstruction = reconstructor.create_reconstruction()
evaluation_module.Evaluator(reconstruction, pipeline).evaluate()

0 comments on commit bccd449

Please sign in to comment.