diff --git a/albumentations/core/composition.py b/albumentations/core/composition.py index 4092069d9..ae2e37fbf 100644 --- a/albumentations/core/composition.py +++ b/albumentations/core/composition.py @@ -3,6 +3,7 @@ import random +import cv2 import numpy as np from albumentations.augmentations.keypoints_utils import KeypointsProcessor @@ -116,6 +117,26 @@ def set_deterministic(self, flag, save_key="replay"): for t in self.transforms: t.set_deterministic(flag, save_key) + def set_mask_interpolation(self, mask_interpolation): + if mask_interpolation not in { + cv2.INTER_NEAREST, + cv2.INTER_NEAREST_EXACT, + cv2.INTER_LINEAR, + cv2.INTER_LINEAR_EXACT, + cv2.INTER_AREA, + cv2.INTER_CUBIC, + cv2.INTER_LANCZOS4, + cv2.INTER_MAX, + }: + raise ValueError( + f"Value {mask_interpolation} is not supported. " + f"Choose one of the following methods: cv2.INTER_NEAREST, cv2.INTER_NEAREST_EXACT, cv2.INTER_LINEAR, " + f"cv2.INTER_LINEAR_EXACT, cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4, cv2.INTER_MAX" + ) + for t in self.transforms: + if isinstance(t, BaseCompose) or (isinstance(t, DualTransform) and t.mask_interpolation is None): + t.set_mask_interpolation(mask_interpolation) + class Compose(BaseCompose): """Compose transforms and handle all transformations regarding bounding boxes @@ -128,8 +149,16 @@ class Compose(BaseCompose): p (float): probability of applying all list of transforms. Default: 1.0. """ - def __init__(self, transforms, bbox_params=None, keypoint_params=None, additional_targets=None, p=1.0): - super(Compose, self).__init__([t for t in transforms if t is not None], p) + def __init__( + self, + transforms, + bbox_params=None, + keypoint_params=None, + additional_targets=None, + mask_interpolation=cv2.INTER_NEAREST, + p=1.0, + ): + super().__init__([t for t in transforms if t is not None], p) self.processors = {} if bbox_params: @@ -159,6 +188,7 @@ def __init__(self, transforms, bbox_params=None, keypoint_params=None, additiona proc.ensure_transforms_valid(self.transforms) self.add_targets(additional_targets) + self.set_mask_interpolation(mask_interpolation) def __call__(self, *args, force_apply=False, **data): if args: diff --git a/albumentations/core/transforms_interface.py b/albumentations/core/transforms_interface.py index 6bc6cb241..63e658a29 100644 --- a/albumentations/core/transforms_interface.py +++ b/albumentations/core/transforms_interface.py @@ -202,6 +202,10 @@ def get_dict_with_id(self): class DualTransform(BasicTransform): """Transform for segmentation task.""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._mask_interpolation = None + @property def targets(self): return { @@ -212,6 +216,18 @@ def targets(self): "keypoints": self.apply_to_keypoints, } + @property + def mask_interpolation(self): + return self._mask_interpolation + + @mask_interpolation.setter + def mask_interpolation(self, mask_interpolation): + self._mask_interpolation = mask_interpolation + + def set_mask_interpolation(self, mask_interpolation): + self._mask_interpolation = mask_interpolation + return self + def apply_to_bbox(self, bbox, **params): raise NotImplementedError("Method apply_to_bbox is not implemented in class " + self.__class__.__name__) @@ -225,7 +241,8 @@ def apply_to_keypoints(self, keypoints, **params): return [self.apply_to_keypoint(tuple(keypoint[:4]), **params) + tuple(keypoint[4:]) for keypoint in keypoints] def apply_to_mask(self, img, **params): - return self.apply(img, **{k: cv2.INTER_NEAREST if k == "interpolation" else v for k, v in params.items()}) + mask_interpolation = self.mask_interpolation if self.mask_interpolation is not None else cv2.INTER_NEAREST + return self.apply(img, **{k: mask_interpolation if k == "interpolation" else v for k, v in params.items()}) def apply_to_masks(self, masks, **params): return [self.apply_to_mask(mask, **params) for mask in masks]