diff --git a/compressai/models/base.py b/compressai/models/base.py index 543aa96c..0f326ba0 100644 --- a/compressai/models/base.py +++ b/compressai/models/base.py @@ -28,6 +28,7 @@ # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import math +import warnings from typing import cast @@ -66,6 +67,30 @@ class CompressionModel(nn.Module): EntropyBottleneck or GaussianConditional modules. """ + def __init__(self, entropy_bottleneck_channels=None, init_weights=None): + super().__init__() + + if entropy_bottleneck_channels is not None: + warnings.warn( + "The entropy_bottleneck_channels parameter is deprecated. " + "Create an entropy_bottleneck in your model directly instead:\n\n" + "class YourModel(CompressionModel):\n" + " def __init__(self):\n" + " super().__init__()\n" + " self.entropy_bottleneck = " + "EntropyBottleneck(entropy_bottleneck_channels)\n", + DeprecationWarning, + stacklevel=2, + ) + self.entropy_bottleneck = EntropyBottleneck(entropy_bottleneck_channels) + + if init_weights is not None: + warnings.warn( + "The init_weights parameter was removed as it was never functional.", + DeprecationWarning, + stacklevel=2, + ) + def load_state_dict(self, state_dict, strict=True): for name, module in self.named_modules(): if not any(x.startswith(name) for x in state_dict.keys()):