Skip to content

Commit

Permalink
add CNN (#12)
Browse files Browse the repository at this point in the history
* generalize standardization

* test generalized standardization
  • Loading branch information
quintenroets committed Apr 15, 2024
1 parent a48aaa8 commit 68a28fe
Show file tree
Hide file tree
Showing 42 changed files with 385 additions and 182 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dependencies = [
"simple-classproperty >=4.0.2, <5",
"superpathlib >=2.0.2, <3",
"torch >=2.2.2, <3",
"torchsummary >=1.5.1, <2",
"torchvision >=0.17.2, <1",
"torchtoolbox >=0.1.8.2, <1",
"torchtext >=0.17.2, <1",
Expand Down Expand Up @@ -62,6 +63,7 @@ module = [
"kneed.*",
"scipy.*",
"sklearn.*",
"torchsummary.*",
"torchvision.*",
]
ignore_missing_imports = true
Expand Down
7 changes: 6 additions & 1 deletion src/revnets/data/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Any, TypeVar
from typing import Any, TypeVar, cast

import torch
from pytorch_lightning import LightningDataModule
Expand Down Expand Up @@ -49,3 +49,8 @@ def test_dataloader(self) -> DataLoader[Any]:
return DataLoader(
self.test, batch_size=self.evaluation_batch_size, shuffle=False
)

@property
def input_shape(self) -> tuple[int, ...]:
inputs, target = self.train_validation[0]
return cast(tuple[int, ...], inputs.shape)
11 changes: 8 additions & 3 deletions src/revnets/evaluations/analysis/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from torch.nn import Module
from torch.utils.data import TensorDataset

from revnets.reconstructions.queries.random import Reconstructor
from revnets.utils.data import compute_targets

from .. import base
Expand All @@ -31,7 +30,7 @@ def evaluate(self) -> None:
self.visualize_network(model, name)

def visualize_random_inputs(self) -> None:
inputs = Reconstructor(self.pipeline).create_queries(self.n_inputs)
inputs = self.create_queries()
ActivationsVisualizer(inputs, "random inputs").run()

def visualize_train_inputs(self) -> None:
Expand All @@ -47,12 +46,18 @@ def visualize_network(self, network: nn.Module, name: str) -> None:
self.visualize_model_outputs(model, name)

def visualize_model_outputs(self, model: Module, name: str) -> None:
inputs = Reconstructor(self.pipeline).create_queries(self.n_inputs)
inputs = self.create_queries()
outputs = compute_targets(inputs, model)
if self.activation:
outputs = F.relu(outputs) # pragma: nocover
ActivationsVisualizer(outputs, name).run()

def create_queries(self) -> torch.Tensor:
# Circular import: reconstructions should import evaluations
from revnets.reconstructions.queries.random import Reconstructor

return Reconstructor(self.pipeline).create_queries(self.n_inputs)


@dataclass
class ActivationsVisualizer:
Expand Down
22 changes: 2 additions & 20 deletions src/revnets/evaluations/analysis/weights.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from collections.abc import Iterator
from dataclasses import dataclass
from typing import Any

Expand All @@ -7,7 +6,6 @@
from torch.nn import Module

from revnets.context import context
from revnets.standardization import extract_layer_weights, generate_layers

from ...utils.colors import get_colors
from ..weights import layers_mae
Expand All @@ -27,23 +25,16 @@ def evaluate(self) -> None:
self.visualize_network_differences()

def visualize_network_weights(self, network: Module, name: str) -> None:
layer_weights = self.generate_layer_weights(network)
layer_weights = layers_mae.generate_layer_weights(network)
for i, weights in enumerate(layer_weights):
title = f"{name} layer {i + 1} weights".capitalize()
self.visualize_layer_weights(weights, title)

def generate_layer_weights(self, network: Module) -> Iterator[torch.Tensor]:
layers = generate_layers(network)
for layer in layers:
weights = self.extract_layer_weights(layer)
if weights is not None:
yield weights

@classmethod
def visualize_layer_weights(
cls, weights: torch.Tensor, title: str, n_show: int | None = None
) -> None:
weights = weights[:n_show]
weights = weights[:n_show].cpu()

# weights = torch.transpose(weights, 0, 1)

Expand Down Expand Up @@ -84,12 +75,3 @@ def create_figure(cls) -> Any:
def show(cls) -> None:
plt.legend(loc="upper left", bbox_to_anchor=(1, 1))
plt.show()

def iterate_compared_layers(
self, device: torch.device | None = cpu
) -> Iterator[tuple[torch.Tensor, torch.Tensor]]:
return super().iterate_compared_layers(device=device)

@classmethod
def extract_layer_weights(cls, layer: Module) -> torch.Tensor | None:
return extract_layer_weights(layer, device=cpu)
43 changes: 21 additions & 22 deletions src/revnets/evaluations/weights/layers_mae.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,35 @@
from collections.abc import Iterator

import torch
from torch.nn import Module
from torch.nn.functional import l1_loss

from revnets.standardization import extract_layer_weights, generate_layers
from revnets.standardization import extract_weights, generate_layers

from . import mae


def generate_layer_weights(model: Module) -> Iterator[torch.Tensor]:
for layer in generate_layers(model):
try:
yield extract_weights(layer)
except StopIteration:
pass


class Evaluator(mae.Evaluator):
def iterate_compared_layers(
self, device: torch.device | None = None
) -> Iterator[tuple[torch.Tensor, torch.Tensor]]:
original_layers = generate_layers(self.target)
reconstruction_layers = generate_layers(self.reconstruction)
for original, reconstruction in zip(original_layers, reconstruction_layers):
original_weights = extract_layer_weights(original, device)
reconstruction_weights = extract_layer_weights(reconstruction, device)
if original_weights is not None and reconstruction_weights is not None:
yield original_weights, reconstruction_weights

def calculate_distance(self) -> tuple[float, ...]:
return tuple(
self.calculate_weights_distance(original, reconstructed)
for original, reconstructed in self.iterate_compared_layers()
)
def iterate_compared_layers(self) -> Iterator[tuple[torch.Tensor, torch.Tensor]]:
networks = self.target, self.reconstruction
weights_pair = [generate_layer_weights(self.target) for network in networks]
yield from zip(*weights_pair)

def calculate_total_distance(self) -> tuple[float, ...]:
pairs = self.iterate_compared_layers()
return tuple(self.calculate_distance(*pair) for pair in pairs)

@classmethod
def calculate_weights_distance(
cls, original_weights: torch.Tensor, reconstructed_weights: torch.Tensor
) -> float:
distance = torch.nn.functional.l1_loss(original_weights, reconstructed_weights)
return distance.item()
def calculate_distance(cls, values: torch.Tensor, other: torch.Tensor) -> float:
return l1_loss(values, other).item()

@classmethod
def format_evaluation(cls, value: tuple[float, ...], precision: int = 3) -> str: # type: ignore[override]
Expand Down
10 changes: 3 additions & 7 deletions src/revnets/evaluations/weights/mae.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
import torch
from torch.nn.functional import l1_loss

from . import mse


class Evaluator(mse.Evaluator):
@classmethod
def calculate_weights_distance(
cls, original_weights: torch.Tensor, reconstructed_weights: torch.Tensor
) -> float:
distance = torch.nn.functional.l1_loss(
original_weights, reconstructed_weights, reduction="sum"
)
return distance.item()
def calculate_distance(cls, values: torch.Tensor, other: torch.Tensor) -> float:
return l1_loss(values, other, reduction="sum").item()
15 changes: 5 additions & 10 deletions src/revnets/evaluations/weights/max_ae.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,16 @@
import torch
from torch.nn.functional import l1_loss

from . import mae


class Evaluator(mae.Evaluator):
def calculate_distance(self) -> float:
def calculate_total_distance(self) -> float:
return max(
self.calculate_weights_distance(original, reconstruction)
self.calculate_distance(original, reconstruction)
for original, reconstruction in self.iterate_compared_layers()
)

@classmethod
def calculate_weights_distance(
cls, original_weights: torch.Tensor, reconstructed_weights: torch.Tensor
) -> float:
distances = torch.nn.functional.l1_loss(
original_weights, reconstructed_weights, reduction="none"
)
distance = distances.max()
return distance.item()
def calculate_distance(cls, values: torch.Tensor, other: torch.Tensor) -> float:
return l1_loss(values, other, reduction="none").max().item()
16 changes: 6 additions & 10 deletions src/revnets/evaluations/weights/mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import cast

import torch
from torch.nn.functional import mse_loss

from revnets.standardization import Standardizer, align

Expand All @@ -13,7 +14,7 @@
@dataclass
class Evaluator(base.Evaluator):
def evaluate(self) -> float | tuple[float, ...] | None:
return self.calculate_distance() if self.standardize_networks() else None
return self.calculate_total_distance() if self.standardize_networks() else None

def standardize_networks(self) -> bool:
standardized = self.has_same_architecture()
Expand All @@ -31,24 +32,19 @@ def has_same_architecture(self) -> bool:
for original, reconstruction in self.iterate_compared_layers()
)

def calculate_distance(self) -> float | tuple[float, ...]:
def calculate_total_distance(self) -> float | tuple[float, ...]:
layer_weights = self.target.state_dict().values()
total_size = sum(weights.numel() for weights in layer_weights)
total_distance = sum(
self.calculate_weights_distance(original, reconstruction)
self.calculate_distance(original, reconstruction)
for original, reconstruction in self.iterate_compared_layers()
)
distance = total_distance / total_size
return cast(float, distance)

@classmethod
def calculate_weights_distance(
cls, original: torch.Tensor, reconstruction: torch.Tensor
) -> float:
distance = torch.nn.functional.mse_loss(
original, reconstruction, reduction="sum"
)
return distance.item()
def calculate_distance(cls, values: torch.Tensor, other: torch.Tensor) -> float:
return mse_loss(values, other, reduction="sum").item()

def iterate_compared_layers(self) -> Iterator[tuple[torch.Tensor, torch.Tensor]]:
yield from zip(
Expand Down
4 changes: 2 additions & 2 deletions src/revnets/evaluations/weights/named_layers_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@


class Evaluator(layers_mae.Evaluator):
def calculate_distance(self) -> dict[str, float]: # type: ignore[override]
def calculate_total_distance(self) -> dict[str, float]: # type: ignore[override]
return {
name: self.calculate_weights_distance(original, reconstructed)
name: self.calculate_distance(original, reconstructed)
for name, original, reconstructed in self.iterate_named_compared_layers()
}

Expand Down
8 changes: 8 additions & 0 deletions src/revnets/main/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import torch
from torch import nn
from torchsummary import summary

from revnets import pipelines
from revnets.models import Experiment as Config
Expand Down Expand Up @@ -34,11 +35,18 @@ def run(self) -> None:

def run_experiment(self) -> dict[str, Any]:
pipeline: Pipeline = extract_module(pipelines, self.config.pipeline).Pipeline()
self.log_number_of_parameters(pipeline)
reconstruction = self.create_reconstruction(pipeline)
evaluation = evaluations.evaluate(reconstruction, pipeline)
evaluation.show()
return {"metrics": evaluation.dict(), "config": context.config.dict()}

def log_number_of_parameters(self, pipeline: Pipeline) -> None:
network = pipeline.create_initialized_network()
network = network.to(dtype=torch.float32).to(context.device)
data = pipeline.load_prepared_data()
summary(network, data.input_shape)

def create_reconstruction(self, pipeline: Pipeline) -> nn.Module:
module = extract_module(reconstructions, self.config.reconstruction_technique)
reconstructor: Reconstructor = module.Reconstructor(pipeline)
Expand Down
10 changes: 6 additions & 4 deletions src/revnets/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ class Evaluation:

@dataclass
class Config(SerializationMixin):
sampling_data_size: int = 20000
sampling_data_size: int = 102400
reconstruction_training: HyperParameters = HyperParameters(
epochs=300, learning_rate=1e-1, batch_size=256
epochs=300, learning_rate=1e-2, batch_size=256
)
reconstruct_from_checkpoint: bool = False
always_train: bool = True
Expand All @@ -47,11 +47,14 @@ class Config(SerializationMixin):
target_network_training: HyperParameters = HyperParameters(
epochs=100, learning_rate=1.0e-2, batch_size=32
)
difficult_inputs_training: HyperParameters = HyperParameters(
epochs=1000, learning_rate=1.0e-3
)
evaluation: Evaluation = field(default_factory=Evaluation)
evaluation_batch_size: int = 1000

num_workers: int = 8
early_stopping_patience: int = 20
early_stopping_patience: int = 10
n_networks: int = 2
visualization_interval = 10
weight_variance_downscale_factor: float | None = None
Expand All @@ -68,7 +71,6 @@ class Config(SerializationMixin):
validation_ratio: float = 0.1

console_metrics_refresh_interval: float = 0.5
max_difficult_inputs_epochs: int = 100

limit_batches: int | None = None

Expand Down
1 change: 1 addition & 0 deletions src/revnets/networks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class NetworkFactory(NamedClass):
activation: Activation = field(
default_factory=lambda: context.config.target_network_training.activation
)
input_shape: tuple[int, ...] | None = None

def create_activation_layer(self) -> Module:
activation_layer: Module
Expand Down
2 changes: 1 addition & 1 deletion src/revnets/networks/images/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from . import mediumnet, mininet
from . import cnn, mediumnet, mininet
1 change: 1 addition & 0 deletions src/revnets/networks/images/cnn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import lenet, mini
34 changes: 34 additions & 0 deletions src/revnets/networks/images/cnn/lenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from collections.abc import Iterable
from dataclasses import dataclass

from torch import nn

from . import mini


@dataclass
class NetworkFactory(mini.NetworkFactory):
hidden_size1: int = 120
hidden_size2: int = 84

def create_layers(self) -> Iterable[nn.Module]:
yield from (
# 28 x 28 x 1
nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1),
nn.Tanh(),
# 24 x 24 x 6
nn.AvgPool2d(kernel_size=2, stride=2),
# 12 x 12 x 6
nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1),
nn.Tanh(),
# 8 x 8 x 16
nn.AvgPool2d(kernel_size=2, stride=2),
# 4 x 4 x 16
nn.Flatten(),
# 256
nn.Linear(in_features=256, out_features=self.hidden_size1),
nn.Tanh(),
nn.Linear(in_features=self.hidden_size1, out_features=self.hidden_size2),
nn.Tanh(),
nn.Linear(in_features=self.hidden_size2, out_features=self.output_size),
)
Loading

0 comments on commit 68a28fe

Please sign in to comment.