From 8fe5dd6094e2c3c1f18227f79beaf721e1c8619f Mon Sep 17 00:00:00 2001 From: Matthew Muckley Date: Thu, 14 Mar 2024 21:21:03 +0000 Subject: [PATCH 1/5] Make EntropyBottleneck compilable --- compressai/entropy_models/entropy_models.py | 36 ++++++++++++--------- compressai/models/base.py | 3 +- compressai/models/utils.py | 22 +++++++++++++ tests/test_entropy_models.py | 23 +++++++++++++ 4 files changed, 68 insertions(+), 16 deletions(-) diff --git a/compressai/entropy_models/entropy_models.py b/compressai/entropy_models/entropy_models.py index 8068a6d2..ec5b69c8 100644 --- a/compressai/entropy_models/entropy_models.py +++ b/compressai/entropy_models/entropy_models.py @@ -29,7 +29,7 @@ import warnings -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, Callable, List, Mapping, Optional, Tuple, Union import numpy as np import scipy.stats @@ -360,20 +360,24 @@ def __init__( scale = self.init_scale ** (1 / (len(self.filters) + 1)) channels = self.channels + self.matrices = nn.ParameterList() + self.biases = nn.ParameterList() + self.factors = nn.ParameterList() + for i in range(len(self.filters) + 1): init = np.log(np.expm1(1 / scale / filters[i + 1])) matrix = torch.Tensor(channels, filters[i + 1], filters[i]) matrix.data.fill_(init) - self.register_parameter(f"_matrix{i:d}", nn.Parameter(matrix)) + self.matrices.append(nn.Parameter(matrix)) bias = torch.Tensor(channels, filters[i + 1], 1) nn.init.uniform_(bias, -0.5, 0.5) - self.register_parameter(f"_bias{i:d}", nn.Parameter(bias)) + self.biases.append(nn.Parameter(bias)) if i < len(self.filters): factor = torch.Tensor(channels, filters[i + 1], 1) nn.init.zeros_(factor) - self.register_parameter(f"_factor{i:d}", nn.Parameter(factor)) + self.factors.append(nn.Parameter(factor)) self.quantiles = nn.Parameter(torch.Tensor(channels, 1, 3)) init = torch.Tensor([-self.init_scale, 0, self.init_scale]) @@ -433,24 +437,23 @@ def _logits_cumulative(self, inputs: Tensor, stop_gradient: bool) -> Tensor: # TorchScript not yet working (nn.Mmodule indexing not supported) logits = inputs for i in range(len(self.filters) + 1): - matrix = getattr(self, f"_matrix{i:d}") + matrix = self.matrices[i] if stop_gradient: matrix = matrix.detach() logits = torch.matmul(F.softplus(matrix), logits) - bias = getattr(self, f"_bias{i:d}") + bias = self.biases[i] if stop_gradient: bias = bias.detach() - logits += bias + logits = logits + bias if i < len(self.filters): - factor = getattr(self, f"_factor{i:d}") + factor = self.factors[i] if stop_gradient: factor = factor.detach() - logits += torch.tanh(factor) * torch.tanh(logits) + logits = logits + torch.tanh(factor) * torch.tanh(logits) return logits - @torch.jit.unused def _likelihood( self, inputs: Tensor, stop_gradient: bool = False ) -> Tuple[Tensor, Tensor, Tensor]: @@ -468,16 +471,19 @@ def forward( if not torch.jit.is_scripting(): # x from B x C x ... to C x B x ... - perm = np.arange(len(x.shape)) - perm[0], perm[1] = perm[1], perm[0] - # Compute inverse permutation - inv_perm = np.arange(len(x.shape))[np.argsort(perm)] + perm = torch.cat( + ( + torch.tensor([1, 0], dtype=torch.long, device=x.device), + torch.arange(2, x.ndim, dtype=torch.long, device=x.device), + ) + ) + inv_perm = perm else: raise NotImplementedError() # TorchScript in 2D for static inference # Convert to (channels, ... , batch) format # perm = (1, 2, 3, 0) - # inv_perm = (3, 0, 1, 2) + # inv_perm = (3, 0, 1, 2): x = x.permute(*perm).contiguous() shape = x.size() diff --git a/compressai/models/base.py b/compressai/models/base.py index a7978ea8..b0791943 100644 --- a/compressai/models/base.py +++ b/compressai/models/base.py @@ -39,7 +39,7 @@ from compressai.entropy_models import EntropyBottleneck, GaussianConditional from compressai.latent_codecs import LatentCodec -from compressai.models.utils import update_registered_buffers +from compressai.models.utils import update_registered_buffers, remap_old_keys __all__ = [ "CompressionModel", @@ -103,6 +103,7 @@ def load_state_dict(self, state_dict, strict=True): ["_quantized_cdf", "_offset", "_cdf_length"], state_dict, ) + state_dict = remap_old_keys(name, state_dict) if isinstance(module, GaussianConditional): update_registered_buffers( diff --git a/compressai/models/utils.py b/compressai/models/utils.py index 57966dc5..9426f0d5 100644 --- a/compressai/models/utils.py +++ b/compressai/models/utils.py @@ -27,10 +27,14 @@ # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from collections import OrderedDict + import torch import torch.nn as nn import torch.nn.functional as F +KEY_MAP = {"_bias": "biases", "_matrix": "matrices", "_factor": "factors"} + def find_named_module(module, query): """Helper function to find a named module. Returns a `nn.Module` or `None` @@ -125,6 +129,24 @@ def update_registered_buffers( ) +def remap_old_keys(module_name, state_dict): + def remap_subkey(s: str) -> str: + for k, v in KEY_MAP.items(): + if s.startswith(k): + return ".".join((v, s.split(k)[1])) + + return s + + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + if k.startswith(module_name): + k = ".".join((module_name, remap_subkey(k.split(f"{module_name}.")[1]))) + + new_state_dict[k] = v + + return new_state_dict + + def conv(in_channels, out_channels, kernel_size=5, stride=2): return nn.Conv2d( in_channels, diff --git a/tests/test_entropy_models.py b/tests/test_entropy_models.py index 1eeb7662..a2ec309d 100644 --- a/tests/test_entropy_models.py +++ b/tests/test_entropy_models.py @@ -242,6 +242,29 @@ def test_loss(self): # assert torch.allclose(y0[0], y1[0]) # assert torch.all(y1[1] == 0) # not yet supported + def test_compiling(self): + entropy_bottleneck = EntropyBottleneck(128) + x0 = torch.rand(1, 128, 32, 32) + x1 = x0.clone() + x0.requires_grad_(True) + x1.requires_grad_(True) + + torch.manual_seed(32) + y0 = entropy_bottleneck(x0) + + m = torch.compile(entropy_bottleneck) + + torch.manual_seed(32) + y1 = m(x1) + + assert torch.allclose(y0[0], y1[0]) + assert torch.allclose(y0[1], y1[1]) + + y0[0].sum().backward() + y1[0].sum().backward() + + assert torch.allclose(x0.grad, x1.grad) + def test_update(self): # get a pretrained model net = bmshj2018_factorized(quality=1, pretrained=True).eval() From 3836d7911c3d94df0fe9dbba5cd37f04851d3135 Mon Sep 17 00:00:00 2001 From: Matthew Muckley Date: Thu, 14 Mar 2024 21:23:08 +0000 Subject: [PATCH 2/5] Remove unnecessary import --- compressai/entropy_models/entropy_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compressai/entropy_models/entropy_models.py b/compressai/entropy_models/entropy_models.py index ec5b69c8..dd312fd5 100644 --- a/compressai/entropy_models/entropy_models.py +++ b/compressai/entropy_models/entropy_models.py @@ -29,7 +29,7 @@ import warnings -from typing import Any, Callable, List, Mapping, Optional, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, Union import numpy as np import scipy.stats From 6d08da7138650ee946994945e2cd8d4f9b6cedbb Mon Sep 17 00:00:00 2001 From: Matthew Muckley Date: Sun, 28 Apr 2024 01:45:12 +0000 Subject: [PATCH 3/5] Change PyTorch version for compile test --- tests/test_entropy_models.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_entropy_models.py b/tests/test_entropy_models.py index a2ec309d..789a26e7 100644 --- a/tests/test_entropy_models.py +++ b/tests/test_entropy_models.py @@ -32,6 +32,8 @@ import pytest import torch +from packaging import version + from compressai.entropy_models import ( EntropyBottleneck, EntropyModel, @@ -242,6 +244,10 @@ def test_loss(self): # assert torch.allclose(y0[0], y1[0]) # assert torch.all(y1[1] == 0) # not yet supported + @pytest.mark.skipif( + version.parse(torch.__version__) < version.parse("2.0.0"), + reason="torch.compile only available for torch>=2.0", + ) def test_compiling(self): entropy_bottleneck = EntropyBottleneck(128) x0 = torch.rand(1, 128, 32, 32) From 4c1b28d7ad64eb765ec78c91e5653e6b76813a91 Mon Sep 17 00:00:00 2001 From: Matthew Muckley Date: Mon, 29 Apr 2024 21:14:04 +0000 Subject: [PATCH 4/5] Remove colon --- compressai/entropy_models/entropy_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compressai/entropy_models/entropy_models.py b/compressai/entropy_models/entropy_models.py index dd312fd5..47744440 100644 --- a/compressai/entropy_models/entropy_models.py +++ b/compressai/entropy_models/entropy_models.py @@ -483,7 +483,7 @@ def forward( # TorchScript in 2D for static inference # Convert to (channels, ... , batch) format # perm = (1, 2, 3, 0) - # inv_perm = (3, 0, 1, 2): + # inv_perm = (3, 0, 1, 2) x = x.permute(*perm).contiguous() shape = x.size() From 3ceb497e59513a0dde8282c0571b3ede3f22b27b Mon Sep 17 00:00:00 2001 From: Fabien Racape Date: Tue, 30 Apr 2024 19:43:09 +0200 Subject: [PATCH 5/5] formatting --- compressai/models/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compressai/models/base.py b/compressai/models/base.py index b0791943..64ff30ea 100644 --- a/compressai/models/base.py +++ b/compressai/models/base.py @@ -39,7 +39,7 @@ from compressai.entropy_models import EntropyBottleneck, GaussianConditional from compressai.latent_codecs import LatentCodec -from compressai.models.utils import update_registered_buffers, remap_old_keys +from compressai.models.utils import remap_old_keys, update_registered_buffers __all__ = [ "CompressionModel",