Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add experiments launching #11

Merged
merged 3 commits into from
Apr 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/revnets/context/context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from functools import cached_property
from typing import cast

import torch
from package_utils.context import Context as Context_
Expand All @@ -22,8 +23,10 @@ def output_path(self) -> Path:

@property
def results_path(self) -> Path:
path = self.output_path / "results.yaml"
return path.with_nonexistent_name()
relative_path = self.output_path.relative_to(Path.config)
path = Path.results / relative_path / "results.yaml"
path = path.with_nonexistent_name()
return cast(Path, path)

@property
def log_path(self) -> Path:
Expand Down
8 changes: 6 additions & 2 deletions src/revnets/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@

@dataclass
class DataModule(LightningDataModule):
batch_size: int = context.config.target_network_training.batch_size
batch_size: int = field(
default_factory=lambda: context.config.target_network_training.batch_size
)
evaluation_batch_size: int = 1000
validation_ratio = context.config.validation_ratio
validation_ratio: float = field(
default_factory=lambda: context.config.validation_ratio
)
train: data.Dataset[Any] = field(init=False)
validation: data.Dataset[Any] = field(init=False)
train_validation: data.Dataset[Any] = field(init=False)
Expand Down
1 change: 0 additions & 1 deletion src/revnets/evaluations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from . import analysis, attack, outputs, weights
from .base import Evaluator
from .evaluate import evaluate
from .evaluation import Evaluation
4 changes: 2 additions & 2 deletions src/revnets/evaluations/analysis/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from torch.nn import Module
from torch.utils.data import TensorDataset

from revnets.reconstructions.queries.data import QueryDataSet
from revnets.reconstructions.queries.random import Reconstructor
from revnets.utils.data import compute_targets

from .. import base

Expand Down Expand Up @@ -48,7 +48,7 @@ def visualize_network(self, network: nn.Module, name: str) -> None:

