Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Store nn.Parameter in entropy_models.py in nn.ParameterList #284

Merged
merged 5 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 19 additions & 13 deletions compressai/entropy_models/entropy_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,20 +360,24 @@ def __init__(
scale = self.init_scale ** (1 / (len(self.filters) + 1))
channels = self.channels

self.matrices = nn.ParameterList()
self.biases = nn.ParameterList()
self.factors = nn.ParameterList()

for i in range(len(self.filters) + 1):
init = np.log(np.expm1(1 / scale / filters[i + 1]))
matrix = torch.Tensor(channels, filters[i + 1], filters[i])
matrix.data.fill_(init)
self.register_parameter(f"_matrix{i:d}", nn.Parameter(matrix))
self.matrices.append(nn.Parameter(matrix))

bias = torch.Tensor(channels, filters[i + 1], 1)
nn.init.uniform_(bias, -0.5, 0.5)
self.register_parameter(f"_bias{i:d}", nn.Parameter(bias))
self.biases.append(nn.Parameter(bias))

if i < len(self.filters):
factor = torch.Tensor(channels, filters[i + 1], 1)
nn.init.zeros_(factor)
self.register_parameter(f"_factor{i:d}", nn.Parameter(factor))
self.factors.append(nn.Parameter(factor))

self.quantiles = nn.Parameter(torch.Tensor(channels, 1, 3))
init = torch.Tensor([-self.init_scale, 0, self.init_scale])
Expand Down Expand Up @@ -433,24 +437,23 @@ def _logits_cumulative(self, inputs: Tensor, stop_gradient: bool) -> Tensor:
# TorchScript not yet working (nn.Mmodule indexing not supported)
logits = inputs
for i in range(len(self.filters) + 1):
matrix = getattr(self, f"_matrix{i:d}")
matrix = self.matrices[i]
if stop_gradient:
matrix = matrix.detach()
logits = torch.matmul(F.softplus(matrix), logits)

bias = getattr(self, f"_bias{i:d}")
bias = self.biases[i]
if stop_gradient:
bias = bias.detach()
logits += bias
logits = logits + bias

if i < len(self.filters):
factor = getattr(self, f"_factor{i:d}")
factor = self.factors[i]
if stop_gradient:
factor = factor.detach()
logits += torch.tanh(factor) * torch.tanh(logits)
logits = logits + torch.tanh(factor) * torch.tanh(logits)
return logits

@torch.jit.unused
def _likelihood(
self, inputs: Tensor, stop_gradient: bool = False
) -> Tuple[Tensor, Tensor, Tensor]:
Expand All @@ -468,10 +471,13 @@ def forward(

if not torch.jit.is_scripting():
# x from B x C x ... to C x B x ...
perm = np.arange(len(x.shape))
perm[0], perm[1] = perm[1], perm[0]
# Compute inverse permutation
inv_perm = np.arange(len(x.shape))[np.argsort(perm)]
perm = torch.cat(
(
torch.tensor([1, 0], dtype=torch.long, device=x.device),
torch.arange(2, x.ndim, dtype=torch.long, device=x.device),
)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another possibility:

    perm = torch.tensor(
        [1, 0, *range(2, x.ndim)], dtype=torch.long, device=x.device
    )

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, might boil down to the same thing under the hood :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @YodaEmbedding, thanks for the suggestion. If it's okay, I would like to argue for the current implementation, as the modification relies on a Python-level range iterator and unpacking it into a list, which can lead to a lot of Python calls under the hood.

When working with frameworks like torch.jit and torch.compile, I often find that these kinds of constructs can be difficult for the compiler, as the most shaky parts of those libraries are around understanding Python. By keeping everything as PyTorch calls, the compilers seem to perform more stably.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good.

inv_perm = perm
else:
raise NotImplementedError()
# TorchScript in 2D for static inference
Expand Down
3 changes: 2 additions & 1 deletion compressai/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

from compressai.entropy_models import EntropyBottleneck, GaussianConditional
from compressai.latent_codecs import LatentCodec
from compressai.models.utils import update_registered_buffers
from compressai.models.utils import remap_old_keys, update_registered_buffers

__all__ = [
"CompressionModel",
Expand Down Expand Up @@ -103,6 +103,7 @@ def load_state_dict(self, state_dict, strict=True):
["_quantized_cdf", "_offset", "_cdf_length"],
state_dict,
)
state_dict = remap_old_keys(name, state_dict)

if isinstance(module, GaussianConditional):
update_registered_buffers(
Expand Down
22 changes: 22 additions & 0 deletions compressai/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,14 @@
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F

KEY_MAP = {"_bias": "biases", "_matrix": "matrices", "_factor": "factors"}


def find_named_module(module, query):
"""Helper function to find a named module. Returns a `nn.Module` or `None`
Expand Down Expand Up @@ -125,6 +129,24 @@ def update_registered_buffers(
)


def remap_old_keys(module_name, state_dict):
def remap_subkey(s: str) -> str:
for k, v in KEY_MAP.items():
if s.startswith(k):
return ".".join((v, s.split(k)[1]))

return s

new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k.startswith(module_name):
k = ".".join((module_name, remap_subkey(k.split(f"{module_name}.")[1])))

new_state_dict[k] = v

return new_state_dict


def conv(in_channels, out_channels, kernel_size=5, stride=2):
return nn.Conv2d(
in_channels,
Expand Down
29 changes: 29 additions & 0 deletions tests/test_entropy_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
import pytest
import torch

from packaging import version

from compressai.entropy_models import (
EntropyBottleneck,
EntropyModel,
Expand Down Expand Up @@ -242,6 +244,33 @@ def test_loss(self):
# assert torch.allclose(y0[0], y1[0])
# assert torch.all(y1[1] == 0) # not yet supported

@pytest.mark.skipif(
version.parse(torch.__version__) < version.parse("2.0.0"),
reason="torch.compile only available for torch>=2.0",
)
def test_compiling(self):
entropy_bottleneck = EntropyBottleneck(128)
x0 = torch.rand(1, 128, 32, 32)
x1 = x0.clone()
x0.requires_grad_(True)
x1.requires_grad_(True)

torch.manual_seed(32)
y0 = entropy_bottleneck(x0)

m = torch.compile(entropy_bottleneck)

torch.manual_seed(32)
y1 = m(x1)

assert torch.allclose(y0[0], y1[0])
assert torch.allclose(y0[1], y1[1])

y0[0].sum().backward()
y1[0].sum().backward()

assert torch.allclose(x0.grad, x1.grad)

def test_update(self):
# get a pretrained model
net = bmshj2018_factorized(quality=1, pretrained=True).eval()
Expand Down
Loading