diff --git a/pyproject.toml b/pyproject.toml index f7c5b60..d3900ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ "simple-classproperty >=4.0.2, <5", "superpathlib >=2.0.2, <3", "torch >=2.2.2, <3", + "torchsummary >=1.5.1, <2", "torchvision >=0.17.2, <1", "torchtoolbox >=0.1.8.2, <1", "torchtext >=0.17.2, <1", @@ -62,6 +63,7 @@ module = [ "kneed.*", "scipy.*", "sklearn.*", + "torchsummary.*", "torchvision.*", ] ignore_missing_imports = true diff --git a/src/revnets/data/base.py b/src/revnets/data/base.py index 2105816..c232b10 100644 --- a/src/revnets/data/base.py +++ b/src/revnets/data/base.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Any, TypeVar +from typing import Any, TypeVar, cast import torch from pytorch_lightning import LightningDataModule @@ -49,3 +49,8 @@ def test_dataloader(self) -> DataLoader[Any]: return DataLoader( self.test, batch_size=self.evaluation_batch_size, shuffle=False ) + + @property + def input_shape(self) -> tuple[int, ...]: + inputs, target = self.train_validation[0] + return cast(tuple[int, ...], inputs.shape) diff --git a/src/revnets/evaluations/analysis/activations.py b/src/revnets/evaluations/analysis/activations.py index 4fa17f1..db34199 100644 --- a/src/revnets/evaluations/analysis/activations.py +++ b/src/revnets/evaluations/analysis/activations.py @@ -9,7 +9,6 @@ from torch.nn import Module from torch.utils.data import TensorDataset -from revnets.reconstructions.queries.random import Reconstructor from revnets.utils.data import compute_targets from .. import base @@ -31,7 +30,7 @@ def evaluate(self) -> None: self.visualize_network(model, name) def visualize_random_inputs(self) -> None: - inputs = Reconstructor(self.pipeline).create_queries(self.n_inputs) + inputs = self.create_queries() ActivationsVisualizer(inputs, "random inputs").run() def visualize_train_inputs(self) -> None: @@ -47,12 +46,18 @@ def visualize_network(self, network: nn.Module, name: str) -> None: self.visualize_model_outputs(model, name) def visualize_model_outputs(self, model: Module, name: str) -> None: - inputs = Reconstructor(self.pipeline).create_queries(self.n_inputs) + inputs = self.create_queries() outputs = compute_targets(inputs, model) if self.activation: outputs = F.relu(outputs) # pragma: nocover ActivationsVisualizer(outputs, name).run() + def create_queries(self) -> torch.Tensor: + # Circular import: reconstructions should import evaluations + from revnets.reconstructions.queries.random import Reconstructor + + return Reconstructor(self.pipeline).create_queries(self.n_inputs) + @dataclass class ActivationsVisualizer: diff --git a/src/revnets/evaluations/analysis/weights.py b/src/revnets/evaluations/analysis/weights.py index 86ed4b2..31c3b9d 100644 --- a/src/revnets/evaluations/analysis/weights.py +++ b/src/revnets/evaluations/analysis/weights.py @@ -1,4 +1,3 @@ -from collections.abc import Iterator from dataclasses import dataclass from typing import Any @@ -7,7 +6,6 @@ from torch.nn import Module from revnets.context import context -from revnets.standardization import extract_layer_weights, generate_layers from ...utils.colors import get_colors from ..weights import layers_mae @@ -27,23 +25,16 @@ def evaluate(self) -> None: self.visualize_network_differences() def visualize_network_weights(self, network: Module, name: str) -> None: - layer_weights = self.generate_layer_weights(network) + layer_weights = layers_mae.generate_layer_weights(network) for i, weights in enumerate(layer_weights): title = f"{name} layer {i + 1} weights".capitalize() self.visualize_layer_weights(weights, title) - def generate_layer_weights(self, network: Module) -> Iterator[torch.Tensor]: - layers = generate_layers(network) - for layer in layers: - weights = self.extract_layer_weights(layer) - if weights is not None: - yield weights - @classmethod def visualize_layer_weights( cls, weights: torch.Tensor, title: str, n_show: int | None = None ) -> None: - weights = weights[:n_show] + weights = weights[:n_show].cpu() # weights = torch.transpose(weights, 0, 1) @@ -84,12 +75,3 @@ def create_figure(cls) -> Any: def show(cls) -> None: plt.legend(loc="upper left", bbox_to_anchor=(1, 1)) plt.show() - - def iterate_compared_layers( - self, device: torch.device | None = cpu - ) -> Iterator[tuple[torch.Tensor, torch.Tensor]]: - return super().iterate_compared_layers(device=device) - - @classmethod - def extract_layer_weights(cls, layer: Module) -> torch.Tensor | None: - return extract_layer_weights(layer, device=cpu) diff --git a/src/revnets/evaluations/weights/layers_mae.py b/src/revnets/evaluations/weights/layers_mae.py index 714ade5..d5470ba 100644 --- a/src/revnets/evaluations/weights/layers_mae.py +++ b/src/revnets/evaluations/weights/layers_mae.py @@ -1,36 +1,35 @@ from collections.abc import Iterator import torch +from torch.nn import Module +from torch.nn.functional import l1_loss -from revnets.standardization import extract_layer_weights, generate_layers +from revnets.standardization import extract_weights, generate_layers from . import mae +def generate_layer_weights(model: Module) -> Iterator[torch.Tensor]: + for layer in generate_layers(model): + try: + yield extract_weights(layer) + except StopIteration: + pass + + class Evaluator(mae.Evaluator): - def iterate_compared_layers( - self, device: torch.device | None = None - ) -> Iterator[tuple[torch.Tensor, torch.Tensor]]: - original_layers = generate_layers(self.target) - reconstruction_layers = generate_layers(self.reconstruction) - for original, reconstruction in zip(original_layers, reconstruction_layers): - original_weights = extract_layer_weights(original, device) - reconstruction_weights = extract_layer_weights(reconstruction, device) - if original_weights is not None and reconstruction_weights is not None: - yield original_weights, reconstruction_weights - - def calculate_distance(self) -> tuple[float, ...]: - return tuple( - self.calculate_weights_distance(original, reconstructed) - for original, reconstructed in self.iterate_compared_layers() - ) + def iterate_compared_layers(self) -> Iterator[tuple[torch.Tensor, torch.Tensor]]: + networks = self.target, self.reconstruction + weights_pair = [generate_layer_weights(self.target) for network in networks] + yield from zip(*weights_pair) + + def calculate_total_distance(self) -> tuple[float, ...]: + pairs = self.iterate_compared_layers() + return tuple(self.calculate_distance(*pair) for pair in pairs) @classmethod - def calculate_weights_distance( - cls, original_weights: torch.Tensor, reconstructed_weights: torch.Tensor - ) -> float: - distance = torch.nn.functional.l1_loss(original_weights, reconstructed_weights) - return distance.item() + def calculate_distance(cls, values: torch.Tensor, other: torch.Tensor) -> float: + return l1_loss(values, other).item() @classmethod def format_evaluation(cls, value: tuple[float, ...], precision: int = 3) -> str: # type: ignore[override] diff --git a/src/revnets/evaluations/weights/mae.py b/src/revnets/evaluations/weights/mae.py index b374cc2..2530df5 100644 --- a/src/revnets/evaluations/weights/mae.py +++ b/src/revnets/evaluations/weights/mae.py @@ -1,14 +1,10 @@ import torch +from torch.nn.functional import l1_loss from . import mse class Evaluator(mse.Evaluator): @classmethod - def calculate_weights_distance( - cls, original_weights: torch.Tensor, reconstructed_weights: torch.Tensor - ) -> float: - distance = torch.nn.functional.l1_loss( - original_weights, reconstructed_weights, reduction="sum" - ) - return distance.item() + def calculate_distance(cls, values: torch.Tensor, other: torch.Tensor) -> float: + return l1_loss(values, other, reduction="sum").item() diff --git a/src/revnets/evaluations/weights/max_ae.py b/src/revnets/evaluations/weights/max_ae.py index 75d7aea..22da711 100644 --- a/src/revnets/evaluations/weights/max_ae.py +++ b/src/revnets/evaluations/weights/max_ae.py @@ -1,21 +1,16 @@ import torch +from torch.nn.functional import l1_loss from . import mae class Evaluator(mae.Evaluator): - def calculate_distance(self) -> float: + def calculate_total_distance(self) -> float: return max( - self.calculate_weights_distance(original, reconstruction) + self.calculate_distance(original, reconstruction) for original, reconstruction in self.iterate_compared_layers() ) @classmethod - def calculate_weights_distance( - cls, original_weights: torch.Tensor, reconstructed_weights: torch.Tensor - ) -> float: - distances = torch.nn.functional.l1_loss( - original_weights, reconstructed_weights, reduction="none" - ) - distance = distances.max() - return distance.item() + def calculate_distance(cls, values: torch.Tensor, other: torch.Tensor) -> float: + return l1_loss(values, other, reduction="none").max().item() diff --git a/src/revnets/evaluations/weights/mse.py b/src/revnets/evaluations/weights/mse.py index ed397f1..859d0f0 100644 --- a/src/revnets/evaluations/weights/mse.py +++ b/src/revnets/evaluations/weights/mse.py @@ -3,6 +3,7 @@ from typing import cast import torch +from torch.nn.functional import mse_loss from revnets.standardization import Standardizer, align @@ -13,7 +14,7 @@ @dataclass class Evaluator(base.Evaluator): def evaluate(self) -> float | tuple[float, ...] | None: - return self.calculate_distance() if self.standardize_networks() else None + return self.calculate_total_distance() if self.standardize_networks() else None def standardize_networks(self) -> bool: standardized = self.has_same_architecture() @@ -31,24 +32,19 @@ def has_same_architecture(self) -> bool: for original, reconstruction in self.iterate_compared_layers() ) - def calculate_distance(self) -> float | tuple[float, ...]: + def calculate_total_distance(self) -> float | tuple[float, ...]: layer_weights = self.target.state_dict().values() total_size = sum(weights.numel() for weights in layer_weights) total_distance = sum( - self.calculate_weights_distance(original, reconstruction) + self.calculate_distance(original, reconstruction) for original, reconstruction in self.iterate_compared_layers() ) distance = total_distance / total_size return cast(float, distance) @classmethod - def calculate_weights_distance( - cls, original: torch.Tensor, reconstruction: torch.Tensor - ) -> float: - distance = torch.nn.functional.mse_loss( - original, reconstruction, reduction="sum" - ) - return distance.item() + def calculate_distance(cls, values: torch.Tensor, other: torch.Tensor) -> float: + return mse_loss(values, other, reduction="sum").item() def iterate_compared_layers(self) -> Iterator[tuple[torch.Tensor, torch.Tensor]]: yield from zip( diff --git a/src/revnets/evaluations/weights/named_layers_mae.py b/src/revnets/evaluations/weights/named_layers_mae.py index 9506a38..9ed3ab0 100644 --- a/src/revnets/evaluations/weights/named_layers_mae.py +++ b/src/revnets/evaluations/weights/named_layers_mae.py @@ -7,9 +7,9 @@ class Evaluator(layers_mae.Evaluator): - def calculate_distance(self) -> dict[str, float]: # type: ignore[override] + def calculate_total_distance(self) -> dict[str, float]: # type: ignore[override] return { - name: self.calculate_weights_distance(original, reconstructed) + name: self.calculate_distance(original, reconstructed) for name, original, reconstructed in self.iterate_named_compared_layers() } diff --git a/src/revnets/main/main.py b/src/revnets/main/main.py index 69eef32..db60368 100644 --- a/src/revnets/main/main.py +++ b/src/revnets/main/main.py @@ -6,6 +6,7 @@ import numpy as np import torch from torch import nn +from torchsummary import summary from revnets import pipelines from revnets.models import Experiment as Config @@ -34,11 +35,18 @@ def run(self) -> None: def run_experiment(self) -> dict[str, Any]: pipeline: Pipeline = extract_module(pipelines, self.config.pipeline).Pipeline() + self.log_number_of_parameters(pipeline) reconstruction = self.create_reconstruction(pipeline) evaluation = evaluations.evaluate(reconstruction, pipeline) evaluation.show() return {"metrics": evaluation.dict(), "config": context.config.dict()} + def log_number_of_parameters(self, pipeline: Pipeline) -> None: + network = pipeline.create_initialized_network() + network = network.to(dtype=torch.float32).to(context.device) + data = pipeline.load_prepared_data() + summary(network, data.input_shape) + def create_reconstruction(self, pipeline: Pipeline) -> nn.Module: module = extract_module(reconstructions, self.config.reconstruction_technique) reconstructor: Reconstructor = module.Reconstructor(pipeline) diff --git a/src/revnets/models/config.py b/src/revnets/models/config.py index 984b129..9d0bec2 100644 --- a/src/revnets/models/config.py +++ b/src/revnets/models/config.py @@ -32,9 +32,9 @@ class Evaluation: @dataclass class Config(SerializationMixin): - sampling_data_size: int = 20000 + sampling_data_size: int = 102400 reconstruction_training: HyperParameters = HyperParameters( - epochs=300, learning_rate=1e-1, batch_size=256 + epochs=300, learning_rate=1e-2, batch_size=256 ) reconstruct_from_checkpoint: bool = False always_train: bool = True @@ -47,11 +47,14 @@ class Config(SerializationMixin): target_network_training: HyperParameters = HyperParameters( epochs=100, learning_rate=1.0e-2, batch_size=32 ) + difficult_inputs_training: HyperParameters = HyperParameters( + epochs=1000, learning_rate=1.0e-3 + ) evaluation: Evaluation = field(default_factory=Evaluation) evaluation_batch_size: int = 1000 num_workers: int = 8 - early_stopping_patience: int = 20 + early_stopping_patience: int = 10 n_networks: int = 2 visualization_interval = 10 weight_variance_downscale_factor: float | None = None @@ -68,7 +71,6 @@ class Config(SerializationMixin): validation_ratio: float = 0.1 console_metrics_refresh_interval: float = 0.5 - max_difficult_inputs_epochs: int = 100 limit_batches: int | None = None diff --git a/src/revnets/networks/base.py b/src/revnets/networks/base.py index c26be13..63f3de9 100644 --- a/src/revnets/networks/base.py +++ b/src/revnets/networks/base.py @@ -17,6 +17,7 @@ class NetworkFactory(NamedClass): activation: Activation = field( default_factory=lambda: context.config.target_network_training.activation ) + input_shape: tuple[int, ...] | None = None def create_activation_layer(self) -> Module: activation_layer: Module diff --git a/src/revnets/networks/images/__init__.py b/src/revnets/networks/images/__init__.py index d21d4be..180e980 100644 --- a/src/revnets/networks/images/__init__.py +++ b/src/revnets/networks/images/__init__.py @@ -1 +1 @@ -from . import mediumnet, mininet +from . import cnn, mediumnet, mininet diff --git a/src/revnets/networks/images/cnn/__init__.py b/src/revnets/networks/images/cnn/__init__.py new file mode 100644 index 0000000..11b37fa --- /dev/null +++ b/src/revnets/networks/images/cnn/__init__.py @@ -0,0 +1 @@ +from . import lenet, mini diff --git a/src/revnets/networks/images/cnn/lenet.py b/src/revnets/networks/images/cnn/lenet.py new file mode 100644 index 0000000..bdfd0bc --- /dev/null +++ b/src/revnets/networks/images/cnn/lenet.py @@ -0,0 +1,34 @@ +from collections.abc import Iterable +from dataclasses import dataclass + +from torch import nn + +from . import mini + + +@dataclass +class NetworkFactory(mini.NetworkFactory): + hidden_size1: int = 120 + hidden_size2: int = 84 + + def create_layers(self) -> Iterable[nn.Module]: + yield from ( + # 28 x 28 x 1 + nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1), + nn.Tanh(), + # 24 x 24 x 6 + nn.AvgPool2d(kernel_size=2, stride=2), + # 12 x 12 x 6 + nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1), + nn.Tanh(), + # 8 x 8 x 16 + nn.AvgPool2d(kernel_size=2, stride=2), + # 4 x 4 x 16 + nn.Flatten(), + # 256 + nn.Linear(in_features=256, out_features=self.hidden_size1), + nn.Tanh(), + nn.Linear(in_features=self.hidden_size1, out_features=self.hidden_size2), + nn.Tanh(), + nn.Linear(in_features=self.hidden_size2, out_features=self.output_size), + ) diff --git a/src/revnets/networks/images/cnn/mini.py b/src/revnets/networks/images/cnn/mini.py new file mode 100644 index 0000000..8f1723a --- /dev/null +++ b/src/revnets/networks/images/cnn/mini.py @@ -0,0 +1,25 @@ +from collections.abc import Iterable +from dataclasses import dataclass + +from torch import nn + +from .. import mediumnet + + +@dataclass +class NetworkFactory(mediumnet.NetworkFactory): + input_shape: tuple[int, ...] = 1, 28, 28 + + def create_layers(self) -> Iterable[nn.Module]: + yield from ( + # 28 x 28 x 1 + nn.Conv2d(in_channels=1, out_channels=4, kernel_size=5, stride=8), + nn.LeakyReLU(), + # 6 x 6 x 6 + # nn.Conv2d(in_channels=6, out_channels=6, kernel_size=5, stride=2), + # nn.Tanh(), + # 4 x 4 x 6 + # 96 + nn.Conv2d(in_channels=4, out_channels=10, kernel_size=3), + nn.Flatten(), + ) diff --git a/src/revnets/pipelines/images/__init__.py b/src/revnets/pipelines/images/__init__.py index 288514e..f824798 100644 --- a/src/revnets/pipelines/images/__init__.py +++ b/src/revnets/pipelines/images/__init__.py @@ -1,2 +1,3 @@ +from . import cnn from .mediumnet import mediumnet, mediumnet_small from .mininet import mininet_100, mininet_128, mininet_200, mininet_small diff --git a/src/revnets/pipelines/images/cnn.py b/src/revnets/pipelines/images/cnn.py new file mode 100644 index 0000000..6ac9647 --- /dev/null +++ b/src/revnets/pipelines/images/cnn.py @@ -0,0 +1,15 @@ +from dataclasses import dataclass + +from revnets import networks +from revnets.networks import NetworkFactory + +from .mininet import mininet_small + + +@dataclass +class Pipeline(mininet_small.Pipeline): + max_epochs: int = 10 + + @classmethod + def create_network_factory(cls) -> NetworkFactory: + return networks.images.cnn.mini.NetworkFactory() diff --git a/src/revnets/pipelines/train.py b/src/revnets/pipelines/train.py index e87f7be..18e9bb4 100644 --- a/src/revnets/pipelines/train.py +++ b/src/revnets/pipelines/train.py @@ -1,5 +1,5 @@ from abc import ABC -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import cached_property import torch @@ -18,6 +18,10 @@ @dataclass class Pipeline(base.Pipeline, ABC): + max_epochs: int = field( + default_factory=lambda: context.config.target_network_training.epochs + ) + def create_initialized_network(self) -> Sequential: return self.network_factory.create_network(seed=context.config.experiment.seed) @@ -53,9 +57,8 @@ def train(self, network: torch.nn.Module) -> None: data = self.load_data() self.run_training(trainable_network, data) - @classmethod - def run_training(cls, network: LightningModule, data: DataModule) -> None: - trainer = Trainer(max_epochs=context.config.target_network_training.epochs) + def run_training(self, network: LightningModule, data: DataModule) -> None: + trainer = Trainer(max_epochs=self.max_epochs) trainer.fit(network, data) trainer.test(network, data) diff --git a/src/revnets/reconstructions/queries/__init__.py b/src/revnets/reconstructions/queries/__init__.py index 0923121..d26dc3b 100644 --- a/src/revnets/reconstructions/queries/__init__.py +++ b/src/revnets/reconstructions/queries/__init__.py @@ -1,2 +1,8 @@ -from . import arbitrary_correlated_features, correlated_features, iterative, random +from . import ( + arbitrary_correlated_features, + correlated_features, + iterative, + random, + target_train_data, +) from .base import Reconstructor diff --git a/src/revnets/reconstructions/queries/base.py b/src/revnets/reconstructions/queries/base.py index 14a277a..abd0dc8 100644 --- a/src/revnets/reconstructions/queries/base.py +++ b/src/revnets/reconstructions/queries/base.py @@ -23,6 +23,9 @@ @dataclass class Reconstructor(base.Reconstructor): num_samples: int = field(default_factory=lambda: context.config.sampling_data_size) + max_epochs: int = field( + default_factory=lambda: context.config.reconstruction_training.epochs + ) @cached_property def trained_weights_path(self) -> Path: @@ -40,14 +43,15 @@ def reconstruct_weights(self) -> None: self.save_weights() self.load_weights() - def create_trainer(self) -> Trainer: + def create_trainer(self, max_epochs: int | None = None) -> Trainer: patience = context.config.early_stopping_patience callbacks = ( EarlyStopping("train l1 loss", patience=patience, verbose=True), MAECalculator(self.reconstruction, self.pipeline), LearningRateScheduler(), ) - max_epochs = context.config.reconstruction_training.epochs + if max_epochs is None: + max_epochs = self.max_epochs return Trainer(callbacks=callbacks, max_epochs=max_epochs) # type: ignore[arg-type] def create_train_network(self) -> Network: diff --git a/src/revnets/reconstructions/queries/correlated_features.py b/src/revnets/reconstructions/queries/correlated_features.py index 3f57f07..9de24ee 100644 --- a/src/revnets/reconstructions/queries/correlated_features.py +++ b/src/revnets/reconstructions/queries/correlated_features.py @@ -21,7 +21,7 @@ def create_random_inputs(self, shape: Sequence[int]) -> torch.Tensor: distribution = MultivariateNormal(means, covariance_matrix) # type: ignore[no-untyped-call] sample_shape = torch.Size((shape[0],)) # same mean, variance, and covariance as the training data - return distribution.sample(sample_shape) + return distribution.sample(sample_shape).reshape(shape) def infer_distribution_parameters(self) -> tuple[torch.Tensor, torch.Tensor]: train_inputs = self.extract_flattened_train_inputs() diff --git a/src/revnets/reconstructions/queries/iterative/difficult_inputs.py b/src/revnets/reconstructions/queries/iterative/difficult_inputs.py index d5d54d5..82a1a22 100644 --- a/src/revnets/reconstructions/queries/iterative/difficult_inputs.py +++ b/src/revnets/reconstructions/queries/iterative/difficult_inputs.py @@ -18,12 +18,15 @@ class InputNetwork(LightningModule): def __init__( self, - shape: tuple[int, int], + shape: tuple[int, ...], reconstructions: list[torch.nn.Sequential], - learning_rate: float = 0.01, - verbose: bool = False, + learning_rate: float | None = None, + verbose: bool = True, ) -> None: super().__init__() + self.shape = shape + if learning_rate is None: + learning_rate = context.config.difficult_inputs_training.learning_rate self.learning_rate = learning_rate self.inputs_embedding = self.create_input_embeddings(shape) self.reconstructions = torch.nn.ModuleList(reconstructions) @@ -34,15 +37,16 @@ def on_train_start(self) -> None: print("\nAverage pairwise distances: ", end="\n\t") # pragma: nocover @classmethod - def create_input_embeddings(cls, shape: tuple[int, int]) -> torch.nn.Embedding: - embeddings = torch.nn.Embedding(*shape) + def create_input_embeddings(cls, shape: tuple[int, ...]) -> torch.nn.Embedding: + feature_shape = math.prod(shape[1:]) + embeddings = torch.nn.Embedding(shape[0], feature_shape) torch.nn.init.normal_(embeddings.weight) return embeddings def forward(self, _: Any) -> torch.Tensor: outputs = [] for reconstruction in self.reconstructions: - output = reconstruction(self.inputs_embedding.weight) + output = reconstruction(self.inputs_embedding.weight.reshape(self.shape)) reconstruction.zero_grad() outputs.append(output) return torch.stack(outputs) @@ -65,7 +69,7 @@ def configure_optimizers(self) -> Optimizer: self.inputs_embedding.parameters(), lr=self.learning_rate ) - def get_optimized_inputs(self) -> torch.Tensor: + def extract_optimized_inputs(self) -> torch.Tensor: return self.inputs_embedding.weight.detach() @@ -88,26 +92,25 @@ def __post_init__(self) -> None: for seed in range(self.n_networks) ] - @property - def feature_size(self) -> int: - return math.prod(self.input_shape) + def create_queries(self, num_samples: int) -> torch.Tensor: + return self.create_difficult_samples() def create_difficult_samples(self) -> torch.Tensor: - shape = (self.num_samples, self.feature_size) + shape = (self.num_samples, *self.input_shape) network = InputNetwork(shape, self.reconstructions) self.fit_inputs_network(network) - return network.get_optimized_inputs() + return network.extract_optimized_inputs() @classmethod def fit_inputs_network(cls, network: InputNetwork) -> None: - max_epochs = context.config.max_difficult_inputs_epochs - trainer = Trainer(max_epochs=max_epochs, log_every_n_steps=1) + epochs = context.config.difficult_inputs_training.epochs + trainer = Trainer(max_epochs=epochs, log_every_n_steps=1) dataset = EmptyDataset() dataloader = DataLoader(dataset) trainer.fit(network, dataloader) def run_round(self, data: DataModule) -> None: - trainer = self.create_trainer() + trainer = self.create_trainer(max_epochs=1) networks = [Network(reconstruction) for reconstruction in self.reconstructions] for network in networks: trainer.fit(network, data) diff --git a/src/revnets/reconstructions/queries/iterative/difficult_train_inputs.py b/src/revnets/reconstructions/queries/iterative/difficult_train_inputs.py index a65ff41..feb5881 100644 --- a/src/revnets/reconstructions/queries/iterative/difficult_train_inputs.py +++ b/src/revnets/reconstructions/queries/iterative/difficult_train_inputs.py @@ -22,12 +22,23 @@ def create_difficult_samples(self) -> torch.Tensor: return recombined_inputs + self.noise_factor * noise def recombine(self, inputs: torch.Tensor) -> torch.Tensor: - new_samples_shape = self.num_samples, inputs.shape[-1] - n_inputs = len(inputs) - new_samples = np.random.choice(range(n_inputs), size=new_samples_shape) + feature_dimensions = inputs.shape[1:] + untyped_feature_dimension = np.prod(feature_dimensions) + feature_dimension = cast(int, untyped_feature_dimension) + flat_inputs = inputs.reshape((-1, feature_dimension)) + recombined_flat_inputs = self.recombine_flat(flat_inputs) + shape = -1, *feature_dimensions + return recombined_flat_inputs.reshape(shape) + + def recombine_flat(self, inputs: torch.Tensor) -> torch.Tensor: + number_of_features = inputs.shape[-1] + new_samples_shape = self.num_samples, number_of_features + number_of_inpus = len(inputs) + choices = range(number_of_inpus) + new_samples = np.random.choice(choices, size=new_samples_shape) # each feature value in a new sample corresponds with a feature value # in the corresponding feature of one of the inputs - return inputs[new_samples, np.arange(new_samples.shape[1])] + return inputs[new_samples, np.arange(number_of_features)] def extract_difficult_inputs(self) -> torch.Tensor: data = self.pipeline.load_prepared_data() diff --git a/src/revnets/reconstructions/queries/random.py b/src/revnets/reconstructions/queries/random.py index 62f712d..ee18558 100644 --- a/src/revnets/reconstructions/queries/random.py +++ b/src/revnets/reconstructions/queries/random.py @@ -19,8 +19,7 @@ def create_queries(self, num_samples: int) -> torch.Tensor: @property def input_shape(self) -> tuple[int, ...]: dataset = self.pipeline.load_prepared_data() - shape = dataset.train_validation[0][0].shape - return cast(tuple[int, ...], shape) + return dataset.input_shape def create_random_inputs(self, shape: Sequence[int]) -> torch.Tensor: train_inputs = self.extract_train_inputs() diff --git a/src/revnets/reconstructions/queries/target_train_data.py b/src/revnets/reconstructions/queries/target_train_data.py new file mode 100644 index 0000000..2957490 --- /dev/null +++ b/src/revnets/reconstructions/queries/target_train_data.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass + +import torch +from torch.utils.data import DataLoader + +from .iterative import difficult_train_inputs + + +@dataclass +class Reconstructor(difficult_train_inputs.Reconstructor): + def create_queries(self, num_samples: int) -> torch.Tensor: + data = self.pipeline.load_prepared_data().train_validation + batch_size = len(data) # type: ignore[arg-type] + dataloader = DataLoader(data, batch_size, shuffle=False) + inputs, _ = next(iter(dataloader)) + recombined_inputs = self.recombine(inputs) + return ( + recombined_inputs + self.create_random_inputs(recombined_inputs.shape) / 100 + ) diff --git a/src/revnets/standardization/__init__.py b/src/revnets/standardization/__init__.py index 6b7274b..d83b250 100644 --- a/src/revnets/standardization/__init__.py +++ b/src/revnets/standardization/__init__.py @@ -1,4 +1,4 @@ from . import network, order, scale from .align import align, calculate_optimal_order from .network import Standardizer, generate_internal_neurons, generate_layers -from .utils import extract_layer_weights, extract_linear_layer_weights +from .utils import extract_parameters, extract_weights diff --git a/src/revnets/standardization/align.py b/src/revnets/standardization/align.py index 47fd35c..aa722bd 100644 --- a/src/revnets/standardization/align.py +++ b/src/revnets/standardization/align.py @@ -21,13 +21,13 @@ def align(model: Module, target: Module) -> None: def align_internal_neurons(neurons: InternalNeurons, target: InternalNeurons) -> None: sort_indices = calculate_optimal_order(neurons.incoming, target.incoming) - order.permute_output_weights(neurons.incoming, sort_indices) - order.permute_input_weights(neurons.outgoing, sort_indices) + order.permute_incoming_weights(neurons.incoming, sort_indices) + order.permute_outgoing_weights(neurons.outgoing, sort_indices) def calculate_optimal_order(layer: Module, target: Module) -> torch.Tensor: - weights = order.extract_linear_layer_weights(layer) - target_weights = order.extract_linear_layer_weights(target) + weights = order.extract_weights(layer) + target_weights = order.extract_weights(target) distances = torch.cdist(target_weights, weights, p=1).numpy() indices = linear_sum_assignment(distances)[1] return torch.from_numpy(indices) diff --git a/src/revnets/standardization/network.py b/src/revnets/standardization/network.py index 30ef5b1..dec7912 100644 --- a/src/revnets/standardization/network.py +++ b/src/revnets/standardization/network.py @@ -3,19 +3,19 @@ from functools import cached_property from typing import TypeVar, cast -from torch.nn import Flatten, Module +from torch import nn from revnets.models import InternalNeurons from . import order, scale -from .utils import extract_linear_layer_weights +from .utils import extract_weights T = TypeVar("T") @dataclass class Standardizer: - model: Module + model: nn.Module optimize_mae: bool = False def run(self) -> None: @@ -41,7 +41,7 @@ def apply_optimize_mae(self) -> None: scale.Standardizer(neurons).run() def calculate_average_scale_per_layer(self) -> float: - weights = extract_linear_layer_weights(self.internal_neurons[-1].outgoing) + weights = extract_weights(self.internal_neurons[-1].outgoing) last_neuron_scales = weights.norm(dim=1, p=2) last_neuron_scale = sum(last_neuron_scales) / len(last_neuron_scales) num_internal_layers = len(self.internal_neurons) @@ -54,7 +54,7 @@ def internal_neurons(self) -> list[InternalNeurons]: return list(neurons) -def generate_internal_neurons(model: Module) -> Iterator[InternalNeurons]: +def generate_internal_neurons(model: nn.Module) -> Iterator[InternalNeurons]: layers = generate_layers(model) layers_list = list(layers) for triplet in generate_triplets(layers_list): @@ -65,7 +65,11 @@ def generate_triplets(items: list[T]) -> Iterator[tuple[T, T, T]]: yield from zip(items[::2], items[1::2], items[2::2]) -def generate_layers(model: Module) -> Iterator[Module]: +# TODO: MaxPool destroys sign isomorphism for tanh +skip_layer_types = nn.Flatten, nn.MaxPool1d, nn.MaxPool2d, nn.AvgPool1d, nn.AvgPool2d + + +def generate_layers(model: nn.Module) -> Iterator[nn.Module]: """ :return: all root layers (the deepest level) in order of feature propagation """ @@ -74,5 +78,5 @@ def generate_layers(model: Module) -> Iterator[Module]: for child in children: yield from generate_layers(child) else: - if not isinstance(model, Flatten): + if not isinstance(model, skip_layer_types): yield model diff --git a/src/revnets/standardization/order.py b/src/revnets/standardization/order.py index 9f41837..607ae77 100644 --- a/src/revnets/standardization/order.py +++ b/src/revnets/standardization/order.py @@ -5,7 +5,7 @@ from revnets.models import InternalNeurons -from .utils import extract_linear_layer_weights +from .utils import extract_parameters, extract_weights @dataclass @@ -14,31 +14,28 @@ class Standardizer: def run(self) -> None: sort_indices = calculate_sort_order(self.neurons.incoming) - permute_output_weights(self.neurons.incoming, sort_indices) - permute_input_weights(self.neurons.outgoing, sort_indices) + permute_incoming_weights(self.neurons.incoming, sort_indices) + permute_outgoing_weights(self.neurons.outgoing, sort_indices) def calculate_sort_order(layer: nn.Module) -> torch.Tensor: - weights = extract_linear_layer_weights(layer) + weights = extract_weights(layer) p = 1 # use l1-norm because l2-norm is already standardized sort_values = weights.norm(dim=1, p=p) return torch.sort(sort_values)[1] -def permute_input_weights(layer: nn.Module, sort_indices: torch.Tensor) -> None: - length = len(sort_indices) - for param in layer.parameters(): - shape = param.data.shape - if len(shape) == 2 and shape[1] == length: - param.data = param.data[:, sort_indices] +def permute_incoming_weights(layer: nn.Module, sort_indices: torch.Tensor) -> None: + parameters = extract_parameters(layer) + parameters.weight.data = parameters.weight.data[sort_indices] + if parameters.bias is not None: + parameters.bias.data = parameters.bias.data[sort_indices] -def permute_output_weights(layer: nn.Module, sort_indices: torch.Tensor) -> None: - length = len(sort_indices) - for param in layer.parameters(): - shape = param.shape - dims = len(shape) - if dims in (1, 2) and shape[0] == length: - param.data = ( - param.data[sort_indices] if dims == 1 else param.data[sort_indices, :] - ) +def permute_outgoing_weights(layer: nn.Module, sort_indices: torch.Tensor) -> None: + parameters = extract_parameters(layer) + # take into account that flatten layers cause outgoing weights with altered shapes + shape = parameters.weight.data.shape[0], sort_indices.shape[0], -1 + data = parameters.weight.data.view(shape) + data = torch.index_select(data, 1, sort_indices) + parameters.weight.data = data.reshape(parameters.weight.shape) diff --git a/src/revnets/standardization/scale.py b/src/revnets/standardization/scale.py index 2e8b0a5..1bf04bf 100644 --- a/src/revnets/standardization/scale.py +++ b/src/revnets/standardization/scale.py @@ -6,7 +6,7 @@ from revnets.models import InternalNeurons -from .utils import extract_linear_layer_weights +from .utils import extract_parameters, extract_weights @dataclass @@ -28,7 +28,7 @@ def standardize_scale(self) -> None: rescale_outgoing_weights(self.neurons.outgoing, scale_factors) def calculate_scale_factors(self, layer: Module) -> torch.Tensor: - weights = extract_linear_layer_weights(layer) + weights = extract_weights(layer) scale_factors = ( torch.sign(weights.sum(dim=1)) if self.neurons.has_sign_isomorphism @@ -38,12 +38,24 @@ def calculate_scale_factors(self, layer: Module) -> torch.Tensor: def rescale_incoming_weights(layer: Module, scales: torch.Tensor) -> None: - for param in layer.parameters(): - multiplier = scales if len(param.data.shape) == 1 else scales.reshape(-1, 1) - param.data *= multiplier + parameters = extract_parameters(layer) + parameters.weight.data *= broadcast(scales, parameters.weight.data) + if parameters.bias is not None: + parameters.bias.data *= scales + + +def broadcast( + values: torch.Tensor, target: torch.Tensor, dimension: int = 0 +) -> torch.Tensor: + shape = [1] * target.dim() + shape[dimension] = -1 + return values.view(*shape) def rescale_outgoing_weights(layer: Module, scales: torch.Tensor) -> None: - for param in layer.parameters(): - if len(param.shape) == 2: - param.data *= scales + parameters = extract_parameters(layer) + # take into account that flatten layers cause outgoing weights with altered shapes + shape = parameters.weight.data.shape[0], scales.shape[0], -1 + data = parameters.weight.data.view(shape) + data *= broadcast(scales, data, dimension=1) + parameters.weight.data = data.reshape(parameters.weight.shape) diff --git a/src/revnets/standardization/utils.py b/src/revnets/standardization/utils.py index 2007b27..32b5851 100644 --- a/src/revnets/standardization/utils.py +++ b/src/revnets/standardization/utils.py @@ -1,24 +1,38 @@ +from dataclasses import dataclass + import torch -from torch.nn import Linear, Module +from torch.nn import Module, Parameter -def extract_layer_weights( - layer: Module, device: torch.device | None = None -) -> torch.Tensor | None: - if isinstance(layer, Linear): - weights = extract_linear_layer_weights(layer, device) - else: - weights = None - return weights +@dataclass +class Parameters: + weight: Parameter + bias: Parameter | None + + @property + def number_of_outputs(self) -> int: + return self.weight.shape[0] + +def extract_parameters(layer: Module) -> Parameters: + parameters = layer.parameters() + weight = next(parameters) + bias = next(parameters, None) + return Parameters(weight, bias) -def extract_linear_layer_weights( - layer: Module, device: torch.device | None = None -) -> torch.Tensor: + +def extract_weights(layer: Module) -> torch.Tensor: with torch.no_grad(): - connection_weights, bias_weights = layer.parameters() - weights_tuple = (connection_weights, bias_weights.reshape(-1, 1)) - weights = torch.hstack(weights_tuple) - if device is not None: - weights = weights.to(device) + return _extract_weights(layer) + + +def _extract_weights(layer: Module) -> torch.Tensor: + parameters = extract_parameters(layer) + # flatten incoming weights for each out layer output (Conv, ..) + shape = parameters.number_of_outputs, -1 + weights = parameters.weight.reshape(shape) + if parameters.bias is not None: + bias = parameters.bias.reshape(shape) + combined_weights = (weights, bias) + weights = torch.hstack(combined_weights) return weights diff --git a/src/revnets/training/network.py b/src/revnets/training/network.py index 29813c6..488be46 100644 --- a/src/revnets/training/network.py +++ b/src/revnets/training/network.py @@ -73,7 +73,7 @@ def log( # type: ignore[override] value: float, sync_dist: bool = True, on_epoch: bool = True, - on_step: bool = False, + on_step: bool = True, **kwargs: Any, ) -> None: if self.do_log: diff --git a/src/revnets/training/reconstructions/callbacks/learning_rate.py b/src/revnets/training/reconstructions/callbacks/learning_rate.py index 01deed7..ba8c6b4 100644 --- a/src/revnets/training/reconstructions/callbacks/learning_rate.py +++ b/src/revnets/training/reconstructions/callbacks/learning_rate.py @@ -35,7 +35,9 @@ class LearningRateScheduler(Callback): losses: list[float] = field(default_factory=list) def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: - pl_module.log("learning rate", pl_module.learning_rate, prog_bar=True) + pl_module.log( + "learning rate", pl_module.learning_rate, prog_bar=True, on_step=False + ) loss = trainer.callback_metrics["train l1 loss"].item() self.losses.append(loss) self.check_learning_rate(pl_module, loss) diff --git a/src/revnets/training/reconstructions/callbacks/mae.py b/src/revnets/training/reconstructions/callbacks/mae.py index fe4dfcd..d37acfe 100644 --- a/src/revnets/training/reconstructions/callbacks/mae.py +++ b/src/revnets/training/reconstructions/callbacks/mae.py @@ -25,11 +25,11 @@ def __post_init__(self) -> None: def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: if trainer.current_epoch % self.calculate_interval == 0: mae = self.calculate_mae() - pl_module.log(self.name, mae, prog_bar=True) + pl_module.log(self.name, mae, prog_bar=True, on_step=False) def calculate_mae(self) -> float: state_dict = self.reconstruction.state_dict() Standardizer(self.reconstruction).run() - mae = self.evaluator.calculate_distance() + mae = self.evaluator.calculate_total_distance() self.reconstruction.load_state_dict(state_dict) return cast(float, mae) diff --git a/src/revnets/training/reconstructions/network.py b/src/revnets/training/reconstructions/network.py index e21f163..7dd7776 100644 --- a/src/revnets/training/reconstructions/network.py +++ b/src/revnets/training/reconstructions/network.py @@ -23,7 +23,7 @@ def update_learning_rate(self, learning_rate: float) -> None: try: optimizer = self.optimizers() # if training is no longer running, optimizer is an empty list - if optimizer: + if optimizer: # pragma: nocover typed_optimizer = cast(LightningOptimizer, optimizer) self.update_optimizer(typed_optimizer, learning_rate) except RuntimeError: # pragma: nocover diff --git a/tests/conftest.py b/tests/conftest.py index 174b9c0..c164f87 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -32,7 +32,7 @@ def test_context(context: Context, mocked_assets_path: None) -> Iterator[Context context.loaders.config.value = Config( target_network_training=hyperparameters, reconstruction_training=hyperparameters, - max_difficult_inputs_epochs=1, + difficult_inputs_training=hyperparameters, evaluation=evaluation, limit_batches=5, ) diff --git a/tests/evaluations/test_standardize.py b/tests/evaluations/test_standardize.py index 43b7c07..40b9dc8 100644 --- a/tests/evaluations/test_standardize.py +++ b/tests/evaluations/test_standardize.py @@ -16,6 +16,8 @@ networks.mediumnet, networks.images.mininet, networks.images.mediumnet, + networks.images.cnn.mini, + networks.images.cnn.lenet, ) activations = ( Activation.leaky_relu, @@ -76,6 +78,5 @@ def test_second_standardize_no_effect( @pytest.mark.parametrize("network_module", network_modules) @pytest.mark.parametrize("activation", activations) def test_optimize_mae(network_module: ModuleType, activation: Activation) -> None: - standardization_type = Standardization.standardize - tester = Verifier(network_module, activation, standardization_type) + tester = Verifier(network_module, activation, Standardization.standardize) tester.test_optimize_mae() diff --git a/tests/evaluations/test_utils.py b/tests/evaluations/test_utils.py new file mode 100644 index 0000000..1b7f3de --- /dev/null +++ b/tests/evaluations/test_utils.py @@ -0,0 +1,37 @@ +import pytest +from hypothesis import given, strategies +from revnets.standardization.utils import extract_weights +from torch import nn + +from tests.evaluations.test_standardize import activations + + +@given( + in_features=strategies.integers(min_value=1, max_value=10), + out_features=strategies.integers(min_value=1, max_value=10), +) +def test_linear_extract_weights(in_features: int, out_features: int) -> None: + layer = nn.Linear(in_features=in_features, out_features=out_features) + weights = extract_weights(layer) + assert weights.shape == (out_features, in_features + 1) + + +@given( + in_channels=strategies.integers(min_value=1, max_value=10), + out_channels=strategies.integers(min_value=1, max_value=10), + kernel_size=strategies.integers(min_value=1, max_value=10), +) +def test_convolutional_extract_weights( + in_channels: int, out_channels: int, kernel_size: int +) -> None: + layer = nn.Conv2d( + in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size + ) + weights = extract_weights(layer) + assert weights.shape == (out_channels, in_channels * kernel_size**2 + 1) + + +@pytest.mark.parametrize("activation", activations) +def test_activation_extract_weights(activation: nn.Module) -> None: + with pytest.raises(AttributeError): + extract_weights(activation) diff --git a/tests/evaluations/verifier.py b/tests/evaluations/verifier.py index 083909d..a76975a 100644 --- a/tests/evaluations/verifier.py +++ b/tests/evaluations/verifier.py @@ -9,6 +9,7 @@ from revnets import standardization from revnets.context import context from revnets.models import Activation, InternalNeurons +from revnets.networks import NetworkFactory from revnets.standardization import Standardizer, align, generate_internal_neurons from torch.nn import Module, Sequential @@ -32,6 +33,12 @@ def __post_init__(self) -> None: def target(self) -> Sequential: return self.create_network() + @property + def network_factory(self) -> NetworkFactory: + Factory = self.network_module.NetworkFactory + factory = Factory(activation=self.activation) + return cast(NetworkFactory, factory) + def test_standardized_form(self) -> None: self.apply_transformation() self.verify_form() @@ -79,19 +86,21 @@ def verify_aligned_form(self) -> None: verify_aligned(neurons, target_neurons) def create_network(self) -> Sequential: - Factory = self.network_module.NetworkFactory - factory = Factory(activation=self.activation) - network = factory.create_network() - return cast(Sequential, network) + return self.network_factory.create_network() def create_network_inputs(self) -> torch.Tensor: - size = 1, self.extract_input_size() + feature_shape = ( + self.extract_feature_shape() + if self.network_factory.input_shape is None + else self.network_factory.input_shape + ) + size = 1, *feature_shape return torch.rand(size, dtype=context.dtype) * 20 - 10 - def extract_input_size(self) -> int: + def extract_feature_shape(self) -> tuple[int, ...]: input_layer = next(self.extract_layers()) - size = input_layer.weight.shape[1] - return cast(int, size) + shape = input_layer.weight.shape[1:] + return cast(tuple[int, ...], shape) def extract_layers(self) -> Iterator[Module]: for layer in self.network.children(): @@ -120,7 +129,7 @@ def verify_scale_standardized(neurons: InternalNeurons) -> None: def verify_order_standardized(neurons: InternalNeurons) -> None: - weights = standardization.extract_linear_layer_weights(neurons.incoming) + weights = standardization.extract_weights(neurons.incoming) incoming_weights = weights.norm(dim=1, p=1) sorted_indices = incoming_weights[:-1] <= incoming_weights[1:] is_sorted = torch.all(sorted_indices) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 4720e07..16dae80 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -9,6 +9,7 @@ pipelines.mediumnet.mediumnet, pipelines.images.mininet_small, pipelines.images.mediumnet_small, + pipelines.images.cnn, ) all_pipeline_modules = ( diff --git a/tests/test_reconstructions.py b/tests/test_reconstructions.py index b0df07c..bb69c65 100644 --- a/tests/test_reconstructions.py +++ b/tests/test_reconstructions.py @@ -1,9 +1,13 @@ from types import ModuleType import pytest +from pytorch_lightning.core.optimizer import LightningOptimizer from revnets import reconstructions from revnets.pipelines.mininet import Pipeline from revnets.reconstructions import Reconstructor, queries +from revnets.training.reconstructions import Network +from revnets.training.reconstructions.callbacks import LearningRateScheduler +from torch import nn reconstruction_modules = ( reconstructions.empty, @@ -12,6 +16,7 @@ queries.random, queries.correlated_features, queries.arbitrary_correlated_features, + queries.target_train_data, queries.iterative.difficult_inputs, queries.iterative.difficult_train_inputs, ) @@ -22,3 +27,12 @@ def test_reconstructions(reconstruction_module: ModuleType, test_context: None) pipeline = Pipeline() reconstructor: Reconstructor = reconstruction_module.Reconstructor(pipeline) reconstructor.create_reconstruction() + + +def test_learning_rate_scheduler() -> None: + model = nn.Linear(in_features=1, out_features=1) + network = Network(model) + LearningRateScheduler.check_learning_rate(network, 0) + optimizer = network.configure_optimizers() + lightning_optimizer = LightningOptimizer(optimizer) + network.update_optimizer(lightning_optimizer, 0)