Skip to content

Commit

Permalink
fix standardization for all networks
Browse files Browse the repository at this point in the history
  • Loading branch information
quintenroets committed Apr 9, 2024
1 parent 7e0b12f commit 68a7f45
Show file tree
Hide file tree
Showing 9 changed files with 81 additions and 60 deletions.
27 changes: 12 additions & 15 deletions src/revnets/context/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

import torch
from package_utils.context import Context as Context_
from torch import nn

from ..models import Activation, Config, HyperParameters, Options, Path
from ..models import Config, HyperParameters, Options, Path


class Context(Context_[Options, Config, None]):
Expand Down Expand Up @@ -50,23 +49,21 @@ def log_path(self) -> Path:
def log_path_str(self) -> str:
return str(self.log_path)

@property
def activation_layer(self) -> nn.Module:
activation = context.config.target_network_training.activation
activation_layer: nn.Module
match activation:
case Activation.leaky_relu:
activation_layer = nn.LeakyReLU()
case Activation.relu:
activation_layer = nn.ReLU()
case Activation.tanh:
activation_layer = nn.Tanh()
return activation_layer

@cached_property
def device(self) -> torch.device:
name = "cuda" if torch.cuda.is_available() else "cpu"
return torch.device(name)

@property
def dtype(self) -> torch.dtype:
match self.config.precision:
case 32:
dtype = torch.float32
case 64:
dtype = torch.float64
case _:
raise ValueError(f"Unsupported precision {self.config.precision}")
return dtype


context = Context(Options, Config, None)
7 changes: 4 additions & 3 deletions src/revnets/evaluations/weights/standardize/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from functools import cached_property
from typing import TypeVar, cast

from torch.nn import Module
from torch.nn import Flatten, Module

from revnets.models import InternalNeurons

Expand Down Expand Up @@ -57,7 +57,7 @@ def generate_internal_neurons(model: Module) -> Iterator[InternalNeurons]:


def generate_triplets(items: list[T]) -> Iterator[tuple[T, T, T]]:
yield from zip(items, items[1:], items[2:])
yield from zip(items[::2], items[1::2], items[2::2])


def generate_layers(model: Module) -> Iterator[Module]:
Expand All @@ -69,4 +69,5 @@ def generate_layers(model: Module) -> Iterator[Module]:
for child in children:
yield from generate_layers(child)
else:
yield model
if not isinstance(model, Flatten):
yield model
8 changes: 4 additions & 4 deletions src/revnets/evaluations/weights/standardize/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,16 @@ class Standardizer:
neurons: InternalNeurons

def run(self) -> None:
sort_indices = calculate_sort_order(self.neurons.incoming)
sort_indices = calculate_sort_indices(self.neurons.incoming)
permute_output_weights(self.neurons.incoming, sort_indices)
permute_input_weights(self.neurons.outgoing, sort_indices)


def calculate_sort_order(layer: nn.Module) -> torch.Tensor:
def calculate_sort_indices(layer: nn.Module) -> torch.Tensor:
weights = extract_linear_layer_weights(layer)
p = 1 # use l1-norm because l2-norm is already standardized
total_output_weights = weights.norm(dim=1, p=p)
return torch.sort(total_output_weights)[1]
sort_values = weights.norm(dim=1, p=p)
return torch.sort(sort_values)[1]


def permute_input_weights(layer: nn.Module, sort_indices: torch.Tensor) -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/revnets/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class HyperParameters:
epochs: int = 1
bias_learning_rate: float | None = None
batch_size: int = 1
activation: Enum = Activation.leaky_relu
activation: Activation = Activation.leaky_relu


@dataclass
Expand Down
20 changes: 17 additions & 3 deletions src/revnets/networks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,35 @@
from dataclasses import dataclass

import torch
from torch.nn import Sequential
from torch import nn
from torch.nn import Module, Sequential

from revnets.models import Activation

from ..context import context
from ..utils import NamedClass


@dataclass(frozen=True)
class NetworkFactory(NamedClass):
activation_layer: torch.nn.Module = context.activation_layer
output_size: int = 10
activation: Activation = context.config.target_network_training.activation

def create_activation_layer(self) -> Module:
activation_layer: Module
match self.activation:
case Activation.leaky_relu:
activation_layer = nn.LeakyReLU()
case Activation.relu:
activation_layer = nn.ReLU()
case Activation.tanh:
activation_layer = nn.Tanh()
return activation_layer

def create_network(self, seed: int | None = None) -> Sequential:
if seed is not None:
torch.manual_seed(seed)
layers = self.create_layers()
layers = [layer.to(dtype=context.dtype) for layer in self.create_layers()]
return Sequential(*layers)

def create_layers(self) -> Iterable[torch.nn.Module]:
Expand Down
4 changes: 2 additions & 2 deletions src/revnets/networks/mediumnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ class NetworkFactory(base.NetworkFactory):
def create_layers(self) -> Iterable[nn.Module]:
return (
nn.Linear(self.input_size, self.hidden_size1),
self.activation_layer,
self.create_activation_layer(),
nn.Linear(self.hidden_size1, self.hidden_size2),
self.activation_layer,
self.create_activation_layer(),
nn.Linear(self.hidden_size2, self.output_size),
)
2 changes: 1 addition & 1 deletion src/revnets/networks/mininet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ class NetworkFactory(base.NetworkFactory):
def create_layers(self) -> Iterable[torch.nn.Module]:
return (
nn.Linear(self.input_size, self.hidden_size),
self.activation_layer,
self.create_activation_layer(),
nn.Linear(self.hidden_size, self.output_size),
)
41 changes: 23 additions & 18 deletions tests/evaluations/test_standardize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest
from revnets import networks
from torch import nn
from revnets.models import Activation

from tests.evaluations.verifier import StandardizationType, Verifier

Expand All @@ -11,58 +11,63 @@
StandardizationType.standardize,
StandardizationType.align,
)
network_modules = (networks.mininet,)
activation_layers = (
nn.ReLU(),
nn.LeakyReLU(),
nn.Tanh(),
network_modules = (
networks.mininet,
networks.mediumnet,
networks.images.mininet,
networks.images.mediumnet,
)
activations = (
Activation.leaky_relu,
Activation.relu,
Activation.tanh,
)


@pytest.mark.parametrize("network_module", network_modules)
@pytest.mark.parametrize("activation_layer", activation_layers)
@pytest.mark.parametrize("activation", activations)
@pytest.mark.parametrize("standardization_type", standardization_types)
def test_standardized_form(
network_module: ModuleType,
activation_layer: nn.Module,
activation: Activation,
standardization_type: StandardizationType,
) -> None:
tester = Verifier(network_module, activation_layer, standardization_type)
tester = Verifier(network_module, activation, standardization_type)
tester.test_standardized_form()


@pytest.mark.parametrize("network_module", network_modules)
@pytest.mark.parametrize("activation_layer", activation_layers)
@pytest.mark.parametrize("activation", activations)
@pytest.mark.parametrize("standardization_type", standardization_types)
def test_standardize_preserves_functionality(
network_module: ModuleType,
activation_layer: nn.Module,
activation: Activation,
standardization_type: StandardizationType,
) -> None:
tester = Verifier(network_module, activation_layer, standardization_type)
tester = Verifier(network_module, activation, standardization_type)
tester.test_functional_preservation()


@pytest.mark.parametrize("network_module", network_modules)
@pytest.mark.parametrize("activation_layer", activation_layers)
@pytest.mark.parametrize("activation", activations)
@pytest.mark.parametrize("standardization_type", standardization_types)
def test_standardized_form_and_functionality_preservation(
network_module: ModuleType,
activation_layer: nn.Module,
activation: Activation,
standardization_type: StandardizationType,
) -> None:
tester = Verifier(network_module, activation_layer, standardization_type)
tester = Verifier(network_module, activation, standardization_type)
tester.test_functional_preservation()
tester.test_standardized_form()


@pytest.mark.parametrize("network_module", network_modules)
@pytest.mark.parametrize("activation_layer", activation_layers)
@pytest.mark.parametrize("activation", activations)
@pytest.mark.parametrize("standardization_type", standardization_types)
def test_second_standardize_no_effect(
network_module: ModuleType,
activation_layer: nn.Module,
activation: Activation,
standardization_type: StandardizationType,
) -> None:
tester = Verifier(network_module, activation_layer, standardization_type)
tester = Verifier(network_module, activation, standardization_type)
tester.test_second_standardize_no_effect()
30 changes: 17 additions & 13 deletions tests/evaluations/verifier.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from collections.abc import Iterator
from dataclasses import dataclass
from enum import Enum
from functools import cached_property
from types import ModuleType
from typing import cast

import torch
from revnets.context import context
from revnets.evaluations.weights import standardize
from revnets.evaluations.weights.standardize import Standardizer
from revnets.models import InternalNeurons
from torch import nn
from torch.nn import Sequential
from revnets.models import Activation, InternalNeurons
from torch.nn import Module, Sequential


class StandardizationType(Enum):
Expand All @@ -21,7 +22,7 @@ class StandardizationType(Enum):
@dataclass
class Verifier:
network_module: ModuleType
activation_layer: nn.Module
activation: Activation
standardization_type: StandardizationType

def __post_init__(self) -> None:
Expand All @@ -43,9 +44,9 @@ def apply_transformation(self) -> None:
case StandardizationType.scale:
Standardizer(self.network).standardize_scale()
case StandardizationType.standardize:
return Standardizer(self.network).run()
Standardizer(self.network).run()
case StandardizationType.align:
return standardize.align(self.network, self.target)
standardize.align(self.network, self.target)

def test_functional_preservation(self) -> None:
inputs = self.create_network_inputs()
Expand All @@ -54,9 +55,7 @@ def test_functional_preservation(self) -> None:
self.apply_transformation()
with torch.no_grad():
outputs_after_transformation = self.network(inputs)
outputs_are_closes = torch.isclose(
outputs, outputs_after_transformation, rtol=1e-3
)
outputs_are_closes = torch.isclose(outputs, outputs_after_transformation)
assert torch.all(outputs_are_closes)

def verify_standardized_form(self) -> None:
Expand All @@ -78,19 +77,24 @@ def verify_aligned_form(self) -> None:

def create_network(self) -> Sequential:
Factory = self.network_module.NetworkFactory
factory = Factory(activation_layer=self.activation_layer)
factory = Factory(activation=self.activation)
network = factory.create_network()
return cast(Sequential, network)

def create_network_inputs(self) -> torch.Tensor:
size = 1, self.extract_input_size()
return torch.rand(size) * 20 - 10
return torch.rand(size, dtype=context.dtype) * 20 - 10

def extract_input_size(self) -> int:
layers = self.network.children()
size = next(layers).weight.shape[1]
input_layer = next(self.extract_layers())
size = input_layer.weight.shape[1]
return cast(int, size)

def extract_layers(self) -> Iterator[Module]:
for layer in self.network.children():
if hasattr(layer, "weight"):
yield layer

def test_second_standardize_no_effect(self) -> None:
self.apply_transformation()
state_dict = self.network.state_dict()
Expand Down

0 comments on commit 68a7f45

Please sign in to comment.