diff --git a/compressai/latent_codecs/__init__.py b/compressai/latent_codecs/__init__.py index 033607a6..8e8a9b7c 100644 --- a/compressai/latent_codecs/__init__.py +++ b/compressai/latent_codecs/__init__.py @@ -28,6 +28,7 @@ # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. from .base import LatentCodec +from .channel_groups import ChannelGroupsLatentCodec from .checkerboard import CheckerboardLatentCodec from .entropy_bottleneck import EntropyBottleneckLatentCodec from .gain import GainHyperLatentCodec, GainHyperpriorLatentCodec @@ -38,6 +39,7 @@ __all__ = [ "LatentCodec", + "ChannelGroupsLatentCodec", "CheckerboardLatentCodec", "EntropyBottleneckLatentCodec", "GainHyperLatentCodec", diff --git a/compressai/latent_codecs/channel_groups.py b/compressai/latent_codecs/channel_groups.py new file mode 100644 index 00000000..0cd74baa --- /dev/null +++ b/compressai/latent_codecs/channel_groups.py @@ -0,0 +1,162 @@ +# Copyright (c) 2021-2022, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from itertools import accumulate +from typing import Any, Dict, List, Mapping, Optional, Tuple + +import torch +import torch.nn as nn + +from torch import Tensor + +from compressai.registry import register_module + +from .base import LatentCodec + +__all__ = [ + "ChannelGroupsLatentCodec", +] + + +@register_module("ChannelGroupsLatentCodec") +class ChannelGroupsLatentCodec(LatentCodec): + """Reconstructs groups of channels using previously decoded groups. + + Context model from [Minnen2020] and [He2022]. + Also known as a "channel-conditional" (CC) entropy model. + + See :py:class:`~compressai.models.sensetime.Cheng2020AnchorElic` + for example usage. + + [Minnen2020]: `"Channel-wise Autoregressive Entropy Models for + Learned Image Compression" `_, by + David Minnen, and Saurabh Singh, ICIP 2020. + + [He2022]: `"ELIC: Efficient Learned Image Compression with + Unevenly Grouped Space-Channel Contextual Adaptive Coding" + `_, by Dailan He, Ziming Yang, + Weikun Peng, Rui Ma, Hongwei Qin, and Yan Wang, CVPR 2022. + """ + + latent_codec: Mapping[str, LatentCodec] + + channel_context: Mapping[str, nn.Module] + + def __init__( + self, + latent_codec: Optional[Mapping[str, LatentCodec]] = None, + channel_context: Optional[Mapping[str, nn.Module]] = None, + *, + groups: List[int], + **kwargs, + ): + super().__init__() + self._kwargs = kwargs + self.groups = list(groups) + self.groups_acc = list(accumulate(self.groups, initial=0)) + self.channel_context = nn.ModuleDict(channel_context) + self.latent_codec = nn.ModuleDict(latent_codec) + + def forward(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: + y_ = torch.split(y, self.groups, dim=1) + y_out_ = [{}] * len(self.groups) + y_hat_ = [Tensor()] * len(self.groups) + y_likelihoods_ = [Tensor()] * len(self.groups) + + for k in range(len(self.groups)): + y_hat_prev = torch.cat(y_hat_[:k], dim=1) if k > 0 else Tensor() + params = self._get_ctx_params(k, side_params, y_hat_prev) + y_out_[k] = self.latent_codec[f"y{k}"](y_[k], params) + y_hat_[k] = y_out_[k]["y_hat"] + y_likelihoods_[k] = y_out_[k]["likelihoods"]["y"] + + y_hat = torch.cat(y_hat_, dim=1) + y_likelihoods = torch.cat(y_likelihoods_, dim=1) + + return { + "likelihoods": { + "y": y_likelihoods, + }, + "y_hat": y_hat, + } + + def compress(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: + y_ = torch.split(y, self.groups, dim=1) + y_out_ = [{}] * len(self.groups) + y_hat = torch.zeros_like(y) + + for k in range(len(self.groups)): + y_hat_prev = y_hat[:, : self.groups_acc[k]] + params = self._get_ctx_params(k, side_params, y_hat_prev) + y_out_[k] = self.latent_codec[f"y{k}"].compress(y_[k], params) + y_hat[:, self.groups_acc[k] : self.groups_acc[k + 1]] = y_out_[k]["y_hat"] + + y_strings_groups = [y_out["strings"] for y_out in y_out_] + assert all(len(y_strings_groups[0]) == len(ss) for ss in y_strings_groups) + + return { + "strings": [s for ss in y_strings_groups for s in ss], + "shape": [y_out["shape"] for y_out in y_out_], + "y_hat": y_hat, + } + + def decompress( + self, + strings: List[List[bytes]], + shape: List[Tuple[int, ...]], + side_params: Tensor, + ) -> Dict[str, Any]: + n = len(strings[0]) + assert all(len(ss) == n for ss in strings) + strings_per_group = len(strings) // len(self.groups) + + y_out_ = [{}] * len(self.groups) + y_shape = (sum(s[0] for s in shape), *shape[0][1:]) + y_hat = torch.zeros((n, *y_shape), device=side_params.device) + + for k in range(len(self.groups)): + y_hat_prev = y_hat[:, : self.groups_acc[k]] + params = self._get_ctx_params(k, side_params, y_hat_prev) + y_strings_k = strings[strings_per_group * k : strings_per_group * (k + 1)] + y_out_[k] = self.latent_codec[f"y{k}"].decompress( + y_strings_k, shape[k], params + ) + y_hat[:, self.groups_acc[k] : self.groups_acc[k + 1]] = y_out_[k]["y_hat"] + + return { + "y_hat": y_hat, + } + + def _get_ctx_params( + self, k: int, side_params: Tensor, y_hat_prev: Tensor + ) -> Tensor: + if k == 0: + return side_params + params_ch = self.channel_context[f"y{k}"](y_hat_prev) + return torch.cat([side_params, params_ch], dim=1) diff --git a/compressai/layers/layers.py b/compressai/layers/layers.py index d85846f8..a9fae68a 100644 --- a/compressai/layers/layers.py +++ b/compressai/layers/layers.py @@ -27,6 +27,8 @@ # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import math + from typing import Any import torch @@ -47,6 +49,7 @@ "conv3x3", "subpel_conv3x3", "QReLU", + "sequential_channel_ramp", ] @@ -321,3 +324,37 @@ def backward(ctx, grad_output): grad_input[input > ctx.max_value] = grad_sub[input > ctx.max_value] return grad_input, None, None + + +def sequential_channel_ramp( + in_ch: int, + out_ch: int, + *, + num_layers: int = 3, + interp: str = "linear", + make_layer=None, + make_act=None, + skip_last_act: bool = True, + **layer_kwargs, +) -> nn.Module: + """Interleave layers of gradually ramping channels with nonlinearities.""" + channels = ramp(in_ch, out_ch, num_layers + 1, method=interp).floor().int().tolist() + layers = [ + module + for ch_in, ch_out in zip(channels[:-1], channels[1:]) + for module in [ + make_layer(ch_in, ch_out, **layer_kwargs), + make_act(), + ] + ] + if skip_last_act: + layers = layers[:-1] + return nn.Sequential(*layers) + + +def ramp(a, b, steps=None, method="linear", **kwargs): + if method == "linear": + return torch.linspace(a, b, steps, **kwargs) + if method == "log": + return torch.logspace(math.log10(a), math.log10(b), steps, **kwargs) + raise ValueError(f"Unknown ramp method: {method}") diff --git a/compressai/models/sensetime.py b/compressai/models/sensetime.py index bd260447..a8909d9d 100644 --- a/compressai/models/sensetime.py +++ b/compressai/models/sensetime.py @@ -27,29 +27,34 @@ # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from itertools import accumulate + import torch.nn as nn from compressai.entropy_models import EntropyBottleneck from compressai.latent_codecs import ( + ChannelGroupsLatentCodec, CheckerboardLatentCodec, GaussianConditionalLatentCodec, HyperLatentCodec, HyperpriorLatentCodec, ) from compressai.layers import ( + CheckerboardMaskedConv2d, ResidualBlock, ResidualBlockUpsample, ResidualBlockWithStride, conv3x3, + sequential_channel_ramp, subpel_conv3x3, ) -from compressai.layers.layers import CheckerboardMaskedConv2d from compressai.registry import register_model from .base import SimpleVAECompressionModel __all__ = [ "Cheng2020AnchorCheckerboard", + "Cheng2020AnchorElic", ] @@ -153,3 +158,140 @@ def from_state_dict(cls, state_dict): net = cls(N) net.load_state_dict(state_dict) return net + + +@register_model("cheng2020-anchor-elic") +class Cheng2020AnchorElic(SimpleVAECompressionModel): + """Cheng2020 anchor model with checkerboard context model. + + Base transform model from [Cheng2020]. Context model from [He2022]. + + [Cheng2020]: `"Learned Image Compression with Discretized Gaussian + Mixture Likelihoods and Attention Modules" + `_, by Zhengxue Cheng, Heming Sun, + Masaru Takeuchi, and Jiro Katto, CVPR 2020. + + [He2022]: `"ELIC: Efficient Learned Image Compression with + Unevenly Grouped Space-Channel Contextual Adaptive Coding" + `_, by Dailan He, Ziming Yang, + Weikun Peng, Rui Ma, Hongwei Qin, and Yan Wang, CVPR 2022. + + Uses residual blocks with small convolutions (3x3 and 1x1), and sub-pixel + convolutions for up-sampling. + + Args: + N (int): Number of channels + groups (list[int]): Number of channels in each channel group + """ + + def __init__(self, N=192, groups=None, **kwargs): + super().__init__(**kwargs) + + if groups is None: + groups = [16, 16, 32, 64, 64] + + assert sum(groups) == N + self.groups = list(groups) + self.groups_acc = list(accumulate(self.groups, initial=0)) + + self.g_a = nn.Sequential( + ResidualBlockWithStride(3, N, stride=2), + ResidualBlock(N, N), + ResidualBlockWithStride(N, N, stride=2), + ResidualBlock(N, N), + ResidualBlockWithStride(N, N, stride=2), + ResidualBlock(N, N), + conv3x3(N, N, stride=2), + ) + + self.g_s = nn.Sequential( + ResidualBlock(N, N), + ResidualBlockUpsample(N, N, 2), + ResidualBlock(N, N), + ResidualBlockUpsample(N, N, 2), + ResidualBlock(N, N), + ResidualBlockUpsample(N, N, 2), + ResidualBlock(N, N), + subpel_conv3x3(N, 3, 2), + ) + + h_a = nn.Sequential( + conv3x3(N, N), + nn.LeakyReLU(inplace=True), + conv3x3(N, N), + nn.LeakyReLU(inplace=True), + conv3x3(N, N, stride=2), + nn.LeakyReLU(inplace=True), + conv3x3(N, N), + nn.LeakyReLU(inplace=True), + conv3x3(N, N, stride=2), + ) + + h_s = nn.Sequential( + conv3x3(N, N), + nn.LeakyReLU(inplace=True), + subpel_conv3x3(N, N, 2), + nn.LeakyReLU(inplace=True), + conv3x3(N, N * 3 // 2), + nn.LeakyReLU(inplace=True), + subpel_conv3x3(N * 3 // 2, N * 3 // 2, 2), + nn.LeakyReLU(inplace=True), + conv3x3(N * 3 // 2, N * 2), + ) + + self.latent_codec = HyperpriorLatentCodec( + latent_codec={ + "y": ChannelGroupsLatentCodec( + groups=self.groups, + channel_context={ + f"y{k}": sequential_channel_ramp( + self.groups_acc[k], + self.groups[k] * 2, + num_layers=3, + make_layer=nn.Conv2d, + make_act=lambda: nn.ReLU(inplace=True), + kernel_size=5, + padding=2, + stride=1, + ) + for k in range(1, len(self.groups)) + }, + latent_codec={ + f"y{k}": CheckerboardLatentCodec( + latent_codec={ + "y": GaussianConditionalLatentCodec(quantizer="ste"), + }, + entropy_parameters=sequential_channel_ramp( + N * 2 + self.groups[k] * (2 if k == 0 else 4), + self.groups[k] * 2, + num_layers=3, + make_layer=nn.Conv2d, + make_act=lambda: nn.ReLU(inplace=True), + kernel_size=1, + padding=0, + stride=1, + ), + context_prediction=CheckerboardMaskedConv2d( + self.groups[k], + self.groups[k] * 2, + kernel_size=5, + padding=2, + stride=1, + ), + ) + for k in range(len(self.groups)) + }, + ), + "hyper": HyperLatentCodec( + entropy_bottleneck=EntropyBottleneck(N), h_a=h_a, h_s=h_s + ), + }, + ) + + @classmethod + def from_state_dict(cls, state_dict): + """Return a new model instance from `state_dict`.""" + N = state_dict["g_a.0.conv1.weight"].size(0) + net = cls(N) + net.load_state_dict(state_dict) + return net diff --git a/docs/source/latent_codecs.rst b/docs/source/latent_codecs.rst index 6289df8e..264a84f0 100644 --- a/docs/source/latent_codecs.rst +++ b/docs/source/latent_codecs.rst @@ -31,6 +31,8 @@ CompressAI provides the following predefined :py:class:`~LatentCodec` subclasses - Like :py:class:`~HyperLatentCodec`, but with trainable gain vectors for ``z``. * - :py:class:`~GainHyperpriorLatentCodec` - Like :py:class:`~HyperpriorLatentCodec`, but with trainable gain vectors for ``y``. + * - :py:class:`~ChannelGroupsLatentCodec` + - Encodes `y` in multiple chunked groups, each group conditioned on previously encoded groups. * - :py:class:`~CheckerboardLatentCodec` - Encodes `y` in two-passes in checkerboard order. @@ -332,6 +334,11 @@ GainHyperpriorLatentCodec .. autoclass:: GainHyperpriorLatentCodec +ChannelGroupsLatentCodec +------------------------ +.. autoclass:: ChannelGroupsLatentCodec + + CheckerboardLatentCodec ----------------------- .. autoclass:: CheckerboardLatentCodec