From 243bec9fa0e7e59c62476fa7c757e0451e947b18 Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Mon, 16 Oct 2023 19:37:58 -0700 Subject: [PATCH] feat: SpectralConv2d, SpectralConvTranspose2d MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduced in ["Efficient Nonlinear Transforms for Lossy Image Compression"][Balle2018efficient] by Johannes Ballé, PCS 2018. Reparameterizes the weights to be derived from weights stored in the frequency domain. In the original paper, this is referred to as "spectral Adam" or "Sadam" due to its effect on the Adam optimizer update rule. The motivation behind representing the weights in the frequency domain is that optimizer updates/steps may now affect all frequencies to an equal amount. This improves the gradient conditioning, thus leading to faster convergence and increased stability at larger learning rates. For comparison, see the TensorFlow Compression implementations of [`SignalConv2D`] and [`RDFTParameter`]. They seem to use `SignalConv2d` in most of their provided architectures: https://github.com/search?q=repo%3Atensorflow%2Fcompression+Conv2D&type=code Furthermore, since this is a simple invertible transformation on the weights, it is trivial to convert any existing pretrained weights into this form via: ```python weight_transformed = self._to_transform_domain(weight) ``` To override `self.weight` as a property, I'm unregistering the module using `del self._parameters["weight"]` as shown in https://github.com/pytorch/pytorch/issues/46886, and also [using the fact][property-descriptor-so] that `@property` [returns a descriptor object][property-descriptor-docs] so that `self.weight` "falls back" to the property. ```python def __init__(self, ...): self.weight_transformed = nn.Parameter(self._to_transform_domain(self.weight)) del self._parameters["weight"] # Unregister weight, and fallback to property. @property def weight(self) -> Tensor: return self._from_transform_domain(self.weight_transformed) ``` [Balle2018efficient]: https://arxiv.org/abs/1802.00847 [`SignalConv2D`]: https://github.com/tensorflow/compression/blob/v2.14.0/tensorflow_compression/python/layers/signal_conv.py#L61 [`RDFTParameter`]: https://github.com/tensorflow/compression/blob/v2.14.0/tensorflow_compression/python/layers/parameters.py#L71 [property-descriptor-docs]: https://docs.python.org/3/howto/descriptor.html#properties [property-descriptor-so]: https://stackoverflow.com/a/17330273/365102 [`eval` mode]: https://stackoverflow.com/a/51433411/365102 --- compressai/layers/layers.py | 63 ++++++++++++++++++++++++++++++++++++- 1 file changed, 62 insertions(+), 1 deletion(-) diff --git a/compressai/layers/layers.py b/compressai/layers/layers.py index 540771ef..2d3754e3 100644 --- a/compressai/layers/layers.py +++ b/compressai/layers/layers.py @@ -27,7 +27,7 @@ # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from typing import Any +from typing import Any, Tuple import torch import torch.nn as nn @@ -43,12 +43,73 @@ "ResidualBlock", "ResidualBlockUpsample", "ResidualBlockWithStride", + "SpectralConv2d", + "SpectralConvTranspose2d", "conv3x3", "subpel_conv3x3", "QReLU", ] +class _SpectralConvNdMixin: + def __init__(self, dim: Tuple[int, ...]): + self.dim = dim + self.weight_transformed = nn.Parameter(self._to_transform_domain(self.weight)) + del self._parameters["weight"] # Unregister weight, and fallback to property. + + @property + def weight(self) -> Tensor: + return self._from_transform_domain(self.weight_transformed) + + def _to_transform_domain(self, x: Tensor) -> Tensor: + return torch.fft.rfftn(x, s=self.kernel_size, dim=self.dim, norm="ortho") + + def _from_transform_domain(self, x: Tensor) -> Tensor: + return torch.fft.irfftn(x, s=self.kernel_size, dim=self.dim, norm="ortho") + + +class SpectralConv2d(nn.Conv2d, _SpectralConvNdMixin): + r"""Spectral 2D convolution. + + Introduced in [Balle2018efficient]. + Reparameterizes the weights to be derived from weights stored in the + frequency domain. + In the original paper, this is referred to as "spectral Adam" or + "Sadam" due to its effect on the Adam optimizer update rule. + The motivation behind representing the weights in the frequency + domain is that optimizer updates/steps may now affect all + frequencies to an equal amount. + This improves the gradient conditioning, thus leading to faster + convergence and increased stability at larger learning rates. + + For comparison, see the TensorFlow Compression implementations of + `SignalConv2D + `_ + and + `RDFTParameter + `_. + + [Balle2018efficient]: `"Efficient Nonlinear Transforms for Lossy + Image Compression" `_, + by Johannes Ballé, PCS 2018. + """ + + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + _SpectralConvNdMixin.__init__(self, dim=(-2, -1)) + + +class SpectralConvTranspose2d(nn.ConvTranspose2d, _SpectralConvNdMixin): + r"""Spectral 2D transposed convolution. + + Transposed version of :class:`SpectralConv2d`. + """ + + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + _SpectralConvNdMixin.__init__(self, dim=(-2, -1)) + + class MaskedConv2d(nn.Conv2d): r"""Masked 2D convolution implementation, mask future "unseen" pixels. Useful for building auto-regressive network components.