Skip to content

Commit

Permalink
feat(models): add ELIC
Browse files Browse the repository at this point in the history
Context model from [He2022].

[He2022]: `"ELIC: Efficient Learned Image Compression with
Unevenly Grouped Space-Channel Contextual Adaptive Coding"
<https://arxiv.org/abs/2203.10886>`_, by Dailan He, Ziming Yang,
Weikun Peng, Rui Ma, Hongwei Qin, and Yan Wang, CVPR 2022.
  • Loading branch information
YodaEmbedding committed Jul 21, 2023
1 parent b004d1e commit ee756b0
Show file tree
Hide file tree
Showing 5 changed files with 351 additions and 1 deletion.
2 changes: 2 additions & 0 deletions compressai/latent_codecs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from .base import LatentCodec
from .channel_groups import ChannelGroupsLatentCodec
from .checkerboard import CheckerboardLatentCodec
from .entropy_bottleneck import EntropyBottleneckLatentCodec
from .gain import GainHyperLatentCodec, GainHyperpriorLatentCodec
Expand All @@ -38,6 +39,7 @@

__all__ = [
"LatentCodec",
"ChannelGroupsLatentCodec",
"CheckerboardLatentCodec",
"EntropyBottleneckLatentCodec",
"GainHyperLatentCodec",
Expand Down
162 changes: 162 additions & 0 deletions compressai/latent_codecs/channel_groups.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Copyright (c) 2021-2022, InterDigital Communications, Inc
# All rights reserved.

# Redistribution and use in source and binary forms, with or without
# modification, are permitted (subject to the limitations in the disclaimer
# below) provided that the following conditions are met:

# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of InterDigital Communications, Inc nor the names of its
# contributors may be used to endorse or promote products derived from this
# software without specific prior written permission.

# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT
# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from itertools import accumulate
from typing import Any, Dict, List, Mapping, Optional, Tuple

import torch
import torch.nn as nn

from torch import Tensor

from compressai.registry import register_module

from .base import LatentCodec

__all__ = [
"ChannelGroupsLatentCodec",
]


@register_module("ChannelGroupsLatentCodec")
class ChannelGroupsLatentCodec(LatentCodec):
"""Reconstructs groups of channels using previously decoded groups.
Context model from [Minnen2020] and [He2022].
Also known as a "channel-conditional" (CC) entropy model.
See :py:class:`~compressai.models.sensetime.Cheng2020AnchorElic`
for example usage.
[Minnen2020]: `"Channel-wise Autoregressive Entropy Models for
Learned Image Compression" <https://arxiv.org/abs/2007.08739>`_, by
David Minnen, and Saurabh Singh, ICIP 2020.
[He2022]: `"ELIC: Efficient Learned Image Compression with
Unevenly Grouped Space-Channel Contextual Adaptive Coding"
<https://arxiv.org/abs/2203.10886>`_, by Dailan He, Ziming Yang,
Weikun Peng, Rui Ma, Hongwei Qin, and Yan Wang, CVPR 2022.
"""

latent_codec: Mapping[str, LatentCodec]

channel_context: Mapping[str, nn.Module]

def __init__(
self,
latent_codec: Optional[Mapping[str, LatentCodec]] = None,
channel_context: Optional[Mapping[str, nn.Module]] = None,
*,
groups: List[int],
**kwargs,
):
super().__init__()
self._kwargs = kwargs
self.groups = list(groups)
self.groups_acc = list(accumulate(self.groups, initial=0))
self.channel_context = nn.ModuleDict(channel_context)
self.latent_codec = nn.ModuleDict(latent_codec)

def forward(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]:
y_ = torch.split(y, self.groups, dim=1)
y_out_ = [{}] * len(self.groups)
y_hat_ = [Tensor()] * len(self.groups)
y_likelihoods_ = [Tensor()] * len(self.groups)

for k in range(len(self.groups)):
y_hat_prev = torch.cat(y_hat_[:k], dim=1) if k > 0 else Tensor()
params = self._get_ctx_params(k, side_params, y_hat_prev)
y_out_[k] = self.latent_codec[f"y{k}"](y_[k], params)
y_hat_[k] = y_out_[k]["y_hat"]
y_likelihoods_[k] = y_out_[k]["likelihoods"]["y"]

y_hat = torch.cat(y_hat_, dim=1)
y_likelihoods = torch.cat(y_likelihoods_, dim=1)

return {
"likelihoods": {
"y": y_likelihoods,
},
"y_hat": y_hat,
}

def compress(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]:
y_ = torch.split(y, self.groups, dim=1)
y_out_ = [{}] * len(self.groups)
y_hat = torch.zeros_like(y)

for k in range(len(self.groups)):
y_hat_prev = y_hat[:, : self.groups_acc[k]]
params = self._get_ctx_params(k, side_params, y_hat_prev)
y_out_[k] = self.latent_codec[f"y{k}"].compress(y_[k], params)
y_hat[:, self.groups_acc[k] : self.groups_acc[k + 1]] = y_out_[k]["y_hat"]

y_strings_groups = [y_out["strings"] for y_out in y_out_]
assert all(len(y_strings_groups[0]) == len(ss) for ss in y_strings_groups)

return {
"strings": [s for ss in y_strings_groups for s in ss],
"shape": [y_out["shape"] for y_out in y_out_],
"y_hat": y_hat,
}

def decompress(
self,
strings: List[List[bytes]],
shape: List[Tuple[int, ...]],
side_params: Tensor,
) -> Dict[str, Any]:
n = len(strings[0])
assert all(len(ss) == n for ss in strings)
strings_per_group = len(strings) // len(self.groups)

y_out_ = [{}] * len(self.groups)
y_shape = (sum(s[0] for s in shape), *shape[0][1:])
y_hat = torch.zeros((n, *y_shape), device=side_params.device)

for k in range(len(self.groups)):
y_hat_prev = y_hat[:, : self.groups_acc[k]]
params = self._get_ctx_params(k, side_params, y_hat_prev)
y_strings_k = strings[strings_per_group * k : strings_per_group * (k + 1)]
y_out_[k] = self.latent_codec[f"y{k}"].decompress(
y_strings_k, shape[k], params
)
y_hat[:, self.groups_acc[k] : self.groups_acc[k + 1]] = y_out_[k]["y_hat"]

return {
"y_hat": y_hat,
}

def _get_ctx_params(
self, k: int, side_params: Tensor, y_hat_prev: Tensor
) -> Tensor:
if k == 0:
return side_params
params_ch = self.channel_context[f"y{k}"](y_hat_prev)
return torch.cat([side_params, params_ch], dim=1)
37 changes: 37 additions & 0 deletions compressai/layers/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import math

from typing import Any

import torch
Expand All @@ -47,6 +49,7 @@
"conv3x3",
"subpel_conv3x3",
"QReLU",
"sequential_channel_ramp",
]


Expand Down Expand Up @@ -321,3 +324,37 @@ def backward(ctx, grad_output):
grad_input[input > ctx.max_value] = grad_sub[input > ctx.max_value]

return grad_input, None, None


def sequential_channel_ramp(
in_ch: int,
out_ch: int,
*,
num_layers: int = 3,
interp: str = "linear",
make_layer=None,
make_act=None,
skip_last_act: bool = True,
**layer_kwargs,
) -> nn.Module:
"""Interleave layers of gradually ramping channels with nonlinearities."""
channels = ramp(in_ch, out_ch, num_layers + 1, method=interp).floor().int().tolist()
layers = [
module
for ch_in, ch_out in zip(channels[:-1], channels[1:])
for module in [
make_layer(ch_in, ch_out, **layer_kwargs),
make_act(),
]
]
if skip_last_act:
layers = layers[:-1]
return nn.Sequential(*layers)


def ramp(a, b, steps=None, method="linear", **kwargs):
if method == "linear":
return torch.linspace(a, b, steps, **kwargs)
if method == "log":
return torch.logspace(math.log10(a), math.log10(b), steps, **kwargs)
raise ValueError(f"Unknown ramp method: {method}")
Loading

0 comments on commit ee756b0

Please sign in to comment.