Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow to set a custom mask interpolation method #945

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions albumentations/core/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import random

import cv2
import numpy as np

from albumentations.augmentations.keypoints_utils import KeypointsProcessor
Expand Down Expand Up @@ -116,6 +117,26 @@ def set_deterministic(self, flag, save_key="replay"):
for t in self.transforms:
t.set_deterministic(flag, save_key)

def set_mask_interpolation(self, mask_interpolation):
if mask_interpolation not in {
cv2.INTER_NEAREST,
cv2.INTER_NEAREST_EXACT,
cv2.INTER_LINEAR,
cv2.INTER_LINEAR_EXACT,
cv2.INTER_AREA,
cv2.INTER_CUBIC,
cv2.INTER_LANCZOS4,
cv2.INTER_MAX,
}:
raise ValueError(
f"Value {mask_interpolation} is not supported. "
f"Choose one of the following methods: cv2.INTER_NEAREST, cv2.INTER_NEAREST_EXACT, cv2.INTER_LINEAR, "
f"cv2.INTER_LINEAR_EXACT, cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4, cv2.INTER_MAX"
)
Comment on lines +121 to +135
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe better to create an Enum as class field for this flag?
With Enum error message would be much more clearly.

for t in self.transforms:
if isinstance(t, BaseCompose) or (isinstance(t, DualTransform) and t.mask_interpolation is None):
t.set_mask_interpolation(mask_interpolation)
Comment on lines +136 to +138
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strange condition, I am not sure that it will be work correctly, because mask_interpolation has no default value.



class Compose(BaseCompose):
"""Compose transforms and handle all transformations regarding bounding boxes
Expand All @@ -128,8 +149,16 @@ class Compose(BaseCompose):
p (float): probability of applying all list of transforms. Default: 1.0.
"""

def __init__(self, transforms, bbox_params=None, keypoint_params=None, additional_targets=None, p=1.0):
super(Compose, self).__init__([t for t in transforms if t is not None], p)
def __init__(
self,
transforms,
bbox_params=None,
keypoint_params=None,
additional_targets=None,
mask_interpolation=cv2.INTER_NEAREST,
p=1.0,
):
super().__init__([t for t in transforms if t is not None], p)

self.processors = {}
if bbox_params:
Expand Down Expand Up @@ -159,6 +188,7 @@ def __init__(self, transforms, bbox_params=None, keypoint_params=None, additiona
proc.ensure_transforms_valid(self.transforms)

self.add_targets(additional_targets)
self.set_mask_interpolation(mask_interpolation)

def __call__(self, *args, force_apply=False, **data):
if args:
Expand Down
19 changes: 18 additions & 1 deletion albumentations/core/transforms_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,10 @@ def get_dict_with_id(self):
class DualTransform(BasicTransform):
"""Transform for segmentation task."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._mask_interpolation = None

@property
def targets(self):
return {
Expand All @@ -212,6 +216,18 @@ def targets(self):
"keypoints": self.apply_to_keypoints,
}

@property
def mask_interpolation(self):
return self._mask_interpolation

@mask_interpolation.setter
def mask_interpolation(self, mask_interpolation):
self._mask_interpolation = mask_interpolation

def set_mask_interpolation(self, mask_interpolation):
self._mask_interpolation = mask_interpolation
return self

def apply_to_bbox(self, bbox, **params):
raise NotImplementedError("Method apply_to_bbox is not implemented in class " + self.__class__.__name__)

Expand All @@ -225,7 +241,8 @@ def apply_to_keypoints(self, keypoints, **params):
return [self.apply_to_keypoint(tuple(keypoint[:4]), **params) + tuple(keypoint[4:]) for keypoint in keypoints]

def apply_to_mask(self, img, **params):
return self.apply(img, **{k: cv2.INTER_NEAREST if k == "interpolation" else v for k, v in params.items()})
mask_interpolation = self.mask_interpolation if self.mask_interpolation is not None else cv2.INTER_NEAREST
return self.apply(img, **{k: mask_interpolation if k == "interpolation" else v for k, v in params.items()})

def apply_to_masks(self, masks, **params):
return [self.apply_to_mask(mask, **params) for mask in masks]
Expand Down