Skip to content

Commit

Permalink
add experiments launching (#11)
Browse files Browse the repository at this point in the history
* add experiments launching
  • Loading branch information
quintenroets committed Apr 14, 2024
1 parent 7f021b9 commit a48aaa8
Show file tree
Hide file tree
Showing 51 changed files with 190 additions and 138 deletions.
7 changes: 5 additions & 2 deletions src/revnets/context/context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from functools import cached_property
from typing import cast

import torch
from package_utils.context import Context as Context_
Expand All @@ -22,8 +23,10 @@ def output_path(self) -> Path:

@property
def results_path(self) -> Path:
path = self.output_path / "results.yaml"
return path.with_nonexistent_name()
relative_path = self.output_path.relative_to(Path.config)
path = Path.results / relative_path / "results.yaml"
path = path.with_nonexistent_name()
return cast(Path, path)

@property
def log_path(self) -> Path:
Expand Down
8 changes: 6 additions & 2 deletions src/revnets/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@

@dataclass
class DataModule(LightningDataModule):
batch_size: int = context.config.target_network_training.batch_size
batch_size: int = field(
default_factory=lambda: context.config.target_network_training.batch_size
)
evaluation_batch_size: int = 1000
validation_ratio = context.config.validation_ratio
validation_ratio: float = field(
default_factory=lambda: context.config.validation_ratio
)
train: data.Dataset[Any] = field(init=False)
validation: data.Dataset[Any] = field(init=False)
train_validation: data.Dataset[Any] = field(init=False)
Expand Down
1 change: 0 additions & 1 deletion src/revnets/evaluations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from . import analysis, attack, outputs, weights
from .base import Evaluator
from .evaluate import evaluate
from .evaluation import Evaluation
4 changes: 2 additions & 2 deletions src/revnets/evaluations/analysis/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from torch.nn import Module
from torch.utils.data import TensorDataset

from revnets.reconstructions.queries.data import QueryDataSet
from revnets.reconstructions.queries.random import Reconstructor
from revnets.utils.data import compute_targets

from .. import base

Expand Down Expand Up @@ -48,7 +48,7 @@ def visualize_network(self, network: nn.Module, name: str) -> None:

def visualize_model_outputs(self, model: Module, name: str) -> None:
inputs = Reconstructor(self.pipeline).create_queries(self.n_inputs)
outputs = QueryDataSet(model).compute_targets(inputs)
outputs = compute_targets(inputs, model)
if self.activation:
outputs = F.relu(outputs) # pragma: nocover
ActivationsVisualizer(outputs, name).run()
Expand Down
2 changes: 1 addition & 1 deletion src/revnets/evaluations/analysis/trained_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
class Evaluator(base.Evaluator):
def evaluate(self) -> None:
dataset = self.pipeline.load_prepared_data()
network = Network(self.pipeline.target, learning_rate=0)
network = Network(self.pipeline.target)
dataloaders = (
dataset.train_dataloader(),
dataset.val_dataloader(),
Expand Down
2 changes: 1 addition & 1 deletion src/revnets/evaluations/analysis/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from torch.nn import Module

from revnets.context import context
from revnets.standardization import extract_layer_weights, generate_layers

from ...utils.colors import get_colors
from ..weights import layers_mae
from ..weights.standardize import extract_layer_weights, generate_layers

cpu = torch.device("cpu")

Expand Down
3 changes: 1 addition & 2 deletions src/revnets/evaluations/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@ def apply(evaluation_module: ModuleType) -> Any:
evaluator: Evaluator = evaluation_module.Evaluator(reconstruction, pipeline)
return evaluator.get_evaluation()

analysis_modules = (analysis.weights,)

if context.config.evaluation.run_analysis:
analysis_modules = (analysis.weights,)
for analysis_module in analysis_modules:
apply(analysis_module)

Expand Down
3 changes: 2 additions & 1 deletion src/revnets/evaluations/weights/layers_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import torch

from revnets.standardization import extract_layer_weights, generate_layers

from . import mae
from .standardize import extract_layer_weights, generate_layers


class Evaluator(mae.Evaluator):
Expand Down
3 changes: 2 additions & 1 deletion src/revnets/evaluations/weights/mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

import torch

from revnets.standardization import Standardizer, align

from ...context import context
from .. import base
from .standardize import Standardizer, align


@dataclass
Expand Down
6 changes: 2 additions & 4 deletions src/revnets/launching/lauch_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,15 @@
class LaunchPlan:
@classmethod
def reconstruction_techniques(cls) -> Iterator[ModuleType]:
yield reconstructions.cheat
yield reconstructions.queries.random

@classmethod
def pipelines(cls) -> Iterator[ModuleType]:
# yield pipelines.mininet_images.mininet_100
yield pipelines.mininet.mininet

@classmethod
def seeds(cls) -> Iterator[int]:
yield 77
# yield from (77, 78, 79)
yield from range(5)

@classmethod
def experiments_to_launch(cls) -> Iterator[models.Experiment]:
Expand Down
13 changes: 10 additions & 3 deletions src/revnets/launching/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ def main() -> None:
def launch(experiment: Experiment) -> None:
config_dict = cast(dict[str, Any], context.options.config_path.yaml)
experiment.prepare_config(config_dict)
print(experiment.config_path)
cli.run("ls")
# cli.run("revnets", "--config-path", experiment.config_path)
command = "revnets", "--config-path", str(experiment.config_path)
title = "_".join(experiment.names)
launch_command(command, title)


def launch_command(command: tuple[str, ...], title: str) -> None:
command_str = " ".join(command)
shell_command = f"{command_str}; fish"
launch_command = ("tmux", "new-session", "-s", title, "-d", shell_command)
cli.run(launch_command)
9 changes: 5 additions & 4 deletions src/revnets/main/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dataclasses import dataclass
from types import ModuleType
from typing import Any

import cli
import numpy as np
Expand Down Expand Up @@ -28,15 +29,15 @@ def config(self) -> Config:
def run(self) -> None:
set_seed()
cli.console.rule(self.config.title)
self.run_experiment()
results = self.run_experiment()
context.results_path.yaml = results

def run_experiment(self) -> None:
def run_experiment(self) -> dict[str, Any]:
pipeline: Pipeline = extract_module(pipelines, self.config.pipeline).Pipeline()
reconstruction = self.create_reconstruction(pipeline)
evaluation = evaluations.evaluate(reconstruction, pipeline)
evaluation.show()
results = {"metrics": evaluation.dict(), "config": context.config.dict()}
context.results_path.yaml = results
return {"metrics": evaluation.dict(), "config": context.config.dict()}

def create_reconstruction(self, pipeline: Pipeline) -> nn.Module:
module = extract_module(reconstructions, self.config.reconstruction_technique)
Expand Down
14 changes: 7 additions & 7 deletions src/revnets/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ class Evaluation:
class Config(SerializationMixin):
sampling_data_size: int = 20000
reconstruction_training: HyperParameters = HyperParameters(
epochs=300, learning_rate=3e-2, batch_size=256
epochs=300, learning_rate=1e-1, batch_size=256
)
reconstruct_from_checkpoint: bool = True
always_train: bool = False
reconstruct_from_checkpoint: bool = False
always_train: bool = True
n_rounds: int = 2
experiment: Experiment = field(default_factory=Experiment)

Expand Down Expand Up @@ -70,14 +70,14 @@ class Config(SerializationMixin):
console_metrics_refresh_interval: float = 0.5
max_difficult_inputs_epochs: int = 100

limit_batches: int | None = None

@property
def number_of_validation_sanity_steps(self) -> int | None:
return 0 if self.debug else 0

@property
def limit_batches(self) -> int | None:
return self.debug_batch_limit if self.debug else None

def __post_init__(self) -> None:
if self.evaluation.run_analysis:
self.always_train = False
if self.debug:
self.limit_batches = self.debug_batch_limit # pragma: nocover
2 changes: 1 addition & 1 deletion src/revnets/models/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def generate_title_parts(self) -> Iterator[str]:

@property
def path(self) -> Path:
path = Path.results
path = Path.config
for name in self.names:
path /= name
return cast(Path, path)
Expand Down
8 changes: 5 additions & 3 deletions src/revnets/networks/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Iterable
from dataclasses import dataclass
from dataclasses import dataclass, field

import torch
from torch import nn
Expand All @@ -11,10 +11,12 @@
from ..utils import NamedClass


@dataclass(frozen=True)
@dataclass
class NetworkFactory(NamedClass):
output_size: int = 10
activation: Activation = context.config.target_network_training.activation
activation: Activation = field(
default_factory=lambda: context.config.target_network_training.activation
)

def create_activation_layer(self) -> Module:
activation_layer: Module
Expand Down
2 changes: 1 addition & 1 deletion src/revnets/networks/images/mediumnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .. import mediumnet


@dataclass(frozen=True)
@dataclass
class NetworkFactory(mediumnet.NetworkFactory):
input_size: int = 784
hidden_size1: int = 512
Expand Down
2 changes: 1 addition & 1 deletion src/revnets/networks/images/mininet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .. import mininet


@dataclass(frozen=True)
@dataclass
class NetworkFactory(mininet.NetworkFactory):
input_size: int = 784
hidden_size: int = 40
Expand Down
2 changes: 1 addition & 1 deletion src/revnets/networks/mediumnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from . import base


@dataclass(frozen=True)
@dataclass
class NetworkFactory(base.NetworkFactory):
input_size: int = 40
hidden_size1: int = 20
Expand Down
2 changes: 1 addition & 1 deletion src/revnets/networks/mininet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from . import base


@dataclass(frozen=True)
@dataclass
class NetworkFactory(base.NetworkFactory):
input_size: int = 40
hidden_size: int = 20
Expand Down
4 changes: 3 additions & 1 deletion src/revnets/pipelines/images/mediumnet/mediumnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@


class Pipeline(mininet.Pipeline):
network_factory: NetworkFactory = networks.images.mediumnet.NetworkFactory()
@classmethod
def create_network_factory(cls) -> NetworkFactory:
return networks.images.mediumnet.NetworkFactory()
8 changes: 5 additions & 3 deletions src/revnets/pipelines/images/mediumnet/mediumnet_small.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

@dataclass
class Pipeline(mediumnet.Pipeline):
network_factory: NetworkFactory = networks.images.mediumnet.NetworkFactory(
hidden_size1=100, hidden_size2=50
)
@classmethod
def create_network_factory(cls) -> NetworkFactory:
return networks.images.mediumnet.NetworkFactory(
hidden_size1=100, hidden_size2=50
)
6 changes: 3 additions & 3 deletions src/revnets/pipelines/images/mininet/mininet_100.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@


class Pipeline(mininet_small.Pipeline):
network_factory: NetworkFactory = networks.images.mininet.NetworkFactory(
hidden_size=100
)
@classmethod
def create_network_factory(cls) -> NetworkFactory:
return networks.images.mininet.NetworkFactory(hidden_size=100)
6 changes: 3 additions & 3 deletions src/revnets/pipelines/images/mininet/mininet_128.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@


class Pipeline(mininet_small.Pipeline):
network_factory: NetworkFactory = networks.images.mininet.NetworkFactory(
hidden_size=128
)
@classmethod
def create_network_factory(cls) -> NetworkFactory:
return networks.images.mininet.NetworkFactory(hidden_size=128)
6 changes: 3 additions & 3 deletions src/revnets/pipelines/images/mininet/mininet_200.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@


class Pipeline(mininet_small.Pipeline):
network_factory: NetworkFactory = networks.images.mininet.NetworkFactory(
hidden_size=200
)
@classmethod
def create_network_factory(cls) -> NetworkFactory:
return networks.images.mininet.NetworkFactory(hidden_size=200)
6 changes: 3 additions & 3 deletions src/revnets/pipelines/images/mininet/mininet_small.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

@dataclass
class Pipeline(train.Pipeline):
network_factory: NetworkFactory = networks.images.mininet.NetworkFactory(
hidden_size=40
)
@classmethod
def create_network_factory(cls) -> NetworkFactory:
return networks.images.mininet.NetworkFactory(hidden_size=40)

@classmethod
def load_data(cls) -> data.mnist.DataModule:
Expand Down
4 changes: 3 additions & 1 deletion src/revnets/pipelines/mediumnet/mediumnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@


class Pipeline(mininet.Pipeline):
network_factory: NetworkFactory = networks.mediumnet.NetworkFactory()
@classmethod
def create_network_factory(cls) -> NetworkFactory:
return networks.mediumnet.NetworkFactory()
6 changes: 3 additions & 3 deletions src/revnets/pipelines/mediumnet/mediumnet_40.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@


class Pipeline(mininet.Pipeline):
network_factory: NetworkFactory = networks.mediumnet.NetworkFactory(
hidden_size1=40, hidden_size2=20
)
@classmethod
def create_network_factory(cls) -> NetworkFactory:
return networks.mediumnet.NetworkFactory(hidden_size1=40, hidden_size2=20)
4 changes: 3 additions & 1 deletion src/revnets/pipelines/mininet/mininet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

@dataclass
class Pipeline(train.Pipeline):
network_factory: NetworkFactory = mininet.NetworkFactory()
@classmethod
def create_network_factory(cls) -> NetworkFactory:
return mininet.NetworkFactory()

@classmethod
def load_data(cls) -> mnist1d.DataModule:
Expand Down
4 changes: 3 additions & 1 deletion src/revnets/pipelines/mininet/mininet_100.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@

@dataclass
class Pipeline(mininet.Pipeline):
network_factory: NetworkFactory = networks.mininet.NetworkFactory(hidden_size=100)
@classmethod
def create_network_factory(cls) -> NetworkFactory:
return networks.mininet.NetworkFactory(hidden_size=100)
4 changes: 3 additions & 1 deletion src/revnets/pipelines/mininet/mininet_40.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@

@dataclass
class Pipeline(mininet.Pipeline):
network_factory: NetworkFactory = networks.mininet.NetworkFactory(hidden_size=40)
@classmethod
def create_network_factory(cls) -> NetworkFactory:
return networks.mininet.NetworkFactory(hidden_size=40)
Loading

0 comments on commit a48aaa8

Please sign in to comment.