diff --git a/src/revnets/standardization/network.py b/src/revnets/standardization/network.py index 6dffdfb..9507b38 100644 --- a/src/revnets/standardization/network.py +++ b/src/revnets/standardization/network.py @@ -47,10 +47,8 @@ def _apply_optimize_mae(self) -> None: def calculate_average_scale_per_layer(self) -> float: connection = self.internal_connections[-1] standardizer = scale.Standardizer(connection) - output_scales = standardizer.calculate_outgoing_scales( - connection.output_weights - ) - output_scale = sum(output_scales) / len(output_scales) + scales = standardizer.calculate_outgoing_scales(connection.output_weights) + output_scale = sum(scales) / len(scales) num_internal_connections = len(self.internal_connections) average_scale = output_scale ** (1 / num_internal_connections) return cast(float, average_scale)