diff --git a/compressai/zoo/image_vbr.py b/compressai/zoo/image_vbr.py index d48f418d..d103f839 100644 --- a/compressai/zoo/image_vbr.py +++ b/compressai/zoo/image_vbr.py @@ -60,7 +60,7 @@ # "ms-ssim": f"{root_url}/mbt2018-mean-ms-ssim-vbr-HASH.pth.tar", }, "mbt2018-vbr": { - "mse": f"{root_url}/mbt2018-mse-vbr-53f56fca.pth.tar", + "mse": f"{root_url}/mbt2018-mse-vbr-f12581a1.pth.tar", # "ms-ssim": f"{root_url}/mbt2018-ms-ssim-vbr-HASH.pth.tar", }, } @@ -83,12 +83,12 @@ def _load_model(architecture, metric, pretrained=False, progress=True, **kwargs) url = model_urls[architecture][metric] state_dict = load_state_dict_from_url(url, progress=progress) state_dict = load_pretrained(state_dict) - vr_entbttlnck = False if architecture in ["bmshj2018-hyperprior-vbr", "mbt2018-mean-vbr"]: - vr_entbttlnck = True - model = model_architectures[architecture].from_state_dict( - state_dict, vr_entbttlnck - ) + model = model_architectures[architecture].from_state_dict( + state_dict, vr_entbttlnck=True + ) + else: + model = model_architectures[architecture].from_state_dict(state_dict) return model model = model_architectures[architecture](*cfgs[architecture], **kwargs)