def visualize_model_outputs(self, model: Module, name: str) -> None:
inputs = Reconstructor(self.pipeline).create_queries(self.n_inputs)
outputs = QueryDataSet(model).compute_targets(inputs)
outputs = compute_targets(inputs, model)
if self.activation:
outputs = F.relu(outputs) # pragma: nocover
ActivationsVisualizer(outputs, name).run()
Expand Down
2 changes: 1 addition & 1 deletion src/revnets/evaluations/analysis/trained_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
class Evaluator(base.Evaluator):
def evaluate(self) -> None:
dataset = self.pipeline.load_prepared_data()
network = Network(self.pipeline.target, learning_rate=0)
network = Network(self.pipeline.target)
dataloaders = (
dataset.train_dataloader(),
dataset.val_dataloader(),
Expand Down
2 changes: 1 addition & 1 deletion src/revnets/evaluations/analysis/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from torch.nn import Module

from revnets.context import context
from revnets.standardization import extract_layer_weights, generate_layers

from ...utils.colors import get_colors
from ..weights import layers_mae
from ..weights.standardize import extract_layer_weights, generate_layers

cpu = torch.device("cpu")

Expand Down
3 changes: 1 addition & 2 deletions src/revnets/evaluations/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@ def apply(evaluation_module: ModuleType) -> Any:
evaluator: Evaluator = evaluation_module.Evaluator(reconstruction, pipeline)
return evaluator.get_evaluation()

analysis_modules = (analysis.weights,)

if context.config.evaluation.run_analysis:
analysis_modules = (analysis.weights,)
for analysis_module in analysis_modules:
apply(analysis_module)

Expand Down
3 changes: 2 additions & 1 deletion src/revnets/evaluations/weights/layers_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import torch

from revnets.standardization import extract_layer_weights, generate_layers

from . import mae
from .standardize import extract_layer_weights, generate_layers


class Evaluator(mae.Evaluator):
Expand Down
3 changes: 2 additions & 1 deletion src/revnets/evaluations/weights/mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

import torch

from revnets.standardization import Standardizer, align

from ...context import context
from .. import base
from .standardize import Standardizer, align


@dataclass
Expand Down
6 changes: 2 additions & 4 deletions src/revnets/launching/lauch_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,15 @@
class LaunchPlan:
@classmethod
def reconstruction_techniques(cls) -> Iterator[ModuleType]:
yield reconstructions.cheat
yield reconstructions.queries.random

@classmethod
def pipelines(cls) -> Iterator[ModuleType]:
# yield pipelines.mininet_images.mininet_100
yield pipelines.mininet.mininet

@classmethod
def seeds(cls) -> Iterator[int]:
yield 77
# yield from (77, 78, 79)
yield from range(5)

@classmethod
def experiments_to_launch(cls) -> Iterator[models.Experiment]:
Expand Down
13 changes: 10 additions & 3 deletions src/revnets/launching/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ def main() -> None:
def launch(experiment: Experiment) -> None:
config_dict = cast(dict[str, Any], context.options.config_path.yaml)
experiment.prepare_config(config_dict)
print(experiment.config_path)
cli.run("ls")
# cli.run("revnets", "--config-path", experiment.config_path)
command = "revnets", "--config-path", str(experiment.config_path)
title = "_".join(experiment.names)
launch_command(command, title)


def launch_command(command: tuple[str, ...], title: str) -> None:
command_str = " ".join(command)
shell_command = f"{command_str}; fish"
launch_command = ("tmux", "new-session", "-s", title, "-d", shell_command)
cli.run(launch_command)
9 changes: 5 additions & 4 deletions src/revnets/main/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dataclasses import dataclass
from types import ModuleType
from typing import Any

import cli
import numpy as np
Expand Down Expand Up @@ -28,15 +29,15 @@ def config(self) -> Config:
def run(self) -> None:
set_seed()
cli.console.rule(self.config.title)
self.run_experiment()
results = self.run_experiment()
context.results_path.yaml = results

def run_experiment(self) -> None:
def run_experiment(self) -> dict[str, Any]:
pipeline: Pipeline = extract_module(pipelines, self.config.pipeline).Pipeline()
reconstruction = self.create_reconstruction(pipeline)
evaluation = evaluations.evaluate(reconstruction, pipeline)
evaluation.show()
results = {"metrics": evaluation.dict(), "config": context.config.dict()}
context.results_path.yaml = results
return {"metrics": evaluation.dict(), "config": context.config.dict()}

def create_reconstruction(self, pipeline: Pipeline) -> nn.Module:
module = extract_module(reconstructions, self.config.reconstruction_technique)
Expand Down
14 changes: 7 additions & 7 deletions src/revnets/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ class Evaluation:
class Config(SerializationMixin):
sampling_data_size: int = 20000
reconstruction_training: HyperParameters = HyperParameters(
epochs=300, learning_rate=3e-2, batch_size=256
epochs=300, learning_rate=1e-1, batch_size=256
)
reconstruct_from_checkpoint: bool = True
always_train: bool = False
reconstruct_from_checkpoint: bool = False
always_train: bool = True
n_rounds: int = 2
experiment: Experiment = field(default_factory=Experiment)

Expand Down Expand Up @@ -70,14 +70,14 @@ class Config(SerializationMixin):
console_metrics_refresh_interval: float = 0.5
max_difficult_inputs_epochs: int = 100

limit_batches: int | None = None

@property
def number_of_validation_sanity_steps(self) -> int | None:
return 0 if self.debug else 0

@property
def limit_batches(self) -> int | None:
return self.debug_batch_limit if self.debug else None

def __post_init__(self) -> None:
if self.evaluation.run_analysis:
self.always_train = False
if self.debug:
self.limit_batches = self.debug_batch_limit # pragma: nocover
2 changes: 1 addition & 1 deletion src/revnets/models/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def generate_title_parts(self) -> Iterator[str]:

@property
def path(self) -> Path:
path = Path.results
path = Path.config
for name in self.names:
path /= name
return cast(Path, path)
Expand Down
8 changes: 5 additions & 3 deletions src/revnets/networks/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Iterable
from dataclasses import dataclass
from dataclasses import dataclass, field

import torch
from torch import nn
Expand All @@ -11,10 +11,12 @@
from ..utils import NamedClass


@dataclass(frozen=True)
@dataclass
class NetworkFactory(NamedClass):
output_size: int = 10
activation: Activation = context.config.target_network_training.activation
activation: Activation = field(
default_factory=lambda: context.config.target_network_training.activation
)

def create_activation_layer(self) -> Module:
activation_layer: Module
Expand Down
2 changes: 1 addition & 1 deletion src/revnets/networks/images/mediumnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .. import mediumnet


@dataclass(frozen=True)
@dataclass
class NetworkFactory(mediumnet.NetworkFactory):
input_size: int = 784
hidden_size1: int = 512
Expand Down
2 changes: 1 addition & 1 deletion src/revnets/networks/images/mininet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .. import mininet


@dataclass(frozen=True)
@dataclass
class NetworkFactory(mininet.NetworkFactory):
input_size: int = 784
hidden_size: int = 40
Expand Down
2 changes: 1 addition & 1 deletion src/revnets/networks/mediumnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from . import base


@dataclass(frozen=True)
@dataclass
class NetworkFactory(base.NetworkFactory):
input_size: int = 40
hidden_size1: int = 20
Expand Down
2 changes: 1 addition & 1 deletion src/revnets/networks/mininet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from . import base


@dataclass(frozen=True)
@dataclass
class NetworkFactory(base.NetworkFactory):
input_size: int = 40
hidden_size: int = 20
Expand Down
4 changes: 3 additions & 1 deletion src/revnets/pipelines/images/mediumnet/mediumnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@


class Pipeline(mininet.Pipeline):
network_factory: NetworkFactory = networks.images.mediumnet.NetworkFactory()
@classmethod
def create_network_factory(cls) -> NetworkFactory:
return networks.images.mediumnet.NetworkFactory()
8 changes: 5 additions & 3 deletions src/revnets/pipelines/images/mediumnet/mediumnet_small.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

@dataclass
class Pipeline(mediumnet.Pipeline):
network_factory: NetworkFactory = networks.images.mediumnet.NetworkFactory(
hidden_size1=100, hidden_size2=50
)
@classmethod
def create_network_factory(cls) -> NetworkFactory:
return networks.images.mediumnet.NetworkFactory(
hidden_size1=100, hidden_size2=50
)
6 changes: 3 additions & 3 deletions src/revnets/pipelines/images/mininet/mininet_100.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@


class Pipeline(mininet_small.Pipeline):
network_factory: NetworkFactory = networks.images.mininet.NetworkFactory(
hidden_size=100
)
@classmethod
def create_network_factory(cls) -> NetworkFactory:
return networks.images.mininet.NetworkFactory(hidden_size=100)
6 changes: 3 additions & 3 deletions src/revnets/pipelines/images/mininet/mininet_128.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@


class Pipeline(mininet_small.Pipeline):
network_factory: NetworkFactory = networks.images.mininet.NetworkFactory(
hidden_size=128
)
@classmethod
def create_network_factory(cls) -> NetworkFactory:
return networks.images.mininet.NetworkFactory(hidden_size=128)
6 changes: 3 additions & 3 deletions src/revnets/pipelines/images/mininet/mininet_200.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@


class Pipeline(mininet_small.Pipeline):
network_factory: NetworkFactory = networks.images.mininet.NetworkFactory(
hidden_size=200
)
@classmethod
def create_network_factory(cls) -> NetworkFactory:
return networks.images.mininet.NetworkFactory(hidden_size=200)
6 changes: 3 additions & 3 deletions src/revnets/pipelines/images/mininet/mininet_small.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

@dataclass
class Pipeline(train.Pipeline):
network_factory: NetworkFactory = networks.images.mininet.NetworkFactory(
hidden_size=40
)
@classmethod
def create_network_factory(cls) -> NetworkFactory:
return networks.images.mininet.NetworkFactory(hidden_size=40)

@classmethod
def load_data(cls) -> data.mnist.DataModule:
Expand Down
4 changes: 3 additions & 1 deletion src/revnets/pipelines/mediumnet/mediumnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@


class Pipeline(mininet.Pipeline):
network_factory: NetworkFactory = networks.mediumnet.NetworkFactory()
@classmethod
def create_network_factory(cls) -> NetworkFactory:
return networks.mediumnet.NetworkFactory()
6 changes: 3 additions & 3 deletions src/revnets/pipelines/mediumnet/mediumnet_40.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@


class Pipeline(mininet.Pipeline):
network_factory: NetworkFactory = networks.mediumnet.NetworkFactory(
hidden_size1=40, hidden_size2=20
)
@classmethod
def create_network_factory(cls) -> NetworkFactory:
return networks.mediumnet.NetworkFactory(hidden_size1=40, hidden_size2=20)
4 changes: 3 additions & 1 deletion src/revnets/pipelines/mininet/mininet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

@dataclass
class Pipeline(train.Pipeline):
network_factory: NetworkFactory = mininet.NetworkFactory()
@classmethod
def create_network_factory(cls) -> NetworkFactory:
return mininet.NetworkFactory()

@classmethod
def load_data(cls) -> mnist1d.DataModule:
Expand Down
4 changes: 3 additions & 1 deletion src/revnets/pipelines/mininet/mininet_100.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@

@dataclass
class Pipeline(mininet.Pipeline):
network_factory: NetworkFactory = networks.mininet.NetworkFactory(hidden_size=100)
@classmethod
def create_network_factory(cls) -> NetworkFactory:
return networks.mininet.NetworkFactory(hidden_size=100)
4 changes: 3 additions & 1 deletion src/revnets/pipelines/mininet/mininet_40.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@

@dataclass
class Pipeline(mininet.Pipeline):
network_factory: NetworkFactory = networks.mininet.NetworkFactory(hidden_size=40)
@classmethod
def create_network_factory(cls) -> NetworkFactory:
return networks.mininet.NetworkFactory(hidden_size=40)
Loading