diff --git a/src/revnets/context/context.py b/src/revnets/context/context.py index f3e9027..0757889 100644 --- a/src/revnets/context/context.py +++ b/src/revnets/context/context.py @@ -2,9 +2,8 @@ import torch from package_utils.context import Context as Context_ -from torch import nn -from ..models import Activation, Config, HyperParameters, Options, Path +from ..models import Config, HyperParameters, Options, Path class Context(Context_[Options, Config, None]): @@ -50,23 +49,21 @@ def log_path(self) -> Path: def log_path_str(self) -> str: return str(self.log_path) - @property - def activation_layer(self) -> nn.Module: - activation = context.config.target_network_training.activation - activation_layer: nn.Module - match activation: - case Activation.leaky_relu: - activation_layer = nn.LeakyReLU() - case Activation.relu: - activation_layer = nn.ReLU() - case Activation.tanh: - activation_layer = nn.Tanh() - return activation_layer - @cached_property def device(self) -> torch.device: name = "cuda" if torch.cuda.is_available() else "cpu" return torch.device(name) + @property + def dtype(self) -> torch.dtype: + match self.config.precision: + case 32: + dtype = torch.float32 + case 64: + dtype = torch.float64 + case _: + raise ValueError(f"Unsupported precision {self.config.precision}") + return dtype + context = Context(Options, Config, None) diff --git a/src/revnets/evaluations/weights/standardize/network.py b/src/revnets/evaluations/weights/standardize/network.py index 0a1b294..ac71d9a 100644 --- a/src/revnets/evaluations/weights/standardize/network.py +++ b/src/revnets/evaluations/weights/standardize/network.py @@ -3,7 +3,7 @@ from functools import cached_property from typing import TypeVar, cast -from torch.nn import Module +from torch.nn import Flatten, Module from revnets.models import InternalNeurons @@ -57,7 +57,7 @@ def generate_internal_neurons(model: Module) -> Iterator[InternalNeurons]: def generate_triplets(items: list[T]) -> Iterator[tuple[T, T, T]]: - yield from zip(items, items[1:], items[2:]) + yield from zip(items[::2], items[1::2], items[2::2]) def generate_layers(model: Module) -> Iterator[Module]: @@ -69,4 +69,5 @@ def generate_layers(model: Module) -> Iterator[Module]: for child in children: yield from generate_layers(child) else: - yield model + if not isinstance(model, Flatten): + yield model diff --git a/src/revnets/evaluations/weights/standardize/order.py b/src/revnets/evaluations/weights/standardize/order.py index 327aa6d..b8f3be4 100644 --- a/src/revnets/evaluations/weights/standardize/order.py +++ b/src/revnets/evaluations/weights/standardize/order.py @@ -13,16 +13,16 @@ class Standardizer: neurons: InternalNeurons def run(self) -> None: - sort_indices = calculate_sort_order(self.neurons.incoming) + sort_indices = calculate_sort_indices(self.neurons.incoming) permute_output_weights(self.neurons.incoming, sort_indices) permute_input_weights(self.neurons.outgoing, sort_indices) -def calculate_sort_order(layer: nn.Module) -> torch.Tensor: +def calculate_sort_indices(layer: nn.Module) -> torch.Tensor: weights = extract_linear_layer_weights(layer) p = 1 # use l1-norm because l2-norm is already standardized - total_output_weights = weights.norm(dim=1, p=p) - return torch.sort(total_output_weights)[1] + sort_values = weights.norm(dim=1, p=p) + return torch.sort(sort_values)[1] def permute_input_weights(layer: nn.Module, sort_indices: torch.Tensor) -> None: diff --git a/src/revnets/models/config.py b/src/revnets/models/config.py index f1b750d..b1dd1c4 100644 --- a/src/revnets/models/config.py +++ b/src/revnets/models/config.py @@ -18,7 +18,7 @@ class HyperParameters: epochs: int = 1 bias_learning_rate: float | None = None batch_size: int = 1 - activation: Enum = Activation.leaky_relu + activation: Activation = Activation.leaky_relu @dataclass diff --git a/src/revnets/networks/base.py b/src/revnets/networks/base.py index 83ddc8a..da98826 100644 --- a/src/revnets/networks/base.py +++ b/src/revnets/networks/base.py @@ -2,7 +2,10 @@ from dataclasses import dataclass import torch -from torch.nn import Sequential +from torch import nn +from torch.nn import Module, Sequential + +from revnets.models import Activation from ..context import context from ..utils import NamedClass @@ -10,13 +13,24 @@ @dataclass(frozen=True) class NetworkFactory(NamedClass): - activation_layer: torch.nn.Module = context.activation_layer output_size: int = 10 + activation: Activation = context.config.target_network_training.activation + + def create_activation_layer(self) -> Module: + activation_layer: Module + match self.activation: + case Activation.leaky_relu: + activation_layer = nn.LeakyReLU() + case Activation.relu: + activation_layer = nn.ReLU() + case Activation.tanh: + activation_layer = nn.Tanh() + return activation_layer def create_network(self, seed: int | None = None) -> Sequential: if seed is not None: torch.manual_seed(seed) - layers = self.create_layers() + layers = [layer.to(dtype=context.dtype) for layer in self.create_layers()] return Sequential(*layers) def create_layers(self) -> Iterable[torch.nn.Module]: diff --git a/src/revnets/networks/mediumnet.py b/src/revnets/networks/mediumnet.py index 46c1c2b..b3bba87 100644 --- a/src/revnets/networks/mediumnet.py +++ b/src/revnets/networks/mediumnet.py @@ -16,8 +16,8 @@ class NetworkFactory(base.NetworkFactory): def create_layers(self) -> Iterable[nn.Module]: return ( nn.Linear(self.input_size, self.hidden_size1), - self.activation_layer, + self.create_activation_layer(), nn.Linear(self.hidden_size1, self.hidden_size2), - self.activation_layer, + self.create_activation_layer(), nn.Linear(self.hidden_size2, self.output_size), ) diff --git a/src/revnets/networks/mininet.py b/src/revnets/networks/mininet.py index 2ac6018..28c3666 100644 --- a/src/revnets/networks/mininet.py +++ b/src/revnets/networks/mininet.py @@ -15,6 +15,6 @@ class NetworkFactory(base.NetworkFactory): def create_layers(self) -> Iterable[torch.nn.Module]: return ( nn.Linear(self.input_size, self.hidden_size), - self.activation_layer, + self.create_activation_layer(), nn.Linear(self.hidden_size, self.output_size), ) diff --git a/tests/evaluations/test_standardize.py b/tests/evaluations/test_standardize.py index 3fcefa5..8b6f48a 100644 --- a/tests/evaluations/test_standardize.py +++ b/tests/evaluations/test_standardize.py @@ -2,7 +2,7 @@ import pytest from revnets import networks -from torch import nn +from revnets.models import Activation from tests.evaluations.verifier import StandardizationType, Verifier @@ -11,58 +11,63 @@ StandardizationType.standardize, StandardizationType.align, ) -network_modules = (networks.mininet,) -activation_layers = ( - nn.ReLU(), - nn.LeakyReLU(), - nn.Tanh(), +network_modules = ( + networks.mininet, + networks.mediumnet, + networks.images.mininet, + networks.images.mediumnet, +) +activations = ( + Activation.leaky_relu, + Activation.relu, + Activation.tanh, ) @pytest.mark.parametrize("network_module", network_modules) -@pytest.mark.parametrize("activation_layer", activation_layers) +@pytest.mark.parametrize("activation", activations) @pytest.mark.parametrize("standardization_type", standardization_types) def test_standardized_form( network_module: ModuleType, - activation_layer: nn.Module, + activation: Activation, standardization_type: StandardizationType, ) -> None: - tester = Verifier(network_module, activation_layer, standardization_type) + tester = Verifier(network_module, activation, standardization_type) tester.test_standardized_form() @pytest.mark.parametrize("network_module", network_modules) -@pytest.mark.parametrize("activation_layer", activation_layers) +@pytest.mark.parametrize("activation", activations) @pytest.mark.parametrize("standardization_type", standardization_types) def test_standardize_preserves_functionality( network_module: ModuleType, - activation_layer: nn.Module, + activation: Activation, standardization_type: StandardizationType, ) -> None: - tester = Verifier(network_module, activation_layer, standardization_type) + tester = Verifier(network_module, activation, standardization_type) tester.test_functional_preservation() @pytest.mark.parametrize("network_module", network_modules) -@pytest.mark.parametrize("activation_layer", activation_layers) +@pytest.mark.parametrize("activation", activations) @pytest.mark.parametrize("standardization_type", standardization_types) def test_standardized_form_and_functionality_preservation( network_module: ModuleType, - activation_layer: nn.Module, + activation: Activation, standardization_type: StandardizationType, ) -> None: - tester = Verifier(network_module, activation_layer, standardization_type) + tester = Verifier(network_module, activation, standardization_type) tester.test_functional_preservation() tester.test_standardized_form() @pytest.mark.parametrize("network_module", network_modules) -@pytest.mark.parametrize("activation_layer", activation_layers) +@pytest.mark.parametrize("activation", activations) @pytest.mark.parametrize("standardization_type", standardization_types) def test_second_standardize_no_effect( network_module: ModuleType, - activation_layer: nn.Module, + activation: Activation, standardization_type: StandardizationType, ) -> None: - tester = Verifier(network_module, activation_layer, standardization_type) + tester = Verifier(network_module, activation, standardization_type) tester.test_second_standardize_no_effect() diff --git a/tests/evaluations/verifier.py b/tests/evaluations/verifier.py index 015743d..7edccbb 100644 --- a/tests/evaluations/verifier.py +++ b/tests/evaluations/verifier.py @@ -1,3 +1,4 @@ +from collections.abc import Iterator from dataclasses import dataclass from enum import Enum from functools import cached_property @@ -5,11 +6,11 @@ from typing import cast import torch +from revnets.context import context from revnets.evaluations.weights import standardize from revnets.evaluations.weights.standardize import Standardizer -from revnets.models import InternalNeurons -from torch import nn -from torch.nn import Sequential +from revnets.models import Activation, InternalNeurons +from torch.nn import Module, Sequential class StandardizationType(Enum): @@ -21,7 +22,7 @@ class StandardizationType(Enum): @dataclass class Verifier: network_module: ModuleType - activation_layer: nn.Module + activation: Activation standardization_type: StandardizationType def __post_init__(self) -> None: @@ -43,9 +44,9 @@ def apply_transformation(self) -> None: case StandardizationType.scale: Standardizer(self.network).standardize_scale() case StandardizationType.standardize: - return Standardizer(self.network).run() + Standardizer(self.network).run() case StandardizationType.align: - return standardize.align(self.network, self.target) + standardize.align(self.network, self.target) def test_functional_preservation(self) -> None: inputs = self.create_network_inputs() @@ -54,9 +55,7 @@ def test_functional_preservation(self) -> None: self.apply_transformation() with torch.no_grad(): outputs_after_transformation = self.network(inputs) - outputs_are_closes = torch.isclose( - outputs, outputs_after_transformation, rtol=1e-3 - ) + outputs_are_closes = torch.isclose(outputs, outputs_after_transformation) assert torch.all(outputs_are_closes) def verify_standardized_form(self) -> None: @@ -78,19 +77,24 @@ def verify_aligned_form(self) -> None: def create_network(self) -> Sequential: Factory = self.network_module.NetworkFactory - factory = Factory(activation_layer=self.activation_layer) + factory = Factory(activation=self.activation) network = factory.create_network() return cast(Sequential, network) def create_network_inputs(self) -> torch.Tensor: size = 1, self.extract_input_size() - return torch.rand(size) * 20 - 10 + return torch.rand(size, dtype=context.dtype) * 20 - 10 def extract_input_size(self) -> int: - layers = self.network.children() - size = next(layers).weight.shape[1] + input_layer = next(self.extract_layers()) + size = input_layer.weight.shape[1] return cast(int, size) + def extract_layers(self) -> Iterator[Module]: + for layer in self.network.children(): + if hasattr(layer, "weight"): + yield layer + def test_second_standardize_no_effect(self) -> None: self.apply_transformation() state_dict = self.network.state_dict()