-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* generalize standardization * test generalized standardization
- Loading branch information
1 parent
a48aaa8
commit 68a28fe
Showing
42 changed files
with
385 additions
and
182 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,10 @@ | ||
import torch | ||
from torch.nn.functional import l1_loss | ||
|
||
from . import mse | ||
|
||
|
||
class Evaluator(mse.Evaluator): | ||
@classmethod | ||
def calculate_weights_distance( | ||
cls, original_weights: torch.Tensor, reconstructed_weights: torch.Tensor | ||
) -> float: | ||
distance = torch.nn.functional.l1_loss( | ||
original_weights, reconstructed_weights, reduction="sum" | ||
) | ||
return distance.item() | ||
def calculate_distance(cls, values: torch.Tensor, other: torch.Tensor) -> float: | ||
return l1_loss(values, other, reduction="sum").item() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,21 +1,16 @@ | ||
import torch | ||
from torch.nn.functional import l1_loss | ||
|
||
from . import mae | ||
|
||
|
||
class Evaluator(mae.Evaluator): | ||
def calculate_distance(self) -> float: | ||
def calculate_total_distance(self) -> float: | ||
return max( | ||
self.calculate_weights_distance(original, reconstruction) | ||
self.calculate_distance(original, reconstruction) | ||
for original, reconstruction in self.iterate_compared_layers() | ||
) | ||
|
||
@classmethod | ||
def calculate_weights_distance( | ||
cls, original_weights: torch.Tensor, reconstructed_weights: torch.Tensor | ||
) -> float: | ||
distances = torch.nn.functional.l1_loss( | ||
original_weights, reconstructed_weights, reduction="none" | ||
) | ||
distance = distances.max() | ||
return distance.item() | ||
def calculate_distance(cls, values: torch.Tensor, other: torch.Tensor) -> float: | ||
return l1_loss(values, other, reduction="none").max().item() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
from . import mediumnet, mininet | ||
from . import cnn, mediumnet, mininet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from . import lenet, mini |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from collections.abc import Iterable | ||
from dataclasses import dataclass | ||
|
||
from torch import nn | ||
|
||
from . import mini | ||
|
||
|
||
@dataclass | ||
class NetworkFactory(mini.NetworkFactory): | ||
hidden_size1: int = 120 | ||
hidden_size2: int = 84 | ||
|
||
def create_layers(self) -> Iterable[nn.Module]: | ||
yield from ( | ||
# 28 x 28 x 1 | ||
nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1), | ||
nn.Tanh(), | ||
# 24 x 24 x 6 | ||
nn.AvgPool2d(kernel_size=2, stride=2), | ||
# 12 x 12 x 6 | ||
nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1), | ||
nn.Tanh(), | ||
# 8 x 8 x 16 | ||
nn.AvgPool2d(kernel_size=2, stride=2), | ||
# 4 x 4 x 16 | ||
nn.Flatten(), | ||
# 256 | ||
nn.Linear(in_features=256, out_features=self.hidden_size1), | ||
nn.Tanh(), | ||
nn.Linear(in_features=self.hidden_size1, out_features=self.hidden_size2), | ||
nn.Tanh(), | ||
nn.Linear(in_features=self.hidden_size2, out_features=self.output_size), | ||
) |
Oops, something went wrong.