From 2d0873eee2990ce311d88464d165cce028057219 Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Wed, 26 Apr 2023 21:59:58 -0700 Subject: [PATCH 01/37] feat: fast EntropyBottleneck aux_loss minimization via bisection search This method completes in <1 second and reduces aux_loss to <0.01. This makes the aux_loss optimization during training unnecessary. Another alternative would be to run the following post-training: ```python while aux_loss > 0.1: aux_loss = model.aux_loss() aux_loss.backward() aux_optimizer.step() aux_optimizer.zero_grad() ``` ...but since we do not manage aux_loss learning rates, the bisection search method might converge better. --- compressai/entropy_models/entropy_models.py | 34 ++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/compressai/entropy_models/entropy_models.py b/compressai/entropy_models/entropy_models.py index 038ba21d..e14827fb 100644 --- a/compressai/entropy_models/entropy_models.py +++ b/compressai/entropy_models/entropy_models.py @@ -386,12 +386,15 @@ def _get_medians(self) -> Tensor: medians = self.quantiles[:, :, 1:2] return medians - def update(self, force: bool = False) -> bool: + def update(self, force: bool = False, update_quantiles: bool = True) -> bool: # Check if we need to update the bottleneck parameters, the offsets are # only computed and stored when the conditonal model is update()'d. if self._offset.numel() > 0 and not force: return False + if update_quantiles: + self._update_quantiles() + medians = self.quantiles[:, 0, 1] minima = medians - self.quantiles[:, 0, 0] @@ -521,6 +524,35 @@ def _build_indexes(size): def _extend_ndims(tensor, n): return tensor.reshape(-1, *([1] * n)) if n > 0 else tensor.reshape(-1) + @torch.no_grad() + def _update_quantiles(self, search_radius=1e5, rtol=1e-4, atol=1e-3): + device = self.quantiles.device + shape = (self.channels, 1, 1) + low = torch.full(shape, -search_radius, device=device) + high = torch.full(shape, search_radius, device=device) + + def f(y, self=self): + return self._logits_cumulative(y, stop_gradient=True) + + for i in range(len(self.target)): + q_i = self._search_target(f, self.target[i], low, high, rtol, atol) + self.quantiles[:, :, i] = q_i[:, :, 0] + + @staticmethod + def _search_target(f, target, low, high, rtol=1e-4, atol=1e-3, strict=False): + assert (low <= high).all() + if strict: + assert ((f(low) <= target) & (target <= f(high))).all() + else: + low = torch.where(target <= f(high), low, high) + high = torch.where(f(low) <= target, high, low) + while not torch.isclose(low, high, rtol=rtol, atol=atol).all(): + mid = (low + high) / 2 + f_mid = f(mid) + low = torch.where(f_mid <= target, mid, low) + high = torch.where(f_mid >= target, mid, high) + return (low + high) / 2 + def compress(self, x): indexes = self._build_indexes(x.size()) medians = self._get_medians().detach() From 4ccc68e85086000d131655d7c3eb61a00bf45eb7 Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Mon, 25 Sep 2023 01:27:15 -0700 Subject: [PATCH 02/37] fix: repair tests to work with non-zero medians --- tests/test_entropy_models.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/test_entropy_models.py b/tests/test_entropy_models.py index 00c83287..2dbfc1c4 100644 --- a/tests/test_entropy_models.py +++ b/tests/test_entropy_models.py @@ -261,26 +261,30 @@ def test_compression_2D(self): eb.update() s = eb.compress(x) x2 = eb.decompress(s, x.size()[2:]) + means = eb._get_medians() - assert torch.allclose(torch.round(x), x2) + assert torch.allclose(torch.round(x - means) + means, x2) def test_compression_ND(self): eb = EntropyBottleneck(128) eb.update() + # Test 0D x = torch.rand(1, 128) s = eb.compress(x) x2 = eb.decompress(s, []) + means = eb._get_medians().reshape(128) - assert torch.allclose(torch.round(x), x2) + assert torch.allclose(torch.round(x - means) + means, x2) # Test from 1 to 5 dimensions for i in range(1, 6): x = torch.rand(1, 128, *([4] * i)) s = eb.compress(x) x2 = eb.decompress(s, x.size()[2:]) + means = eb._get_medians().reshape(128, *([1] * i)) - assert torch.allclose(torch.round(x), x2) + assert torch.allclose(torch.round(x - means) + means, x2) class TestGaussianConditional: From cdac4678f6917c028402aaaf0997a5481bf2b2be Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Sun, 5 Mar 2023 23:53:17 -0800 Subject: [PATCH 03/37] feat(models): add Checkerboard Checkerboard context model introduced in [He2021]. [He2021]: `"Checkerboard Context Model for Efficient Learned Image Compression" `_, by Dailan He, Yaoyan Zheng, Baocheng Sun, Yan Wang, and Hongwei Qin, CVPR 2021. --- compressai/latent_codecs/__init__.py | 2 + compressai/latent_codecs/checkerboard.py | 217 +++++++++++++++++++++++ compressai/layers/layers.py | 27 +++ compressai/models/__init__.py | 1 + compressai/models/sensetime.py | 156 ++++++++++++++++ docs/source/latent_codecs.rst | 7 + docs/source/models.rst | 5 + 7 files changed, 415 insertions(+) create mode 100644 compressai/latent_codecs/checkerboard.py create mode 100644 compressai/models/sensetime.py diff --git a/compressai/latent_codecs/__init__.py b/compressai/latent_codecs/__init__.py index 90e3915f..033607a6 100644 --- a/compressai/latent_codecs/__init__.py +++ b/compressai/latent_codecs/__init__.py @@ -28,6 +28,7 @@ # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. from .base import LatentCodec +from .checkerboard import CheckerboardLatentCodec from .entropy_bottleneck import EntropyBottleneckLatentCodec from .gain import GainHyperLatentCodec, GainHyperpriorLatentCodec from .gaussian_conditional import GaussianConditionalLatentCodec @@ -37,6 +38,7 @@ __all__ = [ "LatentCodec", + "CheckerboardLatentCodec", "EntropyBottleneckLatentCodec", "GainHyperLatentCodec", "GainHyperpriorLatentCodec", diff --git a/compressai/latent_codecs/checkerboard.py b/compressai/latent_codecs/checkerboard.py new file mode 100644 index 00000000..4cbe5d2e --- /dev/null +++ b/compressai/latent_codecs/checkerboard.py @@ -0,0 +1,217 @@ +# Copyright (c) 2021-2022, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from typing import Any, Dict, List, Mapping, Tuple + +import torch +import torch.nn as nn + +from torch import Tensor + +from compressai.entropy_models import EntropyModel +from compressai.layers import CheckerboardMaskedConv2d +from compressai.registry import register_module + +from .base import LatentCodec +from .gaussian_conditional import GaussianConditionalLatentCodec + +__all__ = [ + "CheckerboardLatentCodec", +] + + +@register_module("CheckerboardLatentCodec") +class CheckerboardLatentCodec(LatentCodec): + """Reconstructs latent using 2-pass context model with checkerboard anchors. + + Checkerboard context model introduced in [He2021]. + + [He2021]: `"Checkerboard Context Model for Efficient Learned Image + Compression" `_, by Dailan He, + Yaoyan Zheng, Baocheng Sun, Yan Wang, and Hongwei Qin, CVPR 2021. + + .. warning:: This implementation assumes that ``entropy_parameters`` + is a pointwise function, e.g., a composition of 1x1 convs and + pointwise nonlinearities. + + .. note:: This implementation uses uniform noise for training quantization. + + .. code-block:: none + + 0. Input: + + ■ ■ ■ ■ + ■ ■ ■ ■ + ■ ■ ■ ■ + + 1. Decode anchors: + + ◌ ■ ◌ ■ + ■ ◌ ■ ◌ + ◌ ■ ◌ ■ + + 2. Decode non-anchors: + + □ ◌ □ ◌ + ◌ □ ◌ □ + □ ◌ □ ◌ + + 3. End result: + + □ □ □ □ + □ □ □ □ + □ □ □ □ + + LEGEND: + □ decoded + ◌ currently decoding + ■ empty + """ + + latent_codec: Mapping[str, LatentCodec] + + entropy_parameters: nn.Module + context_prediction: CheckerboardMaskedConv2d + + def __init__(self, **kwargs): + super().__init__() + self._kwargs = kwargs + self._setdefault("entropy_parameters", nn.Identity) + self._setdefault("context_prediction", nn.Identity) + self._set_group_defaults( + "latent_codec", + defaults={ + "y": GaussianConditionalLatentCodec, + }, + save_direct=True, + ) + + def forward(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: + y_hat = self.quantize(y) + ctx_params = self.entropy_parameters( + self.merge(side_params, self.context_prediction(y_hat)) + ) + y_out = self.latent_codec["y"](y, ctx_params) + return { + "likelihoods": { + "y": y_out["likelihoods"]["y"], + }, + "y_hat": y_hat, + } + + def compress(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: + n, c, h, w = y.shape + y_hat_ = side_params.new_zeros((2, n, c, h, w // 2)) + side_params_ = self.unembed(side_params) + y_ = self.unembed(y) + y_strings_ = [None] * 2 + + for i in range(2): + y_ctx_i = self.unembed(self.context_prediction(self.embed(y_hat_)))[i] + ctx_params_i = self.entropy_parameters(self.merge(side_params_[i], y_ctx_i)) + y_out = self.latent_codec["y"].compress(y_[i], ctx_params_i) + y_hat_[i] = y_out["y_hat"] + [y_strings_[i]] = y_out["strings"] + + y_hat = self.embed(y_hat_) + + return { + "strings": y_strings_, + "shape": y_hat.shape[1:], + "y_hat": y_hat, + } + + def decompress( + self, strings: List[List[bytes]], shape: Tuple[int, ...], side_params: Tensor + ) -> Dict[str, Any]: + y_strings_ = strings + n = len(y_strings_[0]) + assert len(y_strings_) == 2 + assert all(len(x) == n for x in y_strings_) + + c, h, w = shape + y_hat_ = side_params.new_zeros((2, n, c, h, w // 2)) + side_params_ = self.unembed(side_params) + + for i in range(2): + y_ctx_i = self.unembed(self.context_prediction(self.embed(y_hat_)))[i] + ctx_params_i = self.entropy_parameters(self.merge(side_params_[i], y_ctx_i)) + y_out = self.latent_codec["y"].decompress( + [y_strings_[i]], shape=(h, w // 2), ctx_params=ctx_params_i + ) + y_hat_[i] = y_out["y_hat"] + + y_hat = self.embed(y_hat_) + + return { + "y_hat": y_hat, + } + + def unembed(self, y: Tensor) -> Tensor: + """Separate single tensor into two even/odd checkerboard chunks. + + .. code-block:: none + + □ ■ □ ■ □ □ ■ ■ + ■ □ ■ □ ---> □ □ ■ ■ + □ ■ □ ■ □ □ ■ ■ + """ + n, c, h, w = y.shape + y_ = y.new_zeros((2, n, c, h, w // 2)) + y_[0, ..., 0::2, :] = y[..., 0::2, 0::2] + y_[0, ..., 1::2, :] = y[..., 1::2, 1::2] + y_[1, ..., 0::2, :] = y[..., 0::2, 1::2] + y_[1, ..., 1::2, :] = y[..., 1::2, 0::2] + return y_ + + def embed(self, y_: Tensor) -> Tensor: + """Combine two even/odd checkerboard chunks into single tensor. + + .. code-block:: none + + □ □ ■ ■ □ ■ □ ■ + □ □ ■ ■ ---> ■ □ ■ □ + □ □ ■ ■ □ ■ □ ■ + """ + num_chunks, n, c, h, w_half = y_.shape + assert num_chunks == 2 + y = y_.new_zeros((n, c, h, w_half * 2)) + y[..., 0::2, 0::2] = y_[0, ..., 0::2, :] + y[..., 1::2, 1::2] = y_[0, ..., 1::2, :] + y[..., 0::2, 1::2] = y_[1, ..., 0::2, :] + y[..., 1::2, 0::2] = y_[1, ..., 1::2, :] + return y + + def merge(self, *args): + return torch.cat(args, dim=1) + + def quantize(self, y: Tensor) -> Tensor: + mode = "noise" if self.training else "dequantize" + y_hat = EntropyModel.quantize(None, y, mode) + return y_hat diff --git a/compressai/layers/layers.py b/compressai/layers/layers.py index 540771ef..d85846f8 100644 --- a/compressai/layers/layers.py +++ b/compressai/layers/layers.py @@ -40,6 +40,7 @@ __all__ = [ "AttentionBlock", "MaskedConv2d", + "CheckerboardMaskedConv2d", "ResidualBlock", "ResidualBlockUpsample", "ResidualBlockWithStride", @@ -78,6 +79,32 @@ def forward(self, x: Tensor) -> Tensor: return super().forward(x) +class CheckerboardMaskedConv2d(MaskedConv2d): + r"""Checkerboard masked 2D convolution; mask future "unseen" pixels. + + Checkerboard mask variant used in + `"Checkerboard Context Model for Efficient Learned Image Compression" + `_, by Dailan He, Yaoyan Zheng, + Baocheng Sun, Yan Wang, and Hongwei Qin, CVPR 2021. + + Inherits the same arguments as a `nn.Conv2d`. Use `mask_type='A'` for the + first layer (which also masks the "current pixel"), `mask_type='B'` for the + following layers. + """ + + def __init__(self, *args: Any, mask_type: str = "A", **kwargs: Any): + super().__init__(*args, **kwargs) + + if mask_type not in ("A", "B"): + raise ValueError(f'Invalid "mask_type" value "{mask_type}"') + + _, _, h, w = self.mask.size() + self.mask[:] = 1 + self.mask[:, :, 0::2, 0::2] = 0 + self.mask[:, :, 1::2, 1::2] = 0 + self.mask[:, :, h // 2, w // 2] = mask_type == "B" + + def conv3x3(in_ch: int, out_ch: int, stride: int = 1) -> nn.Module: """3x3 convolution with padding.""" return nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1) diff --git a/compressai/models/__init__.py b/compressai/models/__init__.py index 8bdb5749..94876ccb 100644 --- a/compressai/models/__init__.py +++ b/compressai/models/__init__.py @@ -29,4 +29,5 @@ from .base import * from .google import * +from .sensetime import * from .waseda import * diff --git a/compressai/models/sensetime.py b/compressai/models/sensetime.py new file mode 100644 index 00000000..d487e487 --- /dev/null +++ b/compressai/models/sensetime.py @@ -0,0 +1,156 @@ +# Copyright (c) 2021-2022, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import torch.nn as nn + +from compressai.entropy_models import EntropyBottleneck, GaussianConditional +from compressai.latent_codecs import ( + CheckerboardLatentCodec, + HyperLatentCodec, + HyperpriorLatentCodec, +) +from compressai.layers import ( + ResidualBlock, + ResidualBlockUpsample, + ResidualBlockWithStride, + conv3x3, + subpel_conv3x3, +) +from compressai.layers.layers import CheckerboardMaskedConv2d +from compressai.registry import register_model + +from .base import SimpleVAECompressionModel + +__all__ = [ + "Cheng2020AnchorCheckerboard", +] + + +@register_model("cheng2020-anchor-checkerboard") +class Cheng2020AnchorCheckerboard(SimpleVAECompressionModel): + """Cheng2020 anchor model with checkerboard context model. + + Base transform model from [Cheng2020]. Context model from [He2021]. + + [Cheng2020]: `"Learned Image Compression with Discretized Gaussian + Mixture Likelihoods and Attention Modules" + `_, by Zhengxue Cheng, Heming Sun, + Masaru Takeuchi, and Jiro Katto, CVPR 2020. + + [He2021]: `"Checkerboard Context Model for Efficient Learned Image + Compression" `_, by Dailan He, + Yaoyan Zheng, Baocheng Sun, Yan Wang, and Hongwei Qin, CVPR 2021. + + Uses residual blocks with small convolutions (3x3 and 1x1), and sub-pixel + convolutions for up-sampling. + + Args: + N (int): Number of channels + """ + + def __init__(self, N=192, **kwargs): + super().__init__(**kwargs) + + self.g_a = nn.Sequential( + ResidualBlockWithStride(3, N, stride=2), + ResidualBlock(N, N), + ResidualBlockWithStride(N, N, stride=2), + ResidualBlock(N, N), + ResidualBlockWithStride(N, N, stride=2), + ResidualBlock(N, N), + conv3x3(N, N, stride=2), + ) + + self.g_s = nn.Sequential( + ResidualBlock(N, N), + ResidualBlockUpsample(N, N, 2), + ResidualBlock(N, N), + ResidualBlockUpsample(N, N, 2), + ResidualBlock(N, N), + ResidualBlockUpsample(N, N, 2), + ResidualBlock(N, N), + subpel_conv3x3(N, 3, 2), + ) + + h_a = nn.Sequential( + conv3x3(N, N), + nn.LeakyReLU(inplace=True), + conv3x3(N, N), + nn.LeakyReLU(inplace=True), + conv3x3(N, N, stride=2), + nn.LeakyReLU(inplace=True), + conv3x3(N, N), + nn.LeakyReLU(inplace=True), + conv3x3(N, N, stride=2), + ) + + h_s = nn.Sequential( + conv3x3(N, N), + nn.LeakyReLU(inplace=True), + subpel_conv3x3(N, N, 2), + nn.LeakyReLU(inplace=True), + conv3x3(N, N * 3 // 2), + nn.LeakyReLU(inplace=True), + subpel_conv3x3(N * 3 // 2, N * 3 // 2, 2), + nn.LeakyReLU(inplace=True), + conv3x3(N * 3 // 2, N * 2), + ) + + self.latent_codec = HyperpriorLatentCodec( + N, + latent_codec={ + "y": CheckerboardLatentCodec( + gaussian_conditional=GaussianConditional(None), + entropy_parameters=nn.Sequential( + nn.Conv2d(N * 12 // 3, N * 10 // 3, 1), + nn.LeakyReLU(inplace=True), + nn.Conv2d(N * 10 // 3, N * 8 // 3, 1), + nn.LeakyReLU(inplace=True), + nn.Conv2d(N * 8 // 3, N * 6 // 3, 1), + ), + context_prediction=CheckerboardMaskedConv2d( + N, 2 * N, kernel_size=5, padding=2, stride=1 + ), + ), + "hyper": HyperLatentCodec( + N, + h_a=h_a, + h_s=h_s, + entropy_bottleneck=EntropyBottleneck(N), + ), + }, + ) + + @classmethod + def from_state_dict(cls, state_dict): + """Return a new model instance from `state_dict`.""" + N = state_dict["g_a.0.conv1.weight"].size(0) + net = cls(N) + net.load_state_dict(state_dict) + return net diff --git a/docs/source/latent_codecs.rst b/docs/source/latent_codecs.rst index ba1a8efe..b3ca3b6a 100644 --- a/docs/source/latent_codecs.rst +++ b/docs/source/latent_codecs.rst @@ -31,6 +31,8 @@ CompressAI provides the following predefined :py:class:`~LatentCodec` subclasses - Like :py:class:`~HyperLatentCodec`, but with trainable gain vectors for ``z``. * - :py:class:`~GainHyperpriorLatentCodec` - Like :py:class:`~HyperpriorLatentCodec`, but with trainable gain vectors for ``y``. + * - :py:class:`~CheckerboardLatentCodec` + - Encodes ``y`` in two passes in checkerboard order. Diagrams for some of the above predefined latent codecs: @@ -329,3 +331,8 @@ GainHyperpriorLatentCodec ------------------------- .. autoclass:: GainHyperpriorLatentCodec + +CheckerboardLatentCodec +----------------------- +.. autoclass:: CheckerboardLatentCodec + diff --git a/docs/source/models.rst b/docs/source/models.rst index 3703df79..0d69f59d 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -45,6 +45,11 @@ Cheng2020Attention .. autoclass:: Cheng2020Attention +Cheng2020AnchorCheckerboard +--------------------------- +.. autoclass:: Cheng2020AnchorCheckerboard + + .. currentmodule:: compressai.models.video ScaleSpaceFlow From 7a8279ecba06e0f46a09799f225ab012ceaeb4b6 Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Tue, 7 Mar 2023 01:49:42 -0800 Subject: [PATCH 04/37] fix: CheckerboardLatentCodec mask anchors of ctx pred --- compressai/latent_codecs/checkerboard.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/compressai/latent_codecs/checkerboard.py b/compressai/latent_codecs/checkerboard.py index 4cbe5d2e..ff6ad834 100644 --- a/compressai/latent_codecs/checkerboard.py +++ b/compressai/latent_codecs/checkerboard.py @@ -114,9 +114,8 @@ def __init__(self, **kwargs): def forward(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: y_hat = self.quantize(y) - ctx_params = self.entropy_parameters( - self.merge(side_params, self.context_prediction(y_hat)) - ) + y_ctx = self._mask_anchor(self.context_prediction(y_hat)) + ctx_params = self.entropy_parameters(self.merge(side_params, y_ctx)) y_out = self.latent_codec["y"](y, ctx_params) return { "likelihoods": { @@ -208,6 +207,11 @@ def embed(self, y_: Tensor) -> Tensor: y[..., 1::2, 0::2] = y_[1, ..., 1::2, :] return y + def _mask_anchor(self, y: Tensor) -> Tensor: + y[..., 0::2, 0::2] = 0 + y[..., 1::2, 1::2] = 0 + return y + def merge(self, *args): return torch.cat(args, dim=1) From a3bbebe0217c707471580c32057dcb1b3fed1efe Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Wed, 8 Mar 2023 16:08:28 -0800 Subject: [PATCH 05/37] fix: CheckerboardLatentCodec 'two-pass' forward with STE gives better results Gives approximately an immediate 2% reduction in validation loss on a currently training model. The things changed that may have contributed to this bump are: - Do context prediction of non-anchors on STE-quantized y_hat anchors instead. - Use STE-quantized y_hat for non-anchors. - Reuse the exact same quantized y_hat for anchors that was used in predicting the non-anchors. (Note: probably unnecessary since the new means_hat should already contain the same values for the anchor.) --- compressai/latent_codecs/checkerboard.py | 40 ++++++++++++++++++++++-- compressai/models/sensetime.py | 13 +++----- 2 files changed, 43 insertions(+), 10 deletions(-) diff --git a/compressai/latent_codecs/checkerboard.py b/compressai/latent_codecs/checkerboard.py index ff6ad834..d9d46b71 100644 --- a/compressai/latent_codecs/checkerboard.py +++ b/compressai/latent_codecs/checkerboard.py @@ -36,6 +36,7 @@ from compressai.entropy_models import EntropyModel from compressai.layers import CheckerboardMaskedConv2d +from compressai.ops import quantize_ste from compressai.registry import register_module from .base import LatentCodec @@ -99,20 +100,23 @@ class CheckerboardLatentCodec(LatentCodec): entropy_parameters: nn.Module context_prediction: CheckerboardMaskedConv2d - def __init__(self, **kwargs): + def __init__(self, forward_method="twopass", **kwargs): super().__init__() self._kwargs = kwargs + self.forward_method = forward_method self._setdefault("entropy_parameters", nn.Identity) self._setdefault("context_prediction", nn.Identity) self._set_group_defaults( "latent_codec", defaults={ - "y": GaussianConditionalLatentCodec, + "y": lambda: GaussianConditionalLatentCodec(quantizer="ste"), }, save_direct=True, ) def forward(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: + if self.forward_method == "twopass": + return self._forward_twopass(y, side_params) y_hat = self.quantize(y) y_ctx = self._mask_anchor(self.context_prediction(y_hat)) ctx_params = self.entropy_parameters(self.merge(side_params, y_ctx)) @@ -124,6 +128,33 @@ def forward(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: "y_hat": y_hat, } + def _forward_twopass(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: + """Do context prediction on STE-quantized y_hat instead.""" + y_hat_anchors = self._y_hat_anchors(y, side_params) + y_ctx = self._mask_anchor(self.context_prediction(y_hat_anchors)) + ctx_params = self.entropy_parameters(self.merge(side_params, y_ctx)) + y_out = self.latent_codec["y"](y, ctx_params) + # Reuse quantized y_hat that was used for non-anchor context prediction. + y_hat = y_out["y_hat"] + y_hat[..., 0::2, 0::2] = y_hat_anchors[..., 0::2, 0::2] + y_hat[..., 1::2, 1::2] = y_hat_anchors[..., 1::2, 1::2] + return { + "likelihoods": { + "y": y_out["likelihoods"]["y"], + }, + "y_hat": y_hat, + } + + def _y_hat_anchors(self, y, side_params): + y_ctx = self.context_prediction(y).detach() + y_ctx[:] = 0 + ctx_params = self.entropy_parameters(self.merge(side_params, y_ctx)) + ctx_params = self.latent_codec["y"].entropy_parameters(ctx_params) + ctx_params = self._mask_non_anchor(ctx_params) # Probably not needed. + _, means_hat = ctx_params.chunk(2, 1) + y_hat = quantize_ste(y - means_hat) + means_hat + return y_hat + def compress(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: n, c, h, w = y.shape y_hat_ = side_params.new_zeros((2, n, c, h, w // 2)) @@ -212,6 +243,11 @@ def _mask_anchor(self, y: Tensor) -> Tensor: y[..., 1::2, 1::2] = 0 return y + def _mask_non_anchor(self, y: Tensor) -> Tensor: + y[..., 0::2, 1::2] = 0 + y[..., 1::2, 0::2] = 0 + return y + def merge(self, *args): return torch.cat(args, dim=1) diff --git a/compressai/models/sensetime.py b/compressai/models/sensetime.py index d487e487..3c645c1f 100644 --- a/compressai/models/sensetime.py +++ b/compressai/models/sensetime.py @@ -29,9 +29,9 @@ import torch.nn as nn -from compressai.entropy_models import EntropyBottleneck, GaussianConditional from compressai.latent_codecs import ( CheckerboardLatentCodec, + GaussianConditionalLatentCodec, HyperLatentCodec, HyperpriorLatentCodec, ) @@ -126,7 +126,9 @@ def __init__(self, N=192, **kwargs): N, latent_codec={ "y": CheckerboardLatentCodec( - gaussian_conditional=GaussianConditional(None), + latent_codec={ + "y": GaussianConditionalLatentCodec(quantizer="ste"), + }, entropy_parameters=nn.Sequential( nn.Conv2d(N * 12 // 3, N * 10 // 3, 1), nn.LeakyReLU(inplace=True), @@ -138,12 +140,7 @@ def __init__(self, N=192, **kwargs): N, 2 * N, kernel_size=5, padding=2, stride=1 ), ), - "hyper": HyperLatentCodec( - N, - h_a=h_a, - h_s=h_s, - entropy_bottleneck=EntropyBottleneck(N), - ), + "hyper": HyperLatentCodec(N, h_a=h_a, h_s=h_s), }, ) From 165f39668d093cf2acb3ce32188135685f12d9dd Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Wed, 19 Jul 2023 19:39:06 -0700 Subject: [PATCH 06/37] fix: correct latent codec usage --- compressai/latent_codecs/checkerboard.py | 16 ++++++++++++---- compressai/models/sensetime.py | 6 ++++-- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/compressai/latent_codecs/checkerboard.py b/compressai/latent_codecs/checkerboard.py index d9d46b71..4bdf918d 100644 --- a/compressai/latent_codecs/checkerboard.py +++ b/compressai/latent_codecs/checkerboard.py @@ -27,7 +27,7 @@ # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from typing import Any, Dict, List, Mapping, Tuple +from typing import Any, Dict, List, Mapping, Optional, Tuple import torch import torch.nn as nn @@ -100,14 +100,22 @@ class CheckerboardLatentCodec(LatentCodec): entropy_parameters: nn.Module context_prediction: CheckerboardMaskedConv2d - def __init__(self, forward_method="twopass", **kwargs): + def __init__( + self, + latent_codec: Optional[Mapping[str, LatentCodec]] = None, + entropy_parameters: Optional[nn.Module] = None, + context_prediction: Optional[nn.Module] = None, + forward_method="twopass", + **kwargs, + ): super().__init__() self._kwargs = kwargs self.forward_method = forward_method - self._setdefault("entropy_parameters", nn.Identity) - self._setdefault("context_prediction", nn.Identity) + self.entropy_parameters = entropy_parameters or nn.Identity() + self.context_prediction = context_prediction or nn.Identity() self._set_group_defaults( "latent_codec", + latent_codec, defaults={ "y": lambda: GaussianConditionalLatentCodec(quantizer="ste"), }, diff --git a/compressai/models/sensetime.py b/compressai/models/sensetime.py index 3c645c1f..bd260447 100644 --- a/compressai/models/sensetime.py +++ b/compressai/models/sensetime.py @@ -29,6 +29,7 @@ import torch.nn as nn +from compressai.entropy_models import EntropyBottleneck from compressai.latent_codecs import ( CheckerboardLatentCodec, GaussianConditionalLatentCodec, @@ -123,7 +124,6 @@ def __init__(self, N=192, **kwargs): ) self.latent_codec = HyperpriorLatentCodec( - N, latent_codec={ "y": CheckerboardLatentCodec( latent_codec={ @@ -140,7 +140,9 @@ def __init__(self, N=192, **kwargs): N, 2 * N, kernel_size=5, padding=2, stride=1 ), ), - "hyper": HyperLatentCodec(N, h_a=h_a, h_s=h_s), + "hyper": HyperLatentCodec( + entropy_bottleneck=EntropyBottleneck(N), h_a=h_a, h_s=h_s + ), }, ) From 459683fa9f8ad8345c4bf99137dde6b25e9e6d65 Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Thu, 20 Jul 2023 16:01:56 -0700 Subject: [PATCH 07/37] feat(models): add ELIC Context model from [He2022]. [He2022]: `"ELIC: Efficient Learned Image Compression with Unevenly Grouped Space-Channel Contextual Adaptive Coding" `_, by Dailan He, Ziming Yang, Weikun Peng, Rui Ma, Hongwei Qin, and Yan Wang, CVPR 2022. --- compressai/latent_codecs/__init__.py | 2 + compressai/latent_codecs/channel_groups.py | 162 +++++++++++++++++++++ compressai/layers/layers.py | 37 +++++ compressai/models/sensetime.py | 144 +++++++++++++++++- docs/source/latent_codecs.rst | 7 + docs/source/models.rst | 5 + 6 files changed, 356 insertions(+), 1 deletion(-) create mode 100644 compressai/latent_codecs/channel_groups.py diff --git a/compressai/latent_codecs/__init__.py b/compressai/latent_codecs/__init__.py index 033607a6..8e8a9b7c 100644 --- a/compressai/latent_codecs/__init__.py +++ b/compressai/latent_codecs/__init__.py @@ -28,6 +28,7 @@ # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. from .base import LatentCodec +from .channel_groups import ChannelGroupsLatentCodec from .checkerboard import CheckerboardLatentCodec from .entropy_bottleneck import EntropyBottleneckLatentCodec from .gain import GainHyperLatentCodec, GainHyperpriorLatentCodec @@ -38,6 +39,7 @@ __all__ = [ "LatentCodec", + "ChannelGroupsLatentCodec", "CheckerboardLatentCodec", "EntropyBottleneckLatentCodec", "GainHyperLatentCodec", diff --git a/compressai/latent_codecs/channel_groups.py b/compressai/latent_codecs/channel_groups.py new file mode 100644 index 00000000..0cd74baa --- /dev/null +++ b/compressai/latent_codecs/channel_groups.py @@ -0,0 +1,162 @@ +# Copyright (c) 2021-2022, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from itertools import accumulate +from typing import Any, Dict, List, Mapping, Optional, Tuple + +import torch +import torch.nn as nn + +from torch import Tensor + +from compressai.registry import register_module + +from .base import LatentCodec + +__all__ = [ + "ChannelGroupsLatentCodec", +] + + +@register_module("ChannelGroupsLatentCodec") +class ChannelGroupsLatentCodec(LatentCodec): + """Reconstructs groups of channels using previously decoded groups. + + Context model from [Minnen2020] and [He2022]. + Also known as a "channel-conditional" (CC) entropy model. + + See :py:class:`~compressai.models.sensetime.Cheng2020AnchorElic` + for example usage. + + [Minnen2020]: `"Channel-wise Autoregressive Entropy Models for + Learned Image Compression" `_, by + David Minnen, and Saurabh Singh, ICIP 2020. + + [He2022]: `"ELIC: Efficient Learned Image Compression with + Unevenly Grouped Space-Channel Contextual Adaptive Coding" + `_, by Dailan He, Ziming Yang, + Weikun Peng, Rui Ma, Hongwei Qin, and Yan Wang, CVPR 2022. + """ + + latent_codec: Mapping[str, LatentCodec] + + channel_context: Mapping[str, nn.Module] + + def __init__( + self, + latent_codec: Optional[Mapping[str, LatentCodec]] = None, + channel_context: Optional[Mapping[str, nn.Module]] = None, + *, + groups: List[int], + **kwargs, + ): + super().__init__() + self._kwargs = kwargs + self.groups = list(groups) + self.groups_acc = list(accumulate(self.groups, initial=0)) + self.channel_context = nn.ModuleDict(channel_context) + self.latent_codec = nn.ModuleDict(latent_codec) + + def forward(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: + y_ = torch.split(y, self.groups, dim=1) + y_out_ = [{}] * len(self.groups) + y_hat_ = [Tensor()] * len(self.groups) + y_likelihoods_ = [Tensor()] * len(self.groups) + + for k in range(len(self.groups)): + y_hat_prev = torch.cat(y_hat_[:k], dim=1) if k > 0 else Tensor() + params = self._get_ctx_params(k, side_params, y_hat_prev) + y_out_[k] = self.latent_codec[f"y{k}"](y_[k], params) + y_hat_[k] = y_out_[k]["y_hat"] + y_likelihoods_[k] = y_out_[k]["likelihoods"]["y"] + + y_hat = torch.cat(y_hat_, dim=1) + y_likelihoods = torch.cat(y_likelihoods_, dim=1) + + return { + "likelihoods": { + "y": y_likelihoods, + }, + "y_hat": y_hat, + } + + def compress(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: + y_ = torch.split(y, self.groups, dim=1) + y_out_ = [{}] * len(self.groups) + y_hat = torch.zeros_like(y) + + for k in range(len(self.groups)): + y_hat_prev = y_hat[:, : self.groups_acc[k]] + params = self._get_ctx_params(k, side_params, y_hat_prev) + y_out_[k] = self.latent_codec[f"y{k}"].compress(y_[k], params) + y_hat[:, self.groups_acc[k] : self.groups_acc[k + 1]] = y_out_[k]["y_hat"] + + y_strings_groups = [y_out["strings"] for y_out in y_out_] + assert all(len(y_strings_groups[0]) == len(ss) for ss in y_strings_groups) + + return { + "strings": [s for ss in y_strings_groups for s in ss], + "shape": [y_out["shape"] for y_out in y_out_], + "y_hat": y_hat, + } + + def decompress( + self, + strings: List[List[bytes]], + shape: List[Tuple[int, ...]], + side_params: Tensor, + ) -> Dict[str, Any]: + n = len(strings[0]) + assert all(len(ss) == n for ss in strings) + strings_per_group = len(strings) // len(self.groups) + + y_out_ = [{}] * len(self.groups) + y_shape = (sum(s[0] for s in shape), *shape[0][1:]) + y_hat = torch.zeros((n, *y_shape), device=side_params.device) + + for k in range(len(self.groups)): + y_hat_prev = y_hat[:, : self.groups_acc[k]] + params = self._get_ctx_params(k, side_params, y_hat_prev) + y_strings_k = strings[strings_per_group * k : strings_per_group * (k + 1)] + y_out_[k] = self.latent_codec[f"y{k}"].decompress( + y_strings_k, shape[k], params + ) + y_hat[:, self.groups_acc[k] : self.groups_acc[k + 1]] = y_out_[k]["y_hat"] + + return { + "y_hat": y_hat, + } + + def _get_ctx_params( + self, k: int, side_params: Tensor, y_hat_prev: Tensor + ) -> Tensor: + if k == 0: + return side_params + params_ch = self.channel_context[f"y{k}"](y_hat_prev) + return torch.cat([side_params, params_ch], dim=1) diff --git a/compressai/layers/layers.py b/compressai/layers/layers.py index d85846f8..a9fae68a 100644 --- a/compressai/layers/layers.py +++ b/compressai/layers/layers.py @@ -27,6 +27,8 @@ # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import math + from typing import Any import torch @@ -47,6 +49,7 @@ "conv3x3", "subpel_conv3x3", "QReLU", + "sequential_channel_ramp", ] @@ -321,3 +324,37 @@ def backward(ctx, grad_output): grad_input[input > ctx.max_value] = grad_sub[input > ctx.max_value] return grad_input, None, None + + +def sequential_channel_ramp( + in_ch: int, + out_ch: int, + *, + num_layers: int = 3, + interp: str = "linear", + make_layer=None, + make_act=None, + skip_last_act: bool = True, + **layer_kwargs, +) -> nn.Module: + """Interleave layers of gradually ramping channels with nonlinearities.""" + channels = ramp(in_ch, out_ch, num_layers + 1, method=interp).floor().int().tolist() + layers = [ + module + for ch_in, ch_out in zip(channels[:-1], channels[1:]) + for module in [ + make_layer(ch_in, ch_out, **layer_kwargs), + make_act(), + ] + ] + if skip_last_act: + layers = layers[:-1] + return nn.Sequential(*layers) + + +def ramp(a, b, steps=None, method="linear", **kwargs): + if method == "linear": + return torch.linspace(a, b, steps, **kwargs) + if method == "log": + return torch.logspace(math.log10(a), math.log10(b), steps, **kwargs) + raise ValueError(f"Unknown ramp method: {method}") diff --git a/compressai/models/sensetime.py b/compressai/models/sensetime.py index bd260447..a8909d9d 100644 --- a/compressai/models/sensetime.py +++ b/compressai/models/sensetime.py @@ -27,29 +27,34 @@ # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from itertools import accumulate + import torch.nn as nn from compressai.entropy_models import EntropyBottleneck from compressai.latent_codecs import ( + ChannelGroupsLatentCodec, CheckerboardLatentCodec, GaussianConditionalLatentCodec, HyperLatentCodec, HyperpriorLatentCodec, ) from compressai.layers import ( + CheckerboardMaskedConv2d, ResidualBlock, ResidualBlockUpsample, ResidualBlockWithStride, conv3x3, + sequential_channel_ramp, subpel_conv3x3, ) -from compressai.layers.layers import CheckerboardMaskedConv2d from compressai.registry import register_model from .base import SimpleVAECompressionModel __all__ = [ "Cheng2020AnchorCheckerboard", + "Cheng2020AnchorElic", ] @@ -153,3 +158,140 @@ def from_state_dict(cls, state_dict): net = cls(N) net.load_state_dict(state_dict) return net + + +@register_model("cheng2020-anchor-elic") +class Cheng2020AnchorElic(SimpleVAECompressionModel): + """Cheng2020 anchor model with checkerboard context model. + + Base transform model from [Cheng2020]. Context model from [He2022]. + + [Cheng2020]: `"Learned Image Compression with Discretized Gaussian + Mixture Likelihoods and Attention Modules" + `_, by Zhengxue Cheng, Heming Sun, + Masaru Takeuchi, and Jiro Katto, CVPR 2020. + + [He2022]: `"ELIC: Efficient Learned Image Compression with + Unevenly Grouped Space-Channel Contextual Adaptive Coding" + `_, by Dailan He, Ziming Yang, + Weikun Peng, Rui Ma, Hongwei Qin, and Yan Wang, CVPR 2022. + + Uses residual blocks with small convolutions (3x3 and 1x1), and sub-pixel + convolutions for up-sampling. + + Args: + N (int): Number of channels + groups (list[int]): Number of channels in each channel group + """ + + def __init__(self, N=192, groups=None, **kwargs): + super().__init__(**kwargs) + + if groups is None: + groups = [16, 16, 32, 64, 64] + + assert sum(groups) == N + self.groups = list(groups) + self.groups_acc = list(accumulate(self.groups, initial=0)) + + self.g_a = nn.Sequential( + ResidualBlockWithStride(3, N, stride=2), + ResidualBlock(N, N), + ResidualBlockWithStride(N, N, stride=2), + ResidualBlock(N, N), + ResidualBlockWithStride(N, N, stride=2), + ResidualBlock(N, N), + conv3x3(N, N, stride=2), + ) + + self.g_s = nn.Sequential( + ResidualBlock(N, N), + ResidualBlockUpsample(N, N, 2), + ResidualBlock(N, N), + ResidualBlockUpsample(N, N, 2), + ResidualBlock(N, N), + ResidualBlockUpsample(N, N, 2), + ResidualBlock(N, N), + subpel_conv3x3(N, 3, 2), + ) + + h_a = nn.Sequential( + conv3x3(N, N), + nn.LeakyReLU(inplace=True), + conv3x3(N, N), + nn.LeakyReLU(inplace=True), + conv3x3(N, N, stride=2), + nn.LeakyReLU(inplace=True), + conv3x3(N, N), + nn.LeakyReLU(inplace=True), + conv3x3(N, N, stride=2), + ) + + h_s = nn.Sequential( + conv3x3(N, N), + nn.LeakyReLU(inplace=True), + subpel_conv3x3(N, N, 2), + nn.LeakyReLU(inplace=True), + conv3x3(N, N * 3 // 2), + nn.LeakyReLU(inplace=True), + subpel_conv3x3(N * 3 // 2, N * 3 // 2, 2), + nn.LeakyReLU(inplace=True), + conv3x3(N * 3 // 2, N * 2), + ) + + self.latent_codec = HyperpriorLatentCodec( + latent_codec={ + "y": ChannelGroupsLatentCodec( + groups=self.groups, + channel_context={ + f"y{k}": sequential_channel_ramp( + self.groups_acc[k], + self.groups[k] * 2, + num_layers=3, + make_layer=nn.Conv2d, + make_act=lambda: nn.ReLU(inplace=True), + kernel_size=5, + padding=2, + stride=1, + ) + for k in range(1, len(self.groups)) + }, + latent_codec={ + f"y{k}": CheckerboardLatentCodec( + latent_codec={ + "y": GaussianConditionalLatentCodec(quantizer="ste"), + }, + entropy_parameters=sequential_channel_ramp( + N * 2 + self.groups[k] * (2 if k == 0 else 4), + self.groups[k] * 2, + num_layers=3, + make_layer=nn.Conv2d, + make_act=lambda: nn.ReLU(inplace=True), + kernel_size=1, + padding=0, + stride=1, + ), + context_prediction=CheckerboardMaskedConv2d( + self.groups[k], + self.groups[k] * 2, + kernel_size=5, + padding=2, + stride=1, + ), + ) + for k in range(len(self.groups)) + }, + ), + "hyper": HyperLatentCodec( + entropy_bottleneck=EntropyBottleneck(N), h_a=h_a, h_s=h_s + ), + }, + ) + + @classmethod + def from_state_dict(cls, state_dict): + """Return a new model instance from `state_dict`.""" + N = state_dict["g_a.0.conv1.weight"].size(0) + net = cls(N) + net.load_state_dict(state_dict) + return net diff --git a/docs/source/latent_codecs.rst b/docs/source/latent_codecs.rst index b3ca3b6a..d5543b5b 100644 --- a/docs/source/latent_codecs.rst +++ b/docs/source/latent_codecs.rst @@ -31,6 +31,8 @@ CompressAI provides the following predefined :py:class:`~LatentCodec` subclasses - Like :py:class:`~HyperLatentCodec`, but with trainable gain vectors for ``z``. * - :py:class:`~GainHyperpriorLatentCodec` - Like :py:class:`~HyperpriorLatentCodec`, but with trainable gain vectors for ``y``. + * - :py:class:`~ChannelGroupsLatentCodec` + - Encodes ``y`` in multiple chunked groups, each group conditioned on previously encoded groups. * - :py:class:`~CheckerboardLatentCodec` - Encodes ``y`` in two passes in checkerboard order. @@ -332,6 +334,11 @@ GainHyperpriorLatentCodec .. autoclass:: GainHyperpriorLatentCodec +ChannelGroupsLatentCodec +------------------------ +.. autoclass:: ChannelGroupsLatentCodec + + CheckerboardLatentCodec ----------------------- .. autoclass:: CheckerboardLatentCodec diff --git a/docs/source/models.rst b/docs/source/models.rst index 0d69f59d..6dc0df24 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -50,6 +50,11 @@ Cheng2020AnchorCheckerboard .. autoclass:: Cheng2020AnchorCheckerboard +Cheng2020AnchorElic +--------------------------- +.. autoclass:: Cheng2020AnchorElic + + .. currentmodule:: compressai.models.video ScaleSpaceFlow From e1b5335631da6f7d902b57ebe8778055cd3e4ef5 Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Fri, 21 Jul 2023 00:37:22 -0700 Subject: [PATCH 08/37] refactor: improve clarity by labeling w.r.t. ELIC paper --- compressai/models/sensetime.py | 98 +++++++++++++++++++++------------- 1 file changed, 60 insertions(+), 38 deletions(-) diff --git a/compressai/models/sensetime.py b/compressai/models/sensetime.py index a8909d9d..eddb9a75 100644 --- a/compressai/models/sensetime.py +++ b/compressai/models/sensetime.py @@ -239,49 +239,71 @@ def __init__(self, N=192, groups=None, **kwargs): conv3x3(N * 3 // 2, N * 2), ) + # In [He2022], this is labeled "g_ch^(k)". + channel_context = { + f"y{k}": sequential_channel_ramp( + self.groups_acc[k], + self.groups[k] * 2, + num_layers=3, + make_layer=nn.Conv2d, + make_act=lambda: nn.ReLU(inplace=True), + kernel_size=5, + padding=2, + stride=1, + ) + for k in range(1, len(self.groups)) + } + + # In [He2022], this is labeled "g_sp^(k)". + spatial_context = [ + CheckerboardMaskedConv2d( + self.groups[k], + self.groups[k] * 2, + kernel_size=5, + padding=2, + stride=1, + ) + for k in range(len(self.groups)) + ] + + # In [He2022], this is labeled "Param Aggregation". + param_aggregation = [ + sequential_channel_ramp( + N * 2 + self.groups[k] * (2 if k == 0 else 4), + self.groups[k] * 2, + num_layers=3, + make_layer=nn.Conv2d, + make_act=lambda: nn.ReLU(inplace=True), + kernel_size=1, + padding=0, + stride=1, + ) + for k in range(len(self.groups)) + ] + + # In [He2022], this is labeled the space-channel context model (SCCTX). + # The side params and channel context params are computed externally. + scctx_latent_codec = { + f"y{k}": CheckerboardLatentCodec( + latent_codec={ + "y": GaussianConditionalLatentCodec(quantizer="ste"), + }, + context_prediction=spatial_context[k], + entropy_parameters=param_aggregation[k], + ) + for k in range(len(self.groups)) + } + + # [He2022] uses a "hyperprior" architecture, which reconstructs y using z. self.latent_codec = HyperpriorLatentCodec( latent_codec={ + # Channel groups with space-channel context model (SCCTX): "y": ChannelGroupsLatentCodec( groups=self.groups, - channel_context={ - f"y{k}": sequential_channel_ramp( - self.groups_acc[k], - self.groups[k] * 2, - num_layers=3, - make_layer=nn.Conv2d, - make_act=lambda: nn.ReLU(inplace=True), - kernel_size=5, - padding=2, - stride=1, - ) - for k in range(1, len(self.groups)) - }, - latent_codec={ - f"y{k}": CheckerboardLatentCodec( - latent_codec={ - "y": GaussianConditionalLatentCodec(quantizer="ste"), - }, - entropy_parameters=sequential_channel_ramp( - N * 2 + self.groups[k] * (2 if k == 0 else 4), - self.groups[k] * 2, - num_layers=3, - make_layer=nn.Conv2d, - make_act=lambda: nn.ReLU(inplace=True), - kernel_size=1, - padding=0, - stride=1, - ), - context_prediction=CheckerboardMaskedConv2d( - self.groups[k], - self.groups[k] * 2, - kernel_size=5, - padding=2, - stride=1, - ), - ) - for k in range(len(self.groups)) - }, + channel_context=channel_context, + latent_codec=scctx_latent_codec, ), + # Side information branch containing z: "hyper": HyperLatentCodec( entropy_bottleneck=EntropyBottleneck(N), h_a=h_a, h_s=h_s ), From f74400343fb6676f7fff920357ae9c036c95340b Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Sat, 22 Jul 2023 03:45:44 -0700 Subject: [PATCH 09/37] docs: checkerboard swap filled/empty colors --- compressai/latent_codecs/checkerboard.py | 40 ++++++++++++------------ 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/compressai/latent_codecs/checkerboard.py b/compressai/latent_codecs/checkerboard.py index 4bdf918d..c364cd5a 100644 --- a/compressai/latent_codecs/checkerboard.py +++ b/compressai/latent_codecs/checkerboard.py @@ -67,32 +67,32 @@ class CheckerboardLatentCodec(LatentCodec): 0. Input: - ■ ■ ■ ■ - ■ ■ ■ ■ - ■ ■ ■ ■ + □ □ □ □ + □ □ □ □ + □ □ □ □ 1. Decode anchors: - ◌ ■ ◌ ■ - ■ ◌ ■ ◌ - ◌ ■ ◌ ■ + ◌ □ ◌ □ + □ ◌ □ ◌ + ◌ □ ◌ □ 2. Decode non-anchors: - □ ◌ □ ◌ - ◌ □ ◌ □ - □ ◌ □ ◌ + ■ ◌ ■ ◌ + ◌ ■ ◌ ■ + ■ ◌ ■ ◌ 3. End result: - □ □ □ □ - □ □ □ □ - □ □ □ □ + ■ ■ ■ ■ + ■ ■ ■ ■ + ■ ■ ■ ■ LEGEND: - □ decoded + ■ decoded ◌ currently decoding - ■ empty + □ empty """ latent_codec: Mapping[str, LatentCodec] @@ -216,9 +216,9 @@ def unembed(self, y: Tensor) -> Tensor: .. code-block:: none - □ ■ □ ■ □ □ ■ ■ - ■ □ ■ □ ---> □ □ ■ ■ - □ ■ □ ■ □ □ ■ ■ + ■ □ ■ □ ■ ■ □ □ + □ ■ □ ■ ---> ■ ■ □ □ + ■ □ ■ □ ■ ■ □ □ """ n, c, h, w = y.shape y_ = y.new_zeros((2, n, c, h, w // 2)) @@ -233,9 +233,9 @@ def embed(self, y_: Tensor) -> Tensor: .. code-block:: none - □ □ ■ ■ □ ■ □ ■ - □ □ ■ ■ ---> ■ □ ■ □ - □ □ ■ ■ □ ■ □ ■ + ■ ■ □ □ ■ □ ■ □ + ■ ■ □ □ ---> □ ■ □ ■ + ■ ■ □ □ ■ □ ■ □ """ num_chunks, n, c, h, w_half = y_.shape assert num_chunks == 2 From 649a19619375466f26d915f2b655be30cb3c4fa1 Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Fri, 29 Sep 2023 23:40:17 -0700 Subject: [PATCH 10/37] fix: relax latent_codecs to ignore extra parameters via **kwargs Relaxing the signature allows us to do things like: ```python out_enc = module.compress(...) out_dec = module.decompress(**out_enc) ``` ...without fear for decompress() signature incompatibility when compress() outputs information not needed by decompress(). --- compressai/latent_codecs/channel_groups.py | 1 + compressai/latent_codecs/checkerboard.py | 6 +++++- compressai/latent_codecs/entropy_bottleneck.py | 2 +- compressai/latent_codecs/gain/hyper.py | 6 +++++- compressai/latent_codecs/gain/hyperprior.py | 1 + compressai/latent_codecs/gaussian_conditional.py | 6 +++++- compressai/latent_codecs/hyper.py | 2 +- compressai/latent_codecs/hyperprior.py | 2 +- compressai/latent_codecs/rasterscan.py | 6 +++++- compressai/models/base.py | 4 ++-- 10 files changed, 27 insertions(+), 9 deletions(-) diff --git a/compressai/latent_codecs/channel_groups.py b/compressai/latent_codecs/channel_groups.py index 0cd74baa..cc2a1e86 100644 --- a/compressai/latent_codecs/channel_groups.py +++ b/compressai/latent_codecs/channel_groups.py @@ -131,6 +131,7 @@ def decompress( strings: List[List[bytes]], shape: List[Tuple[int, ...]], side_params: Tensor, + **kwargs, ) -> Dict[str, Any]: n = len(strings[0]) assert all(len(ss) == n for ss in strings) diff --git a/compressai/latent_codecs/checkerboard.py b/compressai/latent_codecs/checkerboard.py index c364cd5a..236a8b2e 100644 --- a/compressai/latent_codecs/checkerboard.py +++ b/compressai/latent_codecs/checkerboard.py @@ -186,7 +186,11 @@ def compress(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: } def decompress( - self, strings: List[List[bytes]], shape: Tuple[int, ...], side_params: Tensor + self, + strings: List[List[bytes]], + shape: Tuple[int, ...], + side_params: Tensor, + **kwargs, ) -> Dict[str, Any]: y_strings_ = strings n = len(y_strings_[0]) diff --git a/compressai/latent_codecs/entropy_bottleneck.py b/compressai/latent_codecs/entropy_bottleneck.py index a0d935c5..94ba6a28 100644 --- a/compressai/latent_codecs/entropy_bottleneck.py +++ b/compressai/latent_codecs/entropy_bottleneck.py @@ -80,7 +80,7 @@ def compress(self, y: Tensor) -> Dict[str, Any]: return {"strings": [y_strings], "shape": shape, "y_hat": y_hat} def decompress( - self, strings: List[List[bytes]], shape: Tuple[int, int] + self, strings: List[List[bytes]], shape: Tuple[int, int], **kwargs ) -> Dict[str, Any]: (y_strings,) = strings y_hat = self.entropy_bottleneck.decompress(y_strings, shape) diff --git a/compressai/latent_codecs/gain/hyper.py b/compressai/latent_codecs/gain/hyper.py index 7f74a1e7..c019b03c 100644 --- a/compressai/latent_codecs/gain/hyper.py +++ b/compressai/latent_codecs/gain/hyper.py @@ -102,7 +102,11 @@ def compress(self, y: Tensor, gain: Tensor, gain_inv: Tensor) -> Dict[str, Any]: return {"strings": [z_strings], "shape": shape, "params": params} def decompress( - self, strings: List[List[bytes]], shape: Tuple[int, int], gain_inv: Tensor + self, + strings: List[List[bytes]], + shape: Tuple[int, int], + gain_inv: Tensor, + **kwargs, ) -> Dict[str, Any]: (z_strings,) = strings z_hat = self.entropy_bottleneck.decompress(z_strings, shape) diff --git a/compressai/latent_codecs/gain/hyperprior.py b/compressai/latent_codecs/gain/hyperprior.py index 5d7b4ab3..3d5aa311 100644 --- a/compressai/latent_codecs/gain/hyperprior.py +++ b/compressai/latent_codecs/gain/hyperprior.py @@ -152,6 +152,7 @@ def decompress( shape: Dict[str, Tuple[int, ...]], y_gain_inv: Tensor, z_gain_inv: Tensor, + **kwargs, ) -> Dict[str, Any]: *y_strings_, z_strings = strings assert all(len(y_strings) == len(z_strings) for y_strings in y_strings_) diff --git a/compressai/latent_codecs/gaussian_conditional.py b/compressai/latent_codecs/gaussian_conditional.py index cbd6a114..db6caf58 100644 --- a/compressai/latent_codecs/gaussian_conditional.py +++ b/compressai/latent_codecs/gaussian_conditional.py @@ -113,7 +113,11 @@ def compress(self, y: Tensor, ctx_params: Tensor) -> Dict[str, Any]: return {"strings": [y_strings], "shape": y.shape[2:4], "y_hat": y_hat} def decompress( - self, strings: List[List[bytes]], shape: Tuple[int, int], ctx_params: Tensor + self, + strings: List[List[bytes]], + shape: Tuple[int, int], + ctx_params: Tensor, + **kwargs, ) -> Dict[str, Any]: (y_strings,) = strings gaussian_params = self.entropy_parameters(ctx_params) diff --git a/compressai/latent_codecs/hyper.py b/compressai/latent_codecs/hyper.py index 02c86689..1ebeb311 100644 --- a/compressai/latent_codecs/hyper.py +++ b/compressai/latent_codecs/hyper.py @@ -96,7 +96,7 @@ def compress(self, y: Tensor) -> Dict[str, Any]: return {"strings": [z_strings], "shape": shape, "params": params} def decompress( - self, strings: List[List[bytes]], shape: Tuple[int, int] + self, strings: List[List[bytes]], shape: Tuple[int, int], **kwargs ) -> Dict[str, Any]: (z_strings,) = strings z_hat = self.entropy_bottleneck.decompress(z_strings, shape) diff --git a/compressai/latent_codecs/hyperprior.py b/compressai/latent_codecs/hyperprior.py index 3a2db19e..94d00e5e 100644 --- a/compressai/latent_codecs/hyperprior.py +++ b/compressai/latent_codecs/hyperprior.py @@ -125,7 +125,7 @@ def compress(self, y: Tensor) -> Dict[str, Any]: } def decompress( - self, strings: List[List[bytes]], shape: Dict[str, Tuple[int, ...]] + self, strings: List[List[bytes]], shape: Dict[str, Tuple[int, ...]], **kwargs ) -> Dict[str, Any]: *y_strings_, z_strings = strings assert all(len(y_strings) == len(z_strings) for y_strings in y_strings_) diff --git a/compressai/latent_codecs/rasterscan.py b/compressai/latent_codecs/rasterscan.py index 14d53905..3b2044e6 100644 --- a/compressai/latent_codecs/rasterscan.py +++ b/compressai/latent_codecs/rasterscan.py @@ -137,7 +137,11 @@ def _compress_single(self, **kwargs): return {"strings": [y_strings], "y_hat": y_hat.squeeze(0)} def decompress( - self, strings: List[List[bytes]], shape: Tuple[int, int], ctx_params: Tensor + self, + strings: List[List[bytes]], + shape: Tuple[int, int], + ctx_params: Tensor, + **kwargs, ) -> Dict[str, Any]: (y_strings,) = strings y_height, y_width = shape diff --git a/compressai/models/base.py b/compressai/models/base.py index 0f326ba0..f27f9c01 100644 --- a/compressai/models/base.py +++ b/compressai/models/base.py @@ -200,8 +200,8 @@ def compress(self, x): outputs = self.latent_codec.compress(y) return outputs - def decompress(self, strings, shape): - y_out = self.latent_codec.decompress(strings, shape) + def decompress(self, *args, **kwargs): + y_out = self.latent_codec.decompress(*args, **kwargs) y_hat = y_out["y_hat"] x_hat = self.g_s(y_hat).clamp_(0, 1) return { From 63604e1a5946f5ab1529c5b8f2bb0a9acecdfe39 Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Fri, 29 Sep 2023 23:50:12 -0700 Subject: [PATCH 11/37] feat: latent_codec convenient dict lookup Instead of ```python latent_codec.latent_codec[key] ``` users can now do ```python latent_codec[key] ``` directly. --- compressai/latent_codecs/channel_groups.py | 3 +++ compressai/latent_codecs/checkerboard.py | 3 +++ compressai/latent_codecs/gain/hyperprior.py | 3 +++ compressai/latent_codecs/hyperprior.py | 3 +++ compressai/models/base.py | 3 +++ 5 files changed, 15 insertions(+) diff --git a/compressai/latent_codecs/channel_groups.py b/compressai/latent_codecs/channel_groups.py index cc2a1e86..3caf62cb 100644 --- a/compressai/latent_codecs/channel_groups.py +++ b/compressai/latent_codecs/channel_groups.py @@ -83,6 +83,9 @@ def __init__( self.channel_context = nn.ModuleDict(channel_context) self.latent_codec = nn.ModuleDict(latent_codec) + def __getitem__(self, key: str) -> LatentCodec: + return self.latent_codec[key] + def forward(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: y_ = torch.split(y, self.groups, dim=1) y_out_ = [{}] * len(self.groups) diff --git a/compressai/latent_codecs/checkerboard.py b/compressai/latent_codecs/checkerboard.py index 236a8b2e..5d351fd8 100644 --- a/compressai/latent_codecs/checkerboard.py +++ b/compressai/latent_codecs/checkerboard.py @@ -122,6 +122,9 @@ def __init__( save_direct=True, ) + def __getitem__(self, key: str) -> LatentCodec: + return self.latent_codec[key] + def forward(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: if self.forward_method == "twopass": return self._forward_twopass(y, side_params) diff --git a/compressai/latent_codecs/gain/hyperprior.py b/compressai/latent_codecs/gain/hyperprior.py index 3d5aa311..cb4affdc 100644 --- a/compressai/latent_codecs/gain/hyperprior.py +++ b/compressai/latent_codecs/gain/hyperprior.py @@ -110,6 +110,9 @@ def __init__( save_direct=True, ) + def __getitem__(self, key: str) -> LatentCodec: + return self.latent_codec[key] + def forward( self, y: Tensor, diff --git a/compressai/latent_codecs/hyperprior.py b/compressai/latent_codecs/hyperprior.py index 94d00e5e..9781353a 100644 --- a/compressai/latent_codecs/hyperprior.py +++ b/compressai/latent_codecs/hyperprior.py @@ -103,6 +103,9 @@ def __init__( save_direct=True, ) + def __getitem__(self, key: str) -> LatentCodec: + return self.latent_codec[key] + def forward(self, y: Tensor) -> Dict[str, Any]: hyper_out = self.latent_codec["hyper"](y) y_out = self.latent_codec["y"](y, hyper_out["params"]) diff --git a/compressai/models/base.py b/compressai/models/base.py index f27f9c01..f1175d16 100644 --- a/compressai/models/base.py +++ b/compressai/models/base.py @@ -185,6 +185,9 @@ class SimpleVAECompressionModel(CompressionModel): g_s: nn.Module latent_codec: LatentCodec + def __getitem__(self, key: str) -> LatentCodec: + return self.latent_codec[key] + def forward(self, x): y = self.g_a(x) y_out = self.latent_codec(y) From e815f0c4ce40545b0ac8560d462408128305d9d3 Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Fri, 29 Sep 2023 18:19:25 -0700 Subject: [PATCH 12/37] feat: channel groups: expose merge as overridable Users may now override the merge method, select which items to merge, and how to order. --- compressai/latent_codecs/channel_groups.py | 32 +++++++++++++--------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/compressai/latent_codecs/channel_groups.py b/compressai/latent_codecs/channel_groups.py index 3caf62cb..de99a9ce 100644 --- a/compressai/latent_codecs/channel_groups.py +++ b/compressai/latent_codecs/channel_groups.py @@ -93,8 +93,7 @@ def forward(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: y_likelihoods_ = [Tensor()] * len(self.groups) for k in range(len(self.groups)): - y_hat_prev = torch.cat(y_hat_[:k], dim=1) if k > 0 else Tensor() - params = self._get_ctx_params(k, side_params, y_hat_prev) + params = self._get_ctx_params(k, side_params, y_hat_) y_out_[k] = self.latent_codec[f"y{k}"](y_[k], params) y_hat_[k] = y_out_[k]["y_hat"] y_likelihoods_[k] = y_out_[k]["likelihoods"]["y"] @@ -113,12 +112,12 @@ def compress(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: y_ = torch.split(y, self.groups, dim=1) y_out_ = [{}] * len(self.groups) y_hat = torch.zeros_like(y) + y_hat_ = y_hat.split(self.groups, dim=1) for k in range(len(self.groups)): - y_hat_prev = y_hat[:, : self.groups_acc[k]] - params = self._get_ctx_params(k, side_params, y_hat_prev) + params = self._get_ctx_params(k, side_params, y_hat_) y_out_[k] = self.latent_codec[f"y{k}"].compress(y_[k], params) - y_hat[:, self.groups_acc[k] : self.groups_acc[k + 1]] = y_out_[k]["y_hat"] + y_hat_[k][:] = y_out_[k]["y_hat"] y_strings_groups = [y_out["strings"] for y_out in y_out_] assert all(len(y_strings_groups[0]) == len(ss) for ss in y_strings_groups) @@ -143,24 +142,31 @@ def decompress( y_out_ = [{}] * len(self.groups) y_shape = (sum(s[0] for s in shape), *shape[0][1:]) y_hat = torch.zeros((n, *y_shape), device=side_params.device) + y_hat_ = y_hat.split(self.groups, dim=1) for k in range(len(self.groups)): - y_hat_prev = y_hat[:, : self.groups_acc[k]] - params = self._get_ctx_params(k, side_params, y_hat_prev) - y_strings_k = strings[strings_per_group * k : strings_per_group * (k + 1)] + params = self._get_ctx_params(k, side_params, y_hat_) y_out_[k] = self.latent_codec[f"y{k}"].decompress( - y_strings_k, shape[k], params + strings[strings_per_group * k : strings_per_group * (k + 1)], + shape[k], + params, ) - y_hat[:, self.groups_acc[k] : self.groups_acc[k + 1]] = y_out_[k]["y_hat"] + y_hat_[k][:] = y_out_[k]["y_hat"] return { "y_hat": y_hat, } + def merge_y(self, *args): + return torch.cat(args, dim=1) + + def merge_params(self, *args): + return torch.cat(args, dim=1) + def _get_ctx_params( - self, k: int, side_params: Tensor, y_hat_prev: Tensor + self, k: int, side_params: Tensor, y_hat_: List[Tensor] ) -> Tensor: if k == 0: return side_params - params_ch = self.channel_context[f"y{k}"](y_hat_prev) - return torch.cat([side_params, params_ch], dim=1) + ch_ctx_params = self.channel_context[f"y{k}"](self.merge_y(*y_hat_[:k])) + return self.merge_params(side_params, ch_ctx_params) From 1a61fff4926e195a349cec5c1e363cd574002c90 Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Fri, 29 Sep 2023 19:04:53 -0700 Subject: [PATCH 13/37] fix: use same params merge order as ELIC paper (for compatibility) --- compressai/latent_codecs/channel_groups.py | 2 +- compressai/latent_codecs/checkerboard.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/compressai/latent_codecs/channel_groups.py b/compressai/latent_codecs/channel_groups.py index de99a9ce..52817a61 100644 --- a/compressai/latent_codecs/channel_groups.py +++ b/compressai/latent_codecs/channel_groups.py @@ -169,4 +169,4 @@ def _get_ctx_params( if k == 0: return side_params ch_ctx_params = self.channel_context[f"y{k}"](self.merge_y(*y_hat_[:k])) - return self.merge_params(side_params, ch_ctx_params) + return self.merge_params(ch_ctx_params, side_params) diff --git a/compressai/latent_codecs/checkerboard.py b/compressai/latent_codecs/checkerboard.py index 5d351fd8..0cb45657 100644 --- a/compressai/latent_codecs/checkerboard.py +++ b/compressai/latent_codecs/checkerboard.py @@ -130,7 +130,7 @@ def forward(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: return self._forward_twopass(y, side_params) y_hat = self.quantize(y) y_ctx = self._mask_anchor(self.context_prediction(y_hat)) - ctx_params = self.entropy_parameters(self.merge(side_params, y_ctx)) + ctx_params = self.entropy_parameters(self.merge(y_ctx, side_params)) y_out = self.latent_codec["y"](y, ctx_params) return { "likelihoods": { @@ -143,7 +143,7 @@ def _forward_twopass(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: """Do context prediction on STE-quantized y_hat instead.""" y_hat_anchors = self._y_hat_anchors(y, side_params) y_ctx = self._mask_anchor(self.context_prediction(y_hat_anchors)) - ctx_params = self.entropy_parameters(self.merge(side_params, y_ctx)) + ctx_params = self.entropy_parameters(self.merge(y_ctx, side_params)) y_out = self.latent_codec["y"](y, ctx_params) # Reuse quantized y_hat that was used for non-anchor context prediction. y_hat = y_out["y_hat"] @@ -159,7 +159,7 @@ def _forward_twopass(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: def _y_hat_anchors(self, y, side_params): y_ctx = self.context_prediction(y).detach() y_ctx[:] = 0 - ctx_params = self.entropy_parameters(self.merge(side_params, y_ctx)) + ctx_params = self.entropy_parameters(self.merge(y_ctx, side_params)) ctx_params = self.latent_codec["y"].entropy_parameters(ctx_params) ctx_params = self._mask_non_anchor(ctx_params) # Probably not needed. _, means_hat = ctx_params.chunk(2, 1) @@ -175,7 +175,7 @@ def compress(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: for i in range(2): y_ctx_i = self.unembed(self.context_prediction(self.embed(y_hat_)))[i] - ctx_params_i = self.entropy_parameters(self.merge(side_params_[i], y_ctx_i)) + ctx_params_i = self.entropy_parameters(self.merge(y_ctx_i, side_params_[i])) y_out = self.latent_codec["y"].compress(y_[i], ctx_params_i) y_hat_[i] = y_out["y_hat"] [y_strings_[i]] = y_out["strings"] @@ -206,7 +206,7 @@ def decompress( for i in range(2): y_ctx_i = self.unembed(self.context_prediction(self.embed(y_hat_)))[i] - ctx_params_i = self.entropy_parameters(self.merge(side_params_[i], y_ctx_i)) + ctx_params_i = self.entropy_parameters(self.merge(y_ctx_i, side_params_[i])) y_out = self.latent_codec["y"].decompress( [y_strings_[i]], shape=(h, w // 2), ctx_params=ctx_params_i ) From b31ca22a6a3b6984ecdc7641fb5c1f79dcde57fd Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Sat, 30 Sep 2023 01:56:48 -0700 Subject: [PATCH 14/37] feat: allow choosing Checkerboard parity --- compressai/latent_codecs/checkerboard.py | 63 ++++++++++++++++++------ 1 file changed, 47 insertions(+), 16 deletions(-) diff --git a/compressai/latent_codecs/checkerboard.py b/compressai/latent_codecs/checkerboard.py index 0cb45657..b2258999 100644 --- a/compressai/latent_codecs/checkerboard.py +++ b/compressai/latent_codecs/checkerboard.py @@ -105,11 +105,14 @@ def __init__( latent_codec: Optional[Mapping[str, LatentCodec]] = None, entropy_parameters: Optional[nn.Module] = None, context_prediction: Optional[nn.Module] = None, + anchor_parity="even", forward_method="twopass", **kwargs, ): super().__init__() self._kwargs = kwargs + self.anchor_parity = anchor_parity + self.non_anchor_parity = {"odd": "even", "even": "odd"}[anchor_parity] self.forward_method = forward_method self.entropy_parameters = entropy_parameters or nn.Identity() self.context_prediction = context_prediction or nn.Identity() @@ -147,8 +150,7 @@ def _forward_twopass(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: y_out = self.latent_codec["y"](y, ctx_params) # Reuse quantized y_hat that was used for non-anchor context prediction. y_hat = y_out["y_hat"] - y_hat[..., 0::2, 0::2] = y_hat_anchors[..., 0::2, 0::2] - y_hat[..., 1::2, 1::2] = y_hat_anchors[..., 1::2, 1::2] + self._copy_anchors(y_hat, y_hat_anchors) return { "likelihoods": { "y": y_out["likelihoods"]["y"], @@ -229,10 +231,16 @@ def unembed(self, y: Tensor) -> Tensor: """ n, c, h, w = y.shape y_ = y.new_zeros((2, n, c, h, w // 2)) - y_[0, ..., 0::2, :] = y[..., 0::2, 0::2] - y_[0, ..., 1::2, :] = y[..., 1::2, 1::2] - y_[1, ..., 0::2, :] = y[..., 0::2, 1::2] - y_[1, ..., 1::2, :] = y[..., 1::2, 0::2] + if self.anchor_parity == "even": + y_[0, ..., 0::2, :] = y[..., 0::2, 0::2] + y_[0, ..., 1::2, :] = y[..., 1::2, 1::2] + y_[1, ..., 0::2, :] = y[..., 0::2, 1::2] + y_[1, ..., 1::2, :] = y[..., 1::2, 0::2] + else: + y_[0, ..., 0::2, :] = y[..., 0::2, 1::2] + y_[0, ..., 1::2, :] = y[..., 1::2, 0::2] + y_[1, ..., 0::2, :] = y[..., 0::2, 0::2] + y_[1, ..., 1::2, :] = y[..., 1::2, 1::2] return y_ def embed(self, y_: Tensor) -> Tensor: @@ -247,21 +255,44 @@ def embed(self, y_: Tensor) -> Tensor: num_chunks, n, c, h, w_half = y_.shape assert num_chunks == 2 y = y_.new_zeros((n, c, h, w_half * 2)) - y[..., 0::2, 0::2] = y_[0, ..., 0::2, :] - y[..., 1::2, 1::2] = y_[0, ..., 1::2, :] - y[..., 0::2, 1::2] = y_[1, ..., 0::2, :] - y[..., 1::2, 0::2] = y_[1, ..., 1::2, :] + if self.anchor_parity == "even": + y[..., 0::2, 0::2] = y_[0, ..., 0::2, :] + y[..., 1::2, 1::2] = y_[0, ..., 1::2, :] + y[..., 0::2, 1::2] = y_[1, ..., 0::2, :] + y[..., 1::2, 0::2] = y_[1, ..., 1::2, :] + else: + y[..., 0::2, 1::2] = y_[0, ..., 0::2, :] + y[..., 1::2, 0::2] = y_[0, ..., 1::2, :] + y[..., 0::2, 0::2] = y_[1, ..., 0::2, :] + y[..., 1::2, 1::2] = y_[1, ..., 1::2, :] return y - def _mask_anchor(self, y: Tensor) -> Tensor: - y[..., 0::2, 0::2] = 0 - y[..., 1::2, 1::2] = 0 + def _copy(self, dest: Tensor, src: Tensor, parity: str) -> None: + """Copy pixels of the given parity.""" + if parity == "even": + dest[..., 0::2, 0::2] = src[..., 0::2, 0::2] + dest[..., 1::2, 1::2] = src[..., 1::2, 1::2] + else: + dest[..., 0::2, 1::2] = src[..., 0::2, 1::2] + dest[..., 1::2, 0::2] = src[..., 1::2, 0::2] + + def _copy_anchors(self, dest: Tensor, src: Tensor) -> None: + return self._copy(dest, src, parity=self.anchor_parity) + + def _mask(self, y: Tensor, parity: str) -> Tensor: + if parity == "even": + y[..., 0::2, 0::2] = 0 + y[..., 1::2, 1::2] = 0 + else: + y[..., 0::2, 1::2] = 0 + y[..., 1::2, 0::2] = 0 return y + def _mask_anchor(self, y: Tensor) -> Tensor: + return self._mask(y, parity=self.anchor_parity) + def _mask_non_anchor(self, y: Tensor) -> Tensor: - y[..., 0::2, 1::2] = 0 - y[..., 1::2, 0::2] = 0 - return y + return self._mask(y, parity=self.non_anchor_parity) def merge(self, *args): return torch.cat(args, dim=1) From fd8767b2476958c8355d4b98adf70ad79c3e958a Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Sat, 30 Sep 2023 19:57:59 -0700 Subject: [PATCH 15/37] refactor: extract _forward_onepass --- compressai/latent_codecs/checkerboard.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/compressai/latent_codecs/checkerboard.py b/compressai/latent_codecs/checkerboard.py index b2258999..2d6bf986 100644 --- a/compressai/latent_codecs/checkerboard.py +++ b/compressai/latent_codecs/checkerboard.py @@ -129,8 +129,14 @@ def __getitem__(self, key: str) -> LatentCodec: return self.latent_codec[key] def forward(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: + if self.forward_method == "onepass": + return self._forward_onepass(y, side_params) if self.forward_method == "twopass": return self._forward_twopass(y, side_params) + raise ValueError(f"Unknown forward method: {self.forward_method}") + + def _forward_onepass(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: + """Fast estimation with single pass of the context model.""" y_hat = self.quantize(y) y_ctx = self._mask_anchor(self.context_prediction(y_hat)) ctx_params = self.entropy_parameters(self.merge(y_ctx, side_params)) From 0d063f63f5dd715d246d7a494bb52d62e48d04eb Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Sat, 30 Sep 2023 20:13:35 -0700 Subject: [PATCH 16/37] refactor: inline _y_hat_anchors --- compressai/latent_codecs/checkerboard.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/compressai/latent_codecs/checkerboard.py b/compressai/latent_codecs/checkerboard.py index 2d6bf986..7413d700 100644 --- a/compressai/latent_codecs/checkerboard.py +++ b/compressai/latent_codecs/checkerboard.py @@ -150,13 +150,20 @@ def _forward_onepass(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: def _forward_twopass(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: """Do context prediction on STE-quantized y_hat instead.""" - y_hat_anchors = self._y_hat_anchors(y, side_params) + y_ctx = self._y_ctx_zero(y) + ctx_params = self.entropy_parameters(self.merge(y_ctx, side_params)) + ctx_params = self.latent_codec["y"].entropy_parameters(ctx_params) + ctx_params = self._mask_non_anchor(ctx_params) # Probably not needed. + _, means_hat = ctx_params.chunk(2, 1) + y_hat_anchors = quantize_ste(y - means_hat) + means_hat + y_ctx = self._mask_anchor(self.context_prediction(y_hat_anchors)) ctx_params = self.entropy_parameters(self.merge(y_ctx, side_params)) y_out = self.latent_codec["y"](y, ctx_params) # Reuse quantized y_hat that was used for non-anchor context prediction. y_hat = y_out["y_hat"] self._copy_anchors(y_hat, y_hat_anchors) + return { "likelihoods": { "y": y_out["likelihoods"]["y"], @@ -164,15 +171,11 @@ def _forward_twopass(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: "y_hat": y_hat, } - def _y_hat_anchors(self, y, side_params): + def _y_ctx_zero(self, y): + """Create a zero tensor of the required shape.""" y_ctx = self.context_prediction(y).detach() y_ctx[:] = 0 - ctx_params = self.entropy_parameters(self.merge(y_ctx, side_params)) - ctx_params = self.latent_codec["y"].entropy_parameters(ctx_params) - ctx_params = self._mask_non_anchor(ctx_params) # Probably not needed. - _, means_hat = ctx_params.chunk(2, 1) - y_hat = quantize_ste(y - means_hat) + means_hat - return y_hat + return y_ctx def compress(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: n, c, h, w = y.shape From cdeb5a1cf2ff0a53fca239348c835bfb8c3c9a78 Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Sun, 1 Oct 2023 16:51:59 -0700 Subject: [PATCH 17/37] feat: GaussianConditionalLatentCodec allow choosing means/scales Some models use scales only; some use scales, means; some use means, scales. Perhaps even models that use means only. --- .../latent_codecs/gaussian_conditional.py | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/compressai/latent_codecs/gaussian_conditional.py b/compressai/latent_codecs/gaussian_conditional.py index db6caf58..321dd5ff 100644 --- a/compressai/latent_codecs/gaussian_conditional.py +++ b/compressai/latent_codecs/gaussian_conditional.py @@ -85,6 +85,7 @@ def __init__( gaussian_conditional: Optional[GaussianConditional] = None, entropy_parameters: Optional[nn.Module] = None, quantizer: str = "noise", + chunks: Tuple[str] = ("scales", "means"), **kwargs, ): super().__init__() @@ -93,10 +94,11 @@ def __init__( scale_table, **kwargs ) self.entropy_parameters = entropy_parameters or nn.Identity() + self.chunks = tuple(chunks) def forward(self, y: Tensor, ctx_params: Tensor) -> Dict[str, Any]: gaussian_params = self.entropy_parameters(ctx_params) - scales_hat, means_hat = gaussian_params.chunk(2, 1) + scales_hat, means_hat = self._chunk(gaussian_params) y_hat, y_likelihoods = self.gaussian_conditional(y, scales_hat, means=means_hat) if self.quantizer == "ste": y_hat = quantize_ste(y - means_hat) + means_hat @@ -104,7 +106,7 @@ def forward(self, y: Tensor, ctx_params: Tensor) -> Dict[str, Any]: def compress(self, y: Tensor, ctx_params: Tensor) -> Dict[str, Any]: gaussian_params = self.entropy_parameters(ctx_params) - scales_hat, means_hat = gaussian_params.chunk(2, 1) + scales_hat, means_hat = self._chunk(gaussian_params) indexes = self.gaussian_conditional.build_indexes(scales_hat) y_strings = self.gaussian_conditional.compress(y, indexes, means_hat) y_hat = self.gaussian_conditional.decompress( @@ -121,10 +123,22 @@ def decompress( ) -> Dict[str, Any]: (y_strings,) = strings gaussian_params = self.entropy_parameters(ctx_params) - scales_hat, means_hat = gaussian_params.chunk(2, 1) + scales_hat, means_hat = self._chunk(gaussian_params) indexes = self.gaussian_conditional.build_indexes(scales_hat) y_hat = self.gaussian_conditional.decompress( y_strings, indexes, means=means_hat ) assert y_hat.shape[2:4] == shape return {"y_hat": y_hat} + + def _chunk(self, params: Tensor) -> Tuple[Tensor, Tensor]: + scales, means = None, None + if self.chunks == ("scales",): + scales = params + if self.chunks == ("means",): + means = params + if self.chunks == ("scales", "means"): + scales, means = params.chunk(2, 1) + if self.chunks == ("means", "scales"): + means, scales = params.chunk(2, 1) + return scales, means From 2216efb072eb97623d08b72167fbc810441066d8 Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Sun, 1 Oct 2023 18:00:42 -0700 Subject: [PATCH 18/37] refactor: _keep_only, _mask all --- compressai/latent_codecs/checkerboard.py | 33 +++++++++++++----------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/compressai/latent_codecs/checkerboard.py b/compressai/latent_codecs/checkerboard.py index 7413d700..2a228996 100644 --- a/compressai/latent_codecs/checkerboard.py +++ b/compressai/latent_codecs/checkerboard.py @@ -138,7 +138,7 @@ def forward(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: def _forward_onepass(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: """Fast estimation with single pass of the context model.""" y_hat = self.quantize(y) - y_ctx = self._mask_anchor(self.context_prediction(y_hat)) + y_ctx = self._keep_only(self.context_prediction(y_hat), "non_anchor") ctx_params = self.entropy_parameters(self.merge(y_ctx, side_params)) y_out = self.latent_codec["y"](y, ctx_params) return { @@ -153,16 +153,19 @@ def _forward_twopass(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: y_ctx = self._y_ctx_zero(y) ctx_params = self.entropy_parameters(self.merge(y_ctx, side_params)) ctx_params = self.latent_codec["y"].entropy_parameters(ctx_params) - ctx_params = self._mask_non_anchor(ctx_params) # Probably not needed. + ctx_params = self._keep_only(ctx_params, "anchor") # Probably not needed. _, means_hat = ctx_params.chunk(2, 1) y_hat_anchors = quantize_ste(y - means_hat) + means_hat + y_hat_anchors = self._keep_only(y_hat_anchors, "anchor") - y_ctx = self._mask_anchor(self.context_prediction(y_hat_anchors)) + y_ctx = self.context_prediction(y_hat_anchors) + y_ctx = self._keep_only(y_ctx, "non_anchor") # Probably not needed. ctx_params = self.entropy_parameters(self.merge(y_ctx, side_params)) y_out = self.latent_codec["y"](y, ctx_params) + # Reuse quantized y_hat that was used for non-anchor context prediction. y_hat = y_out["y_hat"] - self._copy_anchors(y_hat, y_hat_anchors) + self._copy(y_hat, y_hat_anchors, "anchor") # Probably not needed. return { "likelihoods": { @@ -276,8 +279,10 @@ def embed(self, y_: Tensor) -> Tensor: y[..., 1::2, 1::2] = y_[1, ..., 1::2, :] return y - def _copy(self, dest: Tensor, src: Tensor, parity: str) -> None: - """Copy pixels of the given parity.""" + def _copy(self, dest: Tensor, src: Tensor, step: str) -> None: + """Copy pixels in the current step.""" + assert step in ("anchor", "non_anchor") + parity = self.anchor_parity if step == "anchor" else self.non_anchor_parity if parity == "even": dest[..., 0::2, 0::2] = src[..., 0::2, 0::2] dest[..., 1::2, 1::2] = src[..., 1::2, 1::2] @@ -285,24 +290,22 @@ def _copy(self, dest: Tensor, src: Tensor, parity: str) -> None: dest[..., 0::2, 1::2] = src[..., 0::2, 1::2] dest[..., 1::2, 0::2] = src[..., 1::2, 0::2] - def _copy_anchors(self, dest: Tensor, src: Tensor) -> None: - return self._copy(dest, src, parity=self.anchor_parity) + def _keep_only(self, y: Tensor, step: str) -> Tensor: + """Keep only pixels in the current step, and zero out the rest.""" + parity = self.non_anchor_parity if step == "anchor" else self.anchor_parity + return self._mask(y, parity) def _mask(self, y: Tensor, parity: str) -> Tensor: if parity == "even": y[..., 0::2, 0::2] = 0 y[..., 1::2, 1::2] = 0 - else: + elif parity == "odd": y[..., 0::2, 1::2] = 0 y[..., 1::2, 0::2] = 0 + elif parity == "all": + y[:] = 0 return y - def _mask_anchor(self, y: Tensor) -> Tensor: - return self._mask(y, parity=self.anchor_parity) - - def _mask_non_anchor(self, y: Tensor) -> Tensor: - return self._mask(y, parity=self.non_anchor_parity) - def merge(self, *args): return torch.cat(args, dim=1) From 18f9282c947c1e861578876b4f99e42d004bbe8e Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Sun, 1 Oct 2023 18:06:45 -0700 Subject: [PATCH 19/37] refactor: _y_ctx_zero use _mask --- compressai/latent_codecs/checkerboard.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/compressai/latent_codecs/checkerboard.py b/compressai/latent_codecs/checkerboard.py index 2a228996..7022b9d5 100644 --- a/compressai/latent_codecs/checkerboard.py +++ b/compressai/latent_codecs/checkerboard.py @@ -174,11 +174,10 @@ def _forward_twopass(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: "y_hat": y_hat, } + @torch.no_grad() def _y_ctx_zero(self, y): - """Create a zero tensor of the required shape.""" - y_ctx = self.context_prediction(y).detach() - y_ctx[:] = 0 - return y_ctx + """Create a zero tensor with correct shape for y_ctx.""" + return self._mask(self.context_prediction(y).detach(), "all") def compress(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: n, c, h, w = y.shape From 1e9778f43cd2b57bc555355d1c6b308f43839e81 Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Sun, 1 Oct 2023 18:04:20 -0700 Subject: [PATCH 20/37] fix: mask y_ctx[0] to zero during compress/decompress context_prediction(0) is not 0...! The model expects the first y_ctx to be 0, however. This fixes a big RD bpp mismatch bug. --- compressai/latent_codecs/checkerboard.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/compressai/latent_codecs/checkerboard.py b/compressai/latent_codecs/checkerboard.py index 7022b9d5..8e2b9cdc 100644 --- a/compressai/latent_codecs/checkerboard.py +++ b/compressai/latent_codecs/checkerboard.py @@ -188,6 +188,8 @@ def compress(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: for i in range(2): y_ctx_i = self.unembed(self.context_prediction(self.embed(y_hat_)))[i] + if i == 0: + y_ctx_i = self._mask(y_ctx_i, "all") ctx_params_i = self.entropy_parameters(self.merge(y_ctx_i, side_params_[i])) y_out = self.latent_codec["y"].compress(y_[i], ctx_params_i) y_hat_[i] = y_out["y_hat"] @@ -219,6 +221,8 @@ def decompress( for i in range(2): y_ctx_i = self.unembed(self.context_prediction(self.embed(y_hat_)))[i] + if i == 0: + y_ctx_i = self._mask(y_ctx_i, "all") ctx_params_i = self.entropy_parameters(self.merge(y_ctx_i, side_params_[i])) y_out = self.latent_codec["y"].decompress( [y_strings_[i]], shape=(h, w // 2), ctx_params=ctx_params_i From 9299f1442c7ef3cffedb104aa62548606c76d176 Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Sun, 1 Oct 2023 18:24:00 -0700 Subject: [PATCH 21/37] fix: CheckerboardLatentCodec allow arbitrary _chunk order --- compressai/latent_codecs/checkerboard.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compressai/latent_codecs/checkerboard.py b/compressai/latent_codecs/checkerboard.py index 8e2b9cdc..b77f8579 100644 --- a/compressai/latent_codecs/checkerboard.py +++ b/compressai/latent_codecs/checkerboard.py @@ -154,7 +154,7 @@ def _forward_twopass(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: ctx_params = self.entropy_parameters(self.merge(y_ctx, side_params)) ctx_params = self.latent_codec["y"].entropy_parameters(ctx_params) ctx_params = self._keep_only(ctx_params, "anchor") # Probably not needed. - _, means_hat = ctx_params.chunk(2, 1) + _, means_hat = self.latent_codec["y"]._chunk(ctx_params) y_hat_anchors = quantize_ste(y - means_hat) + means_hat y_hat_anchors = self._keep_only(y_hat_anchors, "anchor") From 58fd627452c4495a1130924712bba6d2f4c30bc4 Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Sun, 1 Oct 2023 19:08:05 -0700 Subject: [PATCH 22/37] feat: _forward_twopass new implementation --- compressai/latent_codecs/checkerboard.py | 73 ++++++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/compressai/latent_codecs/checkerboard.py b/compressai/latent_codecs/checkerboard.py index b77f8579..28e93e10 100644 --- a/compressai/latent_codecs/checkerboard.py +++ b/compressai/latent_codecs/checkerboard.py @@ -53,10 +53,22 @@ class CheckerboardLatentCodec(LatentCodec): Checkerboard context model introduced in [He2021]. + - `forward_method="one_pass"` is fastest, but does not use + quantization based on the intermediate means. + - `forward_method="two_pass"` is slightly slower, but accurately + quantizes based on the intermediate means. + Uses the same operations as [Chandelier2023]. + - `forward_method="two_pass_faster"` uses slightly fewer + redundant operations, but may not work in all cases. + [He2021]: `"Checkerboard Context Model for Efficient Learned Image Compression" `_, by Dailan He, Yaoyan Zheng, Baocheng Sun, Yan Wang, and Hongwei Qin, CVPR 2021. + [Chandelier2023]: `"ELiC-ReImplemetation" + `_, by + Vincent Chandelier, 2023. + .. warning:: This implementation assumes that ``entropy_parameters`` is a pointwise function, e.g., a composition of 1x1 convs and pointwise nonlinearities. @@ -133,6 +145,8 @@ def forward(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: return self._forward_onepass(y, side_params) if self.forward_method == "twopass": return self._forward_twopass(y, side_params) + if self.forward_method == "twopass_faster": + return self._forward_twopass_faster(y, side_params) raise ValueError(f"Unknown forward method: {self.forward_method}") def _forward_onepass(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: @@ -149,6 +163,65 @@ def _forward_onepass(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: } def _forward_twopass(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: + """Do context prediction on STE-quantized y_hat instead.""" + B, C, H, W = y.shape + ctx_params = y.new_zeros((B, C * 2, H, W)) + + y_hat_anchors = self._forward_twopass_step( + y, + side_params, + self._y_ctx_zero(y), + ctx_params, + step="anchor", + ) + + y_hat_non_anchors = self._forward_twopass_step( + y, + side_params, + self.context_prediction(y_hat_anchors), + ctx_params, + step="non_anchor", + ) + + # NOTE: We could also use y_hat = y_out["y_hat"] if it uses the same quantizer. + y_hat = y_hat_anchors + y_hat_non_anchors + y_out = self.latent_codec["y"](y, ctx_params) + + return { + "likelihoods": { + "y": y_out["likelihoods"]["y"], + }, + "y_hat": y_hat, + } + + def _forward_twopass_step( + self, y: Tensor, side_params: Tensor, y_ctx: Tensor, params: Tensor, step: str + ) -> Dict[str, Any]: + # NOTE: The _i variables only contain the current step's pixels. + + # Estimate parameters. + params_i = self.entropy_parameters(self.merge(y_ctx, side_params)) + + # Save params for current step. This is later used for entropy estimation. + self._copy(params, params_i, step) + + # Technically, latent_codec may also contain an "entropy_parameters" method. + # Usually, it is identity, though. + params_i = self.latent_codec["y"].entropy_parameters(params_i) + + # Keep only elements needed for current step. + params_i = self._keep_only(params_i, step) + + # NOTE: It's not necessary to mask out the non-step pixels, but it doesn't hurt. + y_i = self._keep_only(y.clone(), step) + + # Determine y_hat for current step. + _, means_i = self.latent_codec["y"]._chunk(params_i) + y_hat_i = self._keep_only(quantize_ste(y_i - means_i) + means_i, step) + + return y_hat_i + + def _forward_twopass_faster(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: """Do context prediction on STE-quantized y_hat instead.""" y_ctx = self._y_ctx_zero(y) ctx_params = self.entropy_parameters(self.merge(y_ctx, side_params)) From bd27d3932c7e03a6bdbadd78489eaa8b2940b7f7 Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Sun, 1 Oct 2023 17:58:08 -0700 Subject: [PATCH 23/37] feat: HyperLatentCodec quantizer selection --- compressai/latent_codecs/hyper.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/compressai/latent_codecs/hyper.py b/compressai/latent_codecs/hyper.py index 1ebeb311..e8c5edd4 100644 --- a/compressai/latent_codecs/hyper.py +++ b/compressai/latent_codecs/hyper.py @@ -34,6 +34,7 @@ from torch import Tensor from compressai.entropy_models import EntropyBottleneck +from compressai.ops import quantize_ste from compressai.registry import register_module from .base import LatentCodec @@ -73,6 +74,7 @@ def __init__( entropy_bottleneck: Optional[EntropyBottleneck] = None, h_a: Optional[nn.Module] = None, h_s: Optional[nn.Module] = None, + quantizer: str = "noise", **kwargs, ): super().__init__() @@ -80,10 +82,14 @@ def __init__( self.entropy_bottleneck = entropy_bottleneck self.h_a = h_a or nn.Identity() self.h_s = h_s or nn.Identity() + self.quantizer = quantizer def forward(self, y: Tensor) -> Dict[str, Any]: z = self.h_a(y) z_hat, z_likelihoods = self.entropy_bottleneck(z) + if self.quantizer == "ste": + z_medians = self.entropy_bottleneck._get_medians() + z_hat = quantize_ste(z - z_medians) + z_medians params = self.h_s(z_hat) return {"likelihoods": {"z": z_likelihoods}, "params": params} From d766af2b3d15e12a0cfa0e80c4ab480c3991affd Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Fri, 29 Sep 2023 21:44:15 -0700 Subject: [PATCH 24/37] fix: compressai.layers.layers missing export conv1x1 --- compressai/layers/layers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/compressai/layers/layers.py b/compressai/layers/layers.py index a9fae68a..b8bad8fa 100644 --- a/compressai/layers/layers.py +++ b/compressai/layers/layers.py @@ -46,6 +46,7 @@ "ResidualBlock", "ResidualBlockUpsample", "ResidualBlockWithStride", + "conv1x1", "conv3x3", "subpel_conv3x3", "QReLU", From 4c33f0e3537f1aa9a5b5ad1ac2c74661b840fead Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Fri, 29 Sep 2023 06:21:44 -0700 Subject: [PATCH 25/37] feat: elic2022 use official architecture recommendations See figures 5, 6, 7 of [He2022] ELIC paper. --- compressai/models/sensetime.py | 171 +++++++++++++++++++++------------ 1 file changed, 112 insertions(+), 59 deletions(-) diff --git a/compressai/models/sensetime.py b/compressai/models/sensetime.py index eddb9a75..a3021d40 100644 --- a/compressai/models/sensetime.py +++ b/compressai/models/sensetime.py @@ -27,10 +27,11 @@ # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from itertools import accumulate - +import torch import torch.nn as nn +from torch import Tensor + from compressai.entropy_models import EntropyBottleneck from compressai.latent_codecs import ( ChannelGroupsLatentCodec, @@ -40,10 +41,12 @@ HyperpriorLatentCodec, ) from compressai.layers import ( + AttentionBlock, CheckerboardMaskedConv2d, ResidualBlock, ResidualBlockUpsample, ResidualBlockWithStride, + conv1x1, conv3x3, sequential_channel_ramp, subpel_conv3x3, @@ -51,10 +54,11 @@ from compressai.registry import register_model from .base import SimpleVAECompressionModel +from .utils import conv, deconv __all__ = [ "Cheng2020AnchorCheckerboard", - "Cheng2020AnchorElic", + "Elic2022Official", ] @@ -142,7 +146,7 @@ def __init__(self, N=192, **kwargs): nn.Conv2d(N * 8 // 3, N * 6 // 3, 1), ), context_prediction=CheckerboardMaskedConv2d( - N, 2 * N, kernel_size=5, padding=2, stride=1 + N, 2 * N, kernel_size=5, stride=1, padding=2 ), ), "hyper": HyperLatentCodec( @@ -160,96 +164,101 @@ def from_state_dict(cls, state_dict): return net -@register_model("cheng2020-anchor-elic") -class Cheng2020AnchorElic(SimpleVAECompressionModel): - """Cheng2020 anchor model with checkerboard context model. +@register_model("elic2022-official") +class Elic2022Official(SimpleVAECompressionModel): + """ELIC 2022; uneven channel groups with checkerboard spatial context. - Base transform model from [Cheng2020]. Context model from [He2022]. - - [Cheng2020]: `"Learned Image Compression with Discretized Gaussian - Mixture Likelihoods and Attention Modules" - `_, by Zhengxue Cheng, Heming Sun, - Masaru Takeuchi, and Jiro Katto, CVPR 2020. + Context model from [He2022]. + Based on modified attention model architecture from [Cheng2020]. [He2022]: `"ELIC: Efficient Learned Image Compression with Unevenly Grouped Space-Channel Contextual Adaptive Coding" `_, by Dailan He, Ziming Yang, Weikun Peng, Rui Ma, Hongwei Qin, and Yan Wang, CVPR 2022. - Uses residual blocks with small convolutions (3x3 and 1x1), and sub-pixel - convolutions for up-sampling. + [Cheng2020]: `"Learned Image Compression with Discretized Gaussian + Mixture Likelihoods and Attention Modules" + `_, by Zhengxue Cheng, Heming Sun, + Masaru Takeuchi, and Jiro Katto, CVPR 2020. Args: - N (int): Number of channels + N (int): Number of main network channels + M (int): Number of latent space channels groups (list[int]): Number of channels in each channel group """ - def __init__(self, N=192, groups=None, **kwargs): + def __init__(self, N=192, M=320, groups=None, **kwargs): super().__init__(**kwargs) if groups is None: - groups = [16, 16, 32, 64, 64] + groups = [16, 16, 32, 64, M - 128] - assert sum(groups) == N self.groups = list(groups) - self.groups_acc = list(accumulate(self.groups, initial=0)) + assert sum(self.groups) == M self.g_a = nn.Sequential( - ResidualBlockWithStride(3, N, stride=2), - ResidualBlock(N, N), - ResidualBlockWithStride(N, N, stride=2), - ResidualBlock(N, N), - ResidualBlockWithStride(N, N, stride=2), - ResidualBlock(N, N), - conv3x3(N, N, stride=2), + conv(3, N, kernel_size=5, stride=2), + ResidualBottleneckBlock(N, N), + ResidualBottleneckBlock(N, N), + ResidualBottleneckBlock(N, N), + conv(N, N, kernel_size=5, stride=2), + ResidualBottleneckBlock(N, N), + ResidualBottleneckBlock(N, N), + ResidualBottleneckBlock(N, N), + AttentionBlock(N), + conv(N, N, kernel_size=5, stride=2), + ResidualBottleneckBlock(N, N), + ResidualBottleneckBlock(N, N), + ResidualBottleneckBlock(N, N), + conv(N, M, kernel_size=5, stride=2), + AttentionBlock(M), ) self.g_s = nn.Sequential( - ResidualBlock(N, N), - ResidualBlockUpsample(N, N, 2), - ResidualBlock(N, N), - ResidualBlockUpsample(N, N, 2), - ResidualBlock(N, N), - ResidualBlockUpsample(N, N, 2), - ResidualBlock(N, N), - subpel_conv3x3(N, 3, 2), + AttentionBlock(M), + deconv(M, N, kernel_size=5, stride=2), + ResidualBottleneckBlock(N, N), + ResidualBottleneckBlock(N, N), + ResidualBottleneckBlock(N, N), + deconv(N, N, kernel_size=5, stride=2), + AttentionBlock(N), + ResidualBottleneckBlock(N, N), + ResidualBottleneckBlock(N, N), + ResidualBottleneckBlock(N, N), + deconv(N, N, kernel_size=5, stride=2), + ResidualBottleneckBlock(N, N), + ResidualBottleneckBlock(N, N), + ResidualBottleneckBlock(N, N), + deconv(N, 3, kernel_size=5, stride=2), ) h_a = nn.Sequential( - conv3x3(N, N), - nn.LeakyReLU(inplace=True), - conv3x3(N, N), - nn.LeakyReLU(inplace=True), - conv3x3(N, N, stride=2), - nn.LeakyReLU(inplace=True), - conv3x3(N, N), - nn.LeakyReLU(inplace=True), - conv3x3(N, N, stride=2), + conv(M, N, kernel_size=3, stride=1), + nn.ReLU(inplace=True), + conv(N, N, kernel_size=5, stride=2), + nn.ReLU(inplace=True), + conv(N, N, kernel_size=5, stride=2), ) h_s = nn.Sequential( - conv3x3(N, N), - nn.LeakyReLU(inplace=True), - subpel_conv3x3(N, N, 2), - nn.LeakyReLU(inplace=True), - conv3x3(N, N * 3 // 2), - nn.LeakyReLU(inplace=True), - subpel_conv3x3(N * 3 // 2, N * 3 // 2, 2), - nn.LeakyReLU(inplace=True), - conv3x3(N * 3 // 2, N * 2), + deconv(N, N, kernel_size=5, stride=2), + nn.ReLU(inplace=True), + deconv(N, N * 3 // 2, kernel_size=5, stride=2), + nn.ReLU(inplace=True), + deconv(N * 3 // 2, N * 2, kernel_size=3, stride=1), ) # In [He2022], this is labeled "g_ch^(k)". channel_context = { f"y{k}": sequential_channel_ramp( - self.groups_acc[k], + sum(self.groups[:k]), self.groups[k] * 2, num_layers=3, make_layer=nn.Conv2d, make_act=lambda: nn.ReLU(inplace=True), kernel_size=5, - padding=2, stride=1, + padding=2, ) for k in range(1, len(self.groups)) } @@ -260,8 +269,8 @@ def __init__(self, N=192, groups=None, **kwargs): self.groups[k], self.groups[k] * 2, kernel_size=5, - padding=2, stride=1, + padding=2, ) for k in range(len(self.groups)) ] @@ -269,14 +278,15 @@ def __init__(self, N=192, groups=None, **kwargs): # In [He2022], this is labeled "Param Aggregation". param_aggregation = [ sequential_channel_ramp( - N * 2 + self.groups[k] * (2 if k == 0 else 4), + # Input: spatial context, channel context, and hyper params. + self.groups[k] * 2 + (k > 0) * self.groups[k] * 2 + N * 2, self.groups[k] * 2, num_layers=3, make_layer=nn.Conv2d, make_act=lambda: nn.ReLU(inplace=True), kernel_size=1, - padding=0, stride=1, + padding=0, ) for k in range(len(self.groups)) ] @@ -313,7 +323,50 @@ def __init__(self, N=192, groups=None, **kwargs): @classmethod def from_state_dict(cls, state_dict): """Return a new model instance from `state_dict`.""" - N = state_dict["g_a.0.conv1.weight"].size(0) + N = state_dict["g_a.0.weight"].size(0) net = cls(N) net.load_state_dict(state_dict) return net + + +class ResidualBottleneckBlock(nn.Module): + """Residual bottleneck block. + + Introduced by [He2016], this block sandwiches a 3x3 convolution + between two 1x1 convolutions which reduce and then restore the + number of channels. This reduces the number of parameters required. + + [He2016]: `"Deep Residual Learning for Image Recognition" + `_, by Kaiming He, Xiangyu Zhang, + Shaoqing Ren, and Jian Sun, CVPR 2016. + + Args: + in_ch (int): Number of input channels + out_ch (int): Number of output channels + """ + + def __init__(self, in_ch: int, out_ch: int): + super().__init__() + + mid_ch = min(in_ch, out_ch) // 2 + self.conv1 = conv1x1(in_ch, mid_ch) + self.conv2 = conv3x3(mid_ch, mid_ch) + self.conv3 = conv1x1(mid_ch, out_ch) + self.relu = nn.ReLU(inplace=True) + + if in_ch != out_ch: + self.skip = conv1x1(in_ch, out_ch) + else: + self.skip = None + + def forward(self, x: Tensor) -> Tensor: + identity = x if self.skip is None else self.skip(x) + + out = x + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.relu(out) + out = self.conv3(out) + + return out + identity From b68b4596784df512740aaf5063a759cccc2103cb Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Fri, 29 Sep 2023 21:15:58 -0700 Subject: [PATCH 26/37] fix: insufficient parameters when merely 'interpolating' Instead of using a linear ramp in number of channels between in_ch and out_ch, use some factor of N to determine number of channels. For instance, instead of ramping between `channels = [16, 21, 26, 32]` it makes more sense to expand then contract, e.g. `channels = [16, 192, 192, 32]` This way, we have enough channels to learn a more complex function! --- compressai/layers/layers.py | 5 ++++- compressai/models/sensetime.py | 2 ++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/compressai/layers/layers.py b/compressai/layers/layers.py index b8bad8fa..5bfb855b 100644 --- a/compressai/layers/layers.py +++ b/compressai/layers/layers.py @@ -331,6 +331,7 @@ def sequential_channel_ramp( in_ch: int, out_ch: int, *, + min_ch: int = 0, num_layers: int = 3, interp: str = "linear", make_layer=None, @@ -339,7 +340,9 @@ def sequential_channel_ramp( **layer_kwargs, ) -> nn.Module: """Interleave layers of gradually ramping channels with nonlinearities.""" - channels = ramp(in_ch, out_ch, num_layers + 1, method=interp).floor().int().tolist() + channels = ramp(in_ch, out_ch, num_layers + 1, method=interp).floor().int() + channels[1:-1] = channels[1:-1].clip(min=min_ch) + channels = channels.tolist() layers = [ module for ch_in, ch_out in zip(channels[:-1], channels[1:]) diff --git a/compressai/models/sensetime.py b/compressai/models/sensetime.py index a3021d40..63c399de 100644 --- a/compressai/models/sensetime.py +++ b/compressai/models/sensetime.py @@ -253,6 +253,7 @@ def __init__(self, N=192, M=320, groups=None, **kwargs): f"y{k}": sequential_channel_ramp( sum(self.groups[:k]), self.groups[k] * 2, + min_ch=N, num_layers=3, make_layer=nn.Conv2d, make_act=lambda: nn.ReLU(inplace=True), @@ -281,6 +282,7 @@ def __init__(self, N=192, M=320, groups=None, **kwargs): # Input: spatial context, channel context, and hyper params. self.groups[k] * 2 + (k > 0) * self.groups[k] * 2 + N * 2, self.groups[k] * 2, + min_ch=N * 2, num_layers=3, make_layer=nn.Conv2d, make_act=lambda: nn.ReLU(inplace=True), From 9b0de2a53c6514ebde1c135a8a65737ce990f81c Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Fri, 29 Sep 2023 18:08:29 -0700 Subject: [PATCH 27/37] feat: elic2022-chandelier --- compressai/models/sensetime.py | 202 +++++++++++++++++++++++++++++++++ 1 file changed, 202 insertions(+) diff --git a/compressai/models/sensetime.py b/compressai/models/sensetime.py index 63c399de..136a809f 100644 --- a/compressai/models/sensetime.py +++ b/compressai/models/sensetime.py @@ -27,6 +27,8 @@ # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import types + import torch import torch.nn as nn @@ -331,6 +333,206 @@ def from_state_dict(cls, state_dict): return net +@register_model("elic2022-chandelier") +class Elic2022Chandelier(SimpleVAECompressionModel): + """ELIC 2022; simplified context model using only first and most recent groups. + + Context model from [He2022], with simplifications and parameters + from the [Chandelier2023] implementation. + Based on modified attention model architecture from [Cheng2020]. + + .. note:: + + This implementation contains some differences compared to the + original [He2022] paper. For instance, the implemented context + model only uses the first and the most recently decoded channel + groups to predict the current channel group. In contrast, the + original paper uses all previously decoded channel groups. + Also, the last layer of h_s is now a conv rather than a deconv. + + [Chandelier2023]: `"ELiC-ReImplemetation" + `_, by + Vincent Chandelier, 2023. + + [He2022]: `"ELIC: Efficient Learned Image Compression with + Unevenly Grouped Space-Channel Contextual Adaptive Coding" + `_, by Dailan He, Ziming Yang, + Weikun Peng, Rui Ma, Hongwei Qin, and Yan Wang, CVPR 2022. + + [Cheng2020]: `"Learned Image Compression with Discretized Gaussian + Mixture Likelihoods and Attention Modules" + `_, by Zhengxue Cheng, Heming Sun, + Masaru Takeuchi, and Jiro Katto, CVPR 2020. + + Args: + N (int): Number of main network channels + M (int): Number of latent space channels + groups (list[int]): Number of channels in each channel group + """ + + def __init__(self, N=192, M=320, groups=None, **kwargs): + super().__init__(**kwargs) + + if groups is None: + groups = [16, 16, 32, 64, M - 128] + + self.groups = list(groups) + assert sum(self.groups) == M + + self.g_a = nn.Sequential( + conv(3, N, kernel_size=5, stride=2), + ResidualBottleneckBlock(N, N), + ResidualBottleneckBlock(N, N), + ResidualBottleneckBlock(N, N), + conv(N, N, kernel_size=5, stride=2), + ResidualBottleneckBlock(N, N), + ResidualBottleneckBlock(N, N), + ResidualBottleneckBlock(N, N), + AttentionBlock(N), + conv(N, N, kernel_size=5, stride=2), + ResidualBottleneckBlock(N, N), + ResidualBottleneckBlock(N, N), + ResidualBottleneckBlock(N, N), + conv(N, M, kernel_size=5, stride=2), + AttentionBlock(M), + ) + + self.g_s = nn.Sequential( + AttentionBlock(M), + deconv(M, N, kernel_size=5, stride=2), + ResidualBottleneckBlock(N, N), + ResidualBottleneckBlock(N, N), + ResidualBottleneckBlock(N, N), + deconv(N, N, kernel_size=5, stride=2), + AttentionBlock(N), + ResidualBottleneckBlock(N, N), + ResidualBottleneckBlock(N, N), + ResidualBottleneckBlock(N, N), + deconv(N, N, kernel_size=5, stride=2), + ResidualBottleneckBlock(N, N), + ResidualBottleneckBlock(N, N), + ResidualBottleneckBlock(N, N), + deconv(N, 3, kernel_size=5, stride=2), + ) + + h_a = nn.Sequential( + conv(M, N, kernel_size=3, stride=1), + nn.ReLU(inplace=True), + conv(N, N, kernel_size=5, stride=2), + nn.ReLU(inplace=True), + conv(N, N, kernel_size=5, stride=2), + ) + + h_s = nn.Sequential( + deconv(N, N, kernel_size=5, stride=2), + nn.ReLU(inplace=True), + deconv(N, N * 3 // 2, kernel_size=5, stride=2), + nn.ReLU(inplace=True), + conv(N * 3 // 2, M * 2, kernel_size=3, stride=1), + ) + + # In [He2022], this is labeled "g_ch^(k)". + channel_context = { + f"y{k}": nn.Sequential( + conv( + # Input: first group, and most recently decoded group. + self.groups[0] + (k > 1) * self.groups[k - 1], + 224, + kernel_size=5, + stride=1, + ), + nn.ReLU(inplace=True), + conv(224, 128, kernel_size=5, stride=1), + nn.ReLU(inplace=True), + conv(128, self.groups[k] * 2, kernel_size=5, stride=1), + ) + for k in range(1, len(self.groups)) + } + + # In [He2022], this is labeled "g_sp^(k)". + spatial_context = [ + CheckerboardMaskedConv2d( + self.groups[k], + self.groups[k] * 2, + kernel_size=5, + stride=1, + padding=2, + ) + for k in range(len(self.groups)) + ] + + # In [He2022], this is labeled "Param Aggregation". + param_aggregation = [ + nn.Sequential( + conv1x1( + # Input: spatial context, channel context, and hyper params. + self.groups[k] * 2 + (k > 0) * self.groups[k] * 2 + M * 2, + M * 2, + ), + nn.ReLU(inplace=True), + conv1x1(M * 2, 512), + nn.ReLU(inplace=True), + conv1x1(512, self.groups[k] * 2), + ) + for k in range(len(self.groups)) + ] + + # In [He2022], this is labeled the space-channel context model (SCCTX). + # The side params and channel context params are computed externally. + scctx_latent_codec = { + f"y{k}": CheckerboardLatentCodec( + latent_codec={ + "y": GaussianConditionalLatentCodec(quantizer="ste"), + }, + context_prediction=spatial_context[k], + entropy_parameters=param_aggregation[k], + ) + for k in range(len(self.groups)) + } + + # [He2022] uses a "hyperprior" architecture, which reconstructs y using z. + self.latent_codec = HyperpriorLatentCodec( + latent_codec={ + # Channel groups with space-channel context model (SCCTX): + "y": ChannelGroupsLatentCodec( + groups=self.groups, + channel_context=channel_context, + latent_codec=scctx_latent_codec, + ), + # Side information branch containing z: + "hyper": HyperLatentCodec( + entropy_bottleneck=EntropyBottleneck(N), h_a=h_a, h_s=h_s + ), + }, + ) + + self._monkey_patch() + + def _monkey_patch(self): + """Monkey-patch to use only first group and most recent group.""" + + def merge_y(self: ChannelGroupsLatentCodec, *args): + if len(args) == 0: + return Tensor() + if len(args) == 1: + return args[0] + if len(args) < len(self.groups): + return torch.cat([args[0], args[-1]], dim=1) + return torch.cat(args, dim=1) + + chan_groups_latent_codec = self.latent_codec["y"] + obj = chan_groups_latent_codec + obj.merge_y = types.MethodType(merge_y, obj) + + @classmethod + def from_state_dict(cls, state_dict): + """Return a new model instance from `state_dict`.""" + N = state_dict["g_a.0.weight"].size(0) + net = cls(N) + net.load_state_dict(state_dict) + return net + + class ResidualBottleneckBlock(nn.Module): """Residual bottleneck block. From c07ff21fd9ff6338a8503a589342fd782906a55a Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Sun, 1 Oct 2023 17:38:06 -0700 Subject: [PATCH 28/37] refactor: ResidualBottleneckBlock exact same output as Chandelier The relu initialization order seems to affect PyTorch output by a very small amount. This is just to make sure that exact same operations are made at the CUDA level. --- compressai/models/sensetime.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/compressai/models/sensetime.py b/compressai/models/sensetime.py index 136a809f..03f5f1b6 100644 --- a/compressai/models/sensetime.py +++ b/compressai/models/sensetime.py @@ -551,26 +551,22 @@ class ResidualBottleneckBlock(nn.Module): def __init__(self, in_ch: int, out_ch: int): super().__init__() - mid_ch = min(in_ch, out_ch) // 2 self.conv1 = conv1x1(in_ch, mid_ch) + self.relu1 = nn.ReLU(inplace=True) self.conv2 = conv3x3(mid_ch, mid_ch) + self.relu2 = nn.ReLU(inplace=True) self.conv3 = conv1x1(mid_ch, out_ch) - self.relu = nn.ReLU(inplace=True) - - if in_ch != out_ch: - self.skip = conv1x1(in_ch, out_ch) - else: - self.skip = None + self.skip = conv1x1(in_ch, out_ch) if in_ch != out_ch else nn.Identity() def forward(self, x: Tensor) -> Tensor: - identity = x if self.skip is None else self.skip(x) + identity = self.skip(x) out = x out = self.conv1(out) - out = self.relu(out) + out = self.relu1(out) out = self.conv2(out) - out = self.relu(out) + out = self.relu2(out) out = self.conv3(out) return out + identity From 9a19e2cd7eeb07b8d172553a94646d6f21340f20 Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Sun, 1 Oct 2023 17:59:15 -0700 Subject: [PATCH 29/37] fix: Elic2022Chandelier: use means, scales; not scales, means For compatibility with Chandelier's pretrained models. --- compressai/models/sensetime.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/compressai/models/sensetime.py b/compressai/models/sensetime.py index 03f5f1b6..3d6e3aa8 100644 --- a/compressai/models/sensetime.py +++ b/compressai/models/sensetime.py @@ -482,7 +482,9 @@ def __init__(self, N=192, M=320, groups=None, **kwargs): scctx_latent_codec = { f"y{k}": CheckerboardLatentCodec( latent_codec={ - "y": GaussianConditionalLatentCodec(quantizer="ste"), + "y": GaussianConditionalLatentCodec( + quantizer="ste", chunks=("means", "scales") + ), }, context_prediction=spatial_context[k], entropy_parameters=param_aggregation[k], From 999aaabe5eafafe9c94a828eca3ce806893a3e29 Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Sun, 1 Oct 2023 17:58:44 -0700 Subject: [PATCH 30/37] feat: elic: use z_hat ste quantizer --- compressai/models/sensetime.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/compressai/models/sensetime.py b/compressai/models/sensetime.py index 3d6e3aa8..ef757f96 100644 --- a/compressai/models/sensetime.py +++ b/compressai/models/sensetime.py @@ -152,7 +152,10 @@ def __init__(self, N=192, **kwargs): ), ), "hyper": HyperLatentCodec( - entropy_bottleneck=EntropyBottleneck(N), h_a=h_a, h_s=h_s + entropy_bottleneck=EntropyBottleneck(N), + h_a=h_a, + h_s=h_s, + quantizer="ste", ), }, ) @@ -319,7 +322,10 @@ def __init__(self, N=192, M=320, groups=None, **kwargs): ), # Side information branch containing z: "hyper": HyperLatentCodec( - entropy_bottleneck=EntropyBottleneck(N), h_a=h_a, h_s=h_s + entropy_bottleneck=EntropyBottleneck(N), + h_a=h_a, + h_s=h_s, + quantizer="ste", ), }, ) @@ -503,7 +509,10 @@ def __init__(self, N=192, M=320, groups=None, **kwargs): ), # Side information branch containing z: "hyper": HyperLatentCodec( - entropy_bottleneck=EntropyBottleneck(N), h_a=h_a, h_s=h_s + entropy_bottleneck=EntropyBottleneck(N), + h_a=h_a, + h_s=h_s, + quantizer="ste", ), }, ) From 94154398c67b60d50aa6ee8854d666eb38f38e9b Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Sun, 1 Oct 2023 22:14:13 -0700 Subject: [PATCH 31/37] docs: add ELIC model references --- compressai/latent_codecs/channel_groups.py | 2 +- compressai/latent_codecs/checkerboard.py | 3 +++ compressai/models/sensetime.py | 1 + docs/source/models.rst | 11 ++++++++--- 4 files changed, 13 insertions(+), 4 deletions(-) diff --git a/compressai/latent_codecs/channel_groups.py b/compressai/latent_codecs/channel_groups.py index 52817a61..ee4a4b20 100644 --- a/compressai/latent_codecs/channel_groups.py +++ b/compressai/latent_codecs/channel_groups.py @@ -51,7 +51,7 @@ class ChannelGroupsLatentCodec(LatentCodec): Context model from [Minnen2020] and [He2022]. Also known as a "channel-conditional" (CC) entropy model. - See :py:class:`~compressai.models.sensetime.Cheng2020AnchorElic` + See :py:class:`~compressai.models.sensetime.Elic2022Official` for example usage. [Minnen2020]: `"Channel-wise Autoregressive Entropy Models for diff --git a/compressai/latent_codecs/checkerboard.py b/compressai/latent_codecs/checkerboard.py index 28e93e10..f3d3c562 100644 --- a/compressai/latent_codecs/checkerboard.py +++ b/compressai/latent_codecs/checkerboard.py @@ -53,6 +53,9 @@ class CheckerboardLatentCodec(LatentCodec): Checkerboard context model introduced in [He2021]. + See :py:class:`~compressai.models.sensetime.Cheng2020AnchorCheckerboard` + for example usage. + - `forward_method="one_pass"` is fastest, but does not use quantization based on the intermediate means. - `forward_method="two_pass"` is slightly slower, but accurately diff --git a/compressai/models/sensetime.py b/compressai/models/sensetime.py index ef757f96..145bfe1c 100644 --- a/compressai/models/sensetime.py +++ b/compressai/models/sensetime.py @@ -61,6 +61,7 @@ __all__ = [ "Cheng2020AnchorCheckerboard", "Elic2022Official", + "Elic2022Chandelier", ] diff --git a/docs/source/models.rst b/docs/source/models.rst index 6dc0df24..cd64f1b4 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -50,9 +50,14 @@ Cheng2020AnchorCheckerboard .. autoclass:: Cheng2020AnchorCheckerboard -Cheng2020AnchorElic ---------------------------- -.. autoclass:: Cheng2020AnchorElic +Elic2022Official +---------------- +.. autoclass:: Elic2022Official + + +Elic2022Chandelier +------------------ +.. autoclass:: Elic2022Chandelier .. currentmodule:: compressai.models.video From eddb1bce0750676389389f364eea4711ed2cec1a Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Mon, 2 Oct 2023 03:37:16 -0700 Subject: [PATCH 32/37] perf: avoid computation using 'meta' device tensor https://pytorch.org/tutorials/recipes/recipes/reasoning_about_shapes.html --- compressai/latent_codecs/checkerboard.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/compressai/latent_codecs/checkerboard.py b/compressai/latent_codecs/checkerboard.py index f3d3c562..37f7b55c 100644 --- a/compressai/latent_codecs/checkerboard.py +++ b/compressai/latent_codecs/checkerboard.py @@ -253,7 +253,8 @@ def _forward_twopass_faster(self, y: Tensor, side_params: Tensor) -> Dict[str, A @torch.no_grad() def _y_ctx_zero(self, y): """Create a zero tensor with correct shape for y_ctx.""" - return self._mask(self.context_prediction(y).detach(), "all") + y_ctx_meta = self.context_prediction(y.to("meta")) + return y.new_zeros(y_ctx_meta.shape) def compress(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: n, c, h, w = y.shape From 2db3544d1f81e63d75c723f2126c153236ee250b Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Mon, 2 Oct 2023 04:10:15 -0700 Subject: [PATCH 33/37] style: comments, annotations --- compressai/latent_codecs/checkerboard.py | 63 ++++++++++++++++-------- 1 file changed, 43 insertions(+), 20 deletions(-) diff --git a/compressai/latent_codecs/checkerboard.py b/compressai/latent_codecs/checkerboard.py index 37f7b55c..4f04354b 100644 --- a/compressai/latent_codecs/checkerboard.py +++ b/compressai/latent_codecs/checkerboard.py @@ -56,13 +56,14 @@ class CheckerboardLatentCodec(LatentCodec): See :py:class:`~compressai.models.sensetime.Cheng2020AnchorCheckerboard` for example usage. - - `forward_method="one_pass"` is fastest, but does not use + - `forward_method="onepass"` is fastest, but does not use quantization based on the intermediate means. - - `forward_method="two_pass"` is slightly slower, but accurately - quantizes based on the intermediate means. + Uses noise to model quantization. + - `forward_method="twopass"` is slightly slower, but accurately + quantizes via STE based on the intermediate means. Uses the same operations as [Chandelier2023]. - - `forward_method="two_pass_faster"` uses slightly fewer - redundant operations, but may not work in all cases. + - `forward_method="twopass_faster"` uses slightly fewer + redundant operations. [He2021]: `"Checkerboard Context Model for Efficient Learned Image Compression" `_, by Dailan He, @@ -76,8 +77,6 @@ class CheckerboardLatentCodec(LatentCodec): is a pointwise function, e.g., a composition of 1x1 convs and pointwise nonlinearities. - .. note:: This implementation uses uniform noise for training quantization. - .. code-block:: none 0. Input: @@ -153,7 +152,13 @@ def forward(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: raise ValueError(f"Unknown forward method: {self.forward_method}") def _forward_onepass(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: - """Fast estimation with single pass of the context model.""" + """Fast estimation with single pass of the entropy parameters network. + + It is faster than the twopass method (only one pass required!), + but also less accurate. + + This method uses uniform noise to roughly model quantization. + """ y_hat = self.quantize(y) y_ctx = self._keep_only(self.context_prediction(y_hat), "non_anchor") ctx_params = self.entropy_parameters(self.merge(y_ctx, side_params)) @@ -166,8 +171,21 @@ def _forward_onepass(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: } def _forward_twopass(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: - """Do context prediction on STE-quantized y_hat instead.""" + """Runs the entropy parameters network in two passes. + + The first pass gets ``y_hat`` and ``means_hat`` for the anchors. + This ``y_hat`` is used as context to predict the non-anchors. + The second pass gets ``y_hat`` for the non-anchors. + The two ``y_hat`` tensors are then combined. The resulting + ``y_hat`` models the effects of quantization more realistically. + + To compute ``y_hat_anchors``, we need the predicted ``means_hat``: + ``y_hat = quantize_ste(y - means_hat) + means_hat``. + Thus, two passes of ``entropy_parameters`` are necessary. + + """ B, C, H, W = y.shape + ctx_params = y.new_zeros((B, C * 2, H, W)) y_hat_anchors = self._forward_twopass_step( @@ -186,7 +204,6 @@ def _forward_twopass(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: step="non_anchor", ) - # NOTE: We could also use y_hat = y_out["y_hat"] if it uses the same quantizer. y_hat = y_hat_anchors + y_hat_non_anchors y_out = self.latent_codec["y"](y, ctx_params) @@ -200,9 +217,9 @@ def _forward_twopass(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: def _forward_twopass_step( self, y: Tensor, side_params: Tensor, y_ctx: Tensor, params: Tensor, step: str ) -> Dict[str, Any]: - # NOTE: The _i variables only contain the current step's pixels. + # NOTE: The _i variables contain only the current step's pixels. + assert step in ("anchor", "non_anchor") - # Estimate parameters. params_i = self.entropy_parameters(self.merge(y_ctx, side_params)) # Save params for current step. This is later used for entropy estimation. @@ -213,35 +230,41 @@ def _forward_twopass_step( params_i = self.latent_codec["y"].entropy_parameters(params_i) # Keep only elements needed for current step. + # It's not necessary to mask the rest out just yet, but it doesn't hurt. params_i = self._keep_only(params_i, step) - - # NOTE: It's not necessary to mask out the non-step pixels, but it doesn't hurt. y_i = self._keep_only(y.clone(), step) - # Determine y_hat for current step. + # Determine y_hat for current step, and mask out the other pixels. _, means_i = self.latent_codec["y"]._chunk(params_i) y_hat_i = self._keep_only(quantize_ste(y_i - means_i) + means_i, step) return y_hat_i def _forward_twopass_faster(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: - """Do context prediction on STE-quantized y_hat instead.""" + """Runs the entropy parameters network in two passes. + + This version was written based on the paper description. + It is a tiny bit faster than the twopass method since + it avoids a few redundant operations. The "probably unnecessary" + operations can likely be removed as well. + The speedup is very small, however. + """ y_ctx = self._y_ctx_zero(y) ctx_params = self.entropy_parameters(self.merge(y_ctx, side_params)) ctx_params = self.latent_codec["y"].entropy_parameters(ctx_params) - ctx_params = self._keep_only(ctx_params, "anchor") # Probably not needed. + ctx_params = self._keep_only(ctx_params, "anchor") # Probably unnecessary. _, means_hat = self.latent_codec["y"]._chunk(ctx_params) y_hat_anchors = quantize_ste(y - means_hat) + means_hat y_hat_anchors = self._keep_only(y_hat_anchors, "anchor") y_ctx = self.context_prediction(y_hat_anchors) - y_ctx = self._keep_only(y_ctx, "non_anchor") # Probably not needed. + y_ctx = self._keep_only(y_ctx, "non_anchor") # Probably unnecessary. ctx_params = self.entropy_parameters(self.merge(y_ctx, side_params)) y_out = self.latent_codec["y"](y, ctx_params) # Reuse quantized y_hat that was used for non-anchor context prediction. y_hat = y_out["y_hat"] - self._copy(y_hat, y_hat_anchors, "anchor") # Probably not needed. + self._copy(y_hat, y_hat_anchors, "anchor") # Probably unnecessary. return { "likelihoods": { @@ -251,7 +274,7 @@ def _forward_twopass_faster(self, y: Tensor, side_params: Tensor) -> Dict[str, A } @torch.no_grad() - def _y_ctx_zero(self, y): + def _y_ctx_zero(self, y: Tensor) -> Tensor: """Create a zero tensor with correct shape for y_ctx.""" y_ctx_meta = self.context_prediction(y.to("meta")) return y.new_zeros(y_ctx_meta.shape) From a8184465dd4e34f19b3b8388f740a93e7e7afe04 Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Mon, 2 Oct 2023 04:10:59 -0700 Subject: [PATCH 34/37] refactor: clarity --- compressai/latent_codecs/checkerboard.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/compressai/latent_codecs/checkerboard.py b/compressai/latent_codecs/checkerboard.py index 4f04354b..70c83594 100644 --- a/compressai/latent_codecs/checkerboard.py +++ b/compressai/latent_codecs/checkerboard.py @@ -191,17 +191,17 @@ def _forward_twopass(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: y_hat_anchors = self._forward_twopass_step( y, side_params, - self._y_ctx_zero(y), ctx_params, - step="anchor", + self._y_ctx_zero(y), + "anchor", ) y_hat_non_anchors = self._forward_twopass_step( y, side_params, - self.context_prediction(y_hat_anchors), ctx_params, - step="non_anchor", + self.context_prediction(y_hat_anchors), + "non_anchor", ) y_hat = y_hat_anchors + y_hat_non_anchors @@ -215,7 +215,7 @@ def _forward_twopass(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: } def _forward_twopass_step( - self, y: Tensor, side_params: Tensor, y_ctx: Tensor, params: Tensor, step: str + self, y: Tensor, side_params: Tensor, params: Tensor, y_ctx: Tensor, step: str ) -> Dict[str, Any]: # NOTE: The _i variables contain only the current step's pixels. assert step in ("anchor", "non_anchor") @@ -225,9 +225,9 @@ def _forward_twopass_step( # Save params for current step. This is later used for entropy estimation. self._copy(params, params_i, step) - # Technically, latent_codec may also contain an "entropy_parameters" method. - # Usually, it is identity, though. - params_i = self.latent_codec["y"].entropy_parameters(params_i) + # Apply latent_codec's "entropy_parameters()", if it exists. Usually identity. + func = getattr(self.latent_codec["y"], "entropy_parameters", lambda x: x) + params_i = func(params_i) # Keep only elements needed for current step. # It's not necessary to mask the rest out just yet, but it doesn't hurt. @@ -251,7 +251,8 @@ def _forward_twopass_faster(self, y: Tensor, side_params: Tensor) -> Dict[str, A """ y_ctx = self._y_ctx_zero(y) ctx_params = self.entropy_parameters(self.merge(y_ctx, side_params)) - ctx_params = self.latent_codec["y"].entropy_parameters(ctx_params) + func = getattr(self.latent_codec["y"], "entropy_parameters", lambda x: x) + ctx_params = func(ctx_params) ctx_params = self._keep_only(ctx_params, "anchor") # Probably unnecessary. _, means_hat = self.latent_codec["y"]._chunk(ctx_params) y_hat_anchors = quantize_ste(y - means_hat) + means_hat @@ -316,6 +317,7 @@ def decompress( assert all(len(x) == n for x in y_strings_) c, h, w = shape + y_i_shape = (h, w // 2) y_hat_ = side_params.new_zeros((2, n, c, h, w // 2)) side_params_ = self.unembed(side_params) @@ -325,7 +327,7 @@ def decompress( y_ctx_i = self._mask(y_ctx_i, "all") ctx_params_i = self.entropy_parameters(self.merge(y_ctx_i, side_params_[i])) y_out = self.latent_codec["y"].decompress( - [y_strings_[i]], shape=(h, w // 2), ctx_params=ctx_params_i + [y_strings_[i]], y_i_shape, ctx_params_i ) y_hat_[i] = y_out["y_hat"] From 800fa72551299e38fe10bf51fd37a25e0ab70364 Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Mon, 2 Oct 2023 02:39:44 -0700 Subject: [PATCH 35/37] style: rename ctx_params -> params --- compressai/latent_codecs/checkerboard.py | 40 ++++++++++-------------- 1 file changed, 16 insertions(+), 24 deletions(-) diff --git a/compressai/latent_codecs/checkerboard.py b/compressai/latent_codecs/checkerboard.py index 70c83594..9837ee35 100644 --- a/compressai/latent_codecs/checkerboard.py +++ b/compressai/latent_codecs/checkerboard.py @@ -161,8 +161,8 @@ def _forward_onepass(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: """ y_hat = self.quantize(y) y_ctx = self._keep_only(self.context_prediction(y_hat), "non_anchor") - ctx_params = self.entropy_parameters(self.merge(y_ctx, side_params)) - y_out = self.latent_codec["y"](y, ctx_params) + params = self.entropy_parameters(self.merge(y_ctx, side_params)) + y_out = self.latent_codec["y"](y, params) return { "likelihoods": { "y": y_out["likelihoods"]["y"], @@ -186,26 +186,18 @@ def _forward_twopass(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: """ B, C, H, W = y.shape - ctx_params = y.new_zeros((B, C * 2, H, W)) + params = y.new_zeros((B, C * 2, H, W)) y_hat_anchors = self._forward_twopass_step( - y, - side_params, - ctx_params, - self._y_ctx_zero(y), - "anchor", + y, side_params, params, self._y_ctx_zero(y), "anchor" ) y_hat_non_anchors = self._forward_twopass_step( - y, - side_params, - ctx_params, - self.context_prediction(y_hat_anchors), - "non_anchor", + y, side_params, params, self.context_prediction(y_hat_anchors), "non_anchor" ) y_hat = y_hat_anchors + y_hat_non_anchors - y_out = self.latent_codec["y"](y, ctx_params) + y_out = self.latent_codec["y"](y, params) return { "likelihoods": { @@ -250,18 +242,18 @@ def _forward_twopass_faster(self, y: Tensor, side_params: Tensor) -> Dict[str, A The speedup is very small, however. """ y_ctx = self._y_ctx_zero(y) - ctx_params = self.entropy_parameters(self.merge(y_ctx, side_params)) + params = self.entropy_parameters(self.merge(y_ctx, side_params)) func = getattr(self.latent_codec["y"], "entropy_parameters", lambda x: x) - ctx_params = func(ctx_params) - ctx_params = self._keep_only(ctx_params, "anchor") # Probably unnecessary. - _, means_hat = self.latent_codec["y"]._chunk(ctx_params) + params = func(params) + params = self._keep_only(params, "anchor") # Probably unnecessary. + _, means_hat = self.latent_codec["y"]._chunk(params) y_hat_anchors = quantize_ste(y - means_hat) + means_hat y_hat_anchors = self._keep_only(y_hat_anchors, "anchor") y_ctx = self.context_prediction(y_hat_anchors) y_ctx = self._keep_only(y_ctx, "non_anchor") # Probably unnecessary. - ctx_params = self.entropy_parameters(self.merge(y_ctx, side_params)) - y_out = self.latent_codec["y"](y, ctx_params) + params = self.entropy_parameters(self.merge(y_ctx, side_params)) + y_out = self.latent_codec["y"](y, params) # Reuse quantized y_hat that was used for non-anchor context prediction. y_hat = y_out["y_hat"] @@ -291,8 +283,8 @@ def compress(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: y_ctx_i = self.unembed(self.context_prediction(self.embed(y_hat_)))[i] if i == 0: y_ctx_i = self._mask(y_ctx_i, "all") - ctx_params_i = self.entropy_parameters(self.merge(y_ctx_i, side_params_[i])) - y_out = self.latent_codec["y"].compress(y_[i], ctx_params_i) + params_i = self.entropy_parameters(self.merge(y_ctx_i, side_params_[i])) + y_out = self.latent_codec["y"].compress(y_[i], params_i) y_hat_[i] = y_out["y_hat"] [y_strings_[i]] = y_out["strings"] @@ -325,9 +317,9 @@ def decompress( y_ctx_i = self.unembed(self.context_prediction(self.embed(y_hat_)))[i] if i == 0: y_ctx_i = self._mask(y_ctx_i, "all") - ctx_params_i = self.entropy_parameters(self.merge(y_ctx_i, side_params_[i])) + params_i = self.entropy_parameters(self.merge(y_ctx_i, side_params_[i])) y_out = self.latent_codec["y"].decompress( - [y_strings_[i]], y_i_shape, ctx_params_i + [y_strings_[i]], y_i_shape, params_i ) y_hat_[i] = y_out["y_hat"] From 2489952189980827abea01e8c20018dd6c0905b9 Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Sun, 8 Oct 2023 16:45:00 -0700 Subject: [PATCH 36/37] perf: allow @torch.compile by avoiding in-place operation for MaskedConv `@torch.compile` can speed up model training by 5% - 200%. Simply use: ```python model = torch.compile(model) ``` This commit resolves the error that comes up when compiling MaskedConv2d: ```none RuntimeError: a leaf Variable that requires grad is being used in an in-place operation. ``` --- compressai/layers/layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compressai/layers/layers.py b/compressai/layers/layers.py index 5bfb855b..73fcbce1 100644 --- a/compressai/layers/layers.py +++ b/compressai/layers/layers.py @@ -79,7 +79,7 @@ def __init__(self, *args: Any, mask_type: str = "A", **kwargs: Any): def forward(self, x: Tensor) -> Tensor: # TODO(begaintj): weight assigment is not supported by torchscript - self.weight.data *= self.mask + self.weight.data = self.weight.data * self.mask return super().forward(x) From a494099357f1e3a828357e60862174fe7632b84a Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Sun, 8 Oct 2023 16:45:28 -0700 Subject: [PATCH 37/37] perf: allow @torch.compile by avoiding in-place operation via clone() `@torch.compile` can speed up model training by 5% - 200%. Simply use: ```python model = torch.compile(model) ``` This commit resolves the error that comes up when compiling _mask: ```none RuntimeError: a leaf Variable that requires grad is being used in an in-place operation. ``` --- compressai/latent_codecs/checkerboard.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/compressai/latent_codecs/checkerboard.py b/compressai/latent_codecs/checkerboard.py index 9837ee35..483ea49f 100644 --- a/compressai/latent_codecs/checkerboard.py +++ b/compressai/latent_codecs/checkerboard.py @@ -224,7 +224,7 @@ def _forward_twopass_step( # Keep only elements needed for current step. # It's not necessary to mask the rest out just yet, but it doesn't hurt. params_i = self._keep_only(params_i, step) - y_i = self._keep_only(y.clone(), step) + y_i = self._keep_only(y, step) # Determine y_hat for current step, and mask out the other pixels. _, means_i = self.latent_codec["y"]._chunk(params_i) @@ -387,12 +387,17 @@ def _copy(self, dest: Tensor, src: Tensor, step: str) -> None: dest[..., 0::2, 1::2] = src[..., 0::2, 1::2] dest[..., 1::2, 0::2] = src[..., 1::2, 0::2] - def _keep_only(self, y: Tensor, step: str) -> Tensor: + def _keep_only(self, y: Tensor, step: str, inplace: bool = False) -> Tensor: """Keep only pixels in the current step, and zero out the rest.""" - parity = self.non_anchor_parity if step == "anchor" else self.anchor_parity - return self._mask(y, parity) + return self._mask( + y, + parity=self.non_anchor_parity if step == "anchor" else self.anchor_parity, + inplace=inplace, + ) - def _mask(self, y: Tensor, parity: str) -> Tensor: + def _mask(self, y: Tensor, parity: str, inplace: bool = False) -> Tensor: + if not inplace: + y = y.clone() if parity == "even": y[..., 0::2, 0::2] = 0 y[..., 1::2, 1::2] = 0