Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Random fill_value #597

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 32 additions & 15 deletions albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import math
import random
import typing
import warnings
from enum import IntEnum
from types import LambdaType
Expand All @@ -13,7 +14,7 @@
from . import functional as F
from .bbox_utils import denormalize_bbox, normalize_bbox, union_of_bboxes
from ..core.transforms_interface import DualTransform, ImageOnlyTransform, NoOp, to_tuple
from ..core.utils import format_args
from ..core.utils import format_args, get_random_color

__all__ = [
"Blur",
Expand Down Expand Up @@ -1509,7 +1510,8 @@ class CoarseDropout(ImageOnlyTransform):
`min_height` is set to `max_height`. Default: `None`.
min_width (int): Minimum width of the hole. If `None`, `min_height` is
set to `max_width`. Default: `None`.
fill_value (int, float, lisf of int, list of float): value for dropped pixels.
fill_value (int, float, lisf of int, list of float, string): value for dropped pixels.
If fill_value is 'random', random color will be generated.

Targets:
image
Expand Down Expand Up @@ -1570,7 +1572,11 @@ def get_params_dependent_on_targets(self, params):
x2 = x1 + hole_width
holes.append((x1, y1, x2, y2))

return {"holes": holes}
fill_value = self.fill_value
if self.fill_value == "random":
fill_value = get_random_color(F.get_num_channels(img), img.dtype)

return {"holes": holes, "fill_value": fill_value}

@property
def targets_as_params(self):
Expand Down Expand Up @@ -3105,8 +3111,9 @@ def __init__(self, max_objects=1, image_fill_value=0, mask_fill_value=0, always_
Args:
max_objects: Maximum number of labels that can be zeroed out. Can be tuple, in this case it's [min, max]
image_fill_value: Fill value to use when filling image.
Can be 'inpaint' to apply inpaining (works only for 3-chahnel images)
mask_fill_value: Fill value to use when filling mask.
Can be 'inpaint' to apply inpaining (works only for 3-chahnel images).
If image_fill_value is 'random', random color will be generated. Default = 0.
mask_fill_value: Fill value to use when filling mask. Default = 0.

Targets:
image, mask
Expand All @@ -3121,7 +3128,7 @@ def __init__(self, max_objects=1, image_fill_value=0, mask_fill_value=0, always_

@property
def targets_as_params(self):
return ["mask"]
return ["image", "mask"]

def get_params_dependent_on_targets(self, params):
mask = params["mask"]
Expand All @@ -3142,21 +3149,26 @@ def get_params_dependent_on_targets(self, params):
for label_index in labels_index:
dropout_mask |= label_image == label_index

params.update({"dropout_mask": dropout_mask})
image_fill_value = self.image_fill_value
if self.image_fill_value == "random":
img = params["image"]
image_fill_value = get_random_color(F.get_num_channels(img), img.dtype)

params.update({"dropout_mask": dropout_mask, "image_fill_value": image_fill_value})
return params

def apply(self, img, dropout_mask=None, **params):
def apply(self, img, dropout_mask=None, image_fill_value=0, **params):
if dropout_mask is None:
return img

if self.image_fill_value == "inpaint":
if isinstance(image_fill_value, str) and image_fill_value == "inpaint":
dropout_mask = dropout_mask.astype(np.uint8)
_, _, w, h = cv2.boundingRect(dropout_mask)
radius = min(3, max(w, h) // 2)
img = cv2.inpaint(img, dropout_mask, radius, cv2.INPAINT_NS)
else:
img = img.copy()
img[dropout_mask] = self.image_fill_value
img[dropout_mask] = image_fill_value

return img

Expand Down Expand Up @@ -3249,7 +3261,8 @@ class GridDropout(DualTransform):
Clipped between 0 and grid unit height - hole_height. Default: 0.
random_offset (boolean): weather to offset the grid randomly between 0 and grid unit size - hole size
If 'True', entered shift_x, shift_y are ignored and set randomly. Default: `False`.
fill_value (int): value for the dropped pixels. Default = 0
fill_value (int): value for the dropped pixels.
If fill_value is 'random', random color will be generated. Default = 0
mask_fill_value (int): value for the dropped pixels in mask.
If `None`, tranformation is not applied to the mask. Default: `None`.

Expand All @@ -3274,7 +3287,7 @@ def __init__(
shift_x: int = 0,
shift_y: int = 0,
random_offset: bool = False,
fill_value: int = 0,
fill_value: typing.Union[int, str] = 0,
mask_fill_value: int = None,
always_apply: bool = False,
p: float = 0.5,
Expand All @@ -3293,8 +3306,8 @@ def __init__(
if not 0 < self.ratio <= 1:
raise ValueError("ratio must be between 0 and 1.")

def apply(self, image, holes=(), **params):
return F.cutout(image, holes, self.fill_value)
def apply(self, image, holes=(), fill_value=0, **params):
return F.cutout(image, holes, fill_value)

def apply_to_mask(self, image, holes=(), **params):
if self.mask_fill_value is None:
Expand Down Expand Up @@ -3354,7 +3367,11 @@ def get_params_dependent_on_targets(self, params):
y2 = min(y1 + hole_height, height)
holes.append((x1, y1, x2, y2))

return {"holes": holes}
fill_value = self.fill_value
if self.fill_value == "random":
fill_value = get_random_color(F.get_num_channels(img), img.dtype)

return {"holes": holes, "fill_value": fill_value}

@property
def targets_as_params(self):
Expand Down
2 changes: 1 addition & 1 deletion albumentations/core/transforms_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def targets(self):
def update_params(self, params, **kwargs):
if hasattr(self, "interpolation"):
params["interpolation"] = self.interpolation
if hasattr(self, "fill_value"):
if hasattr(self, "fill_value") and "fill_value" not in params:
params["fill_value"] = self.fill_value
params.update({"cols": kwargs["image"].shape[1], "rows": kwargs["image"].shape[0]})
return params
Expand Down
10 changes: 10 additions & 0 deletions albumentations/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import absolute_import
from abc import ABCMeta, abstractmethod
import numpy as np

from ..core.six import string_types, add_metaclass

Expand All @@ -13,6 +14,15 @@ def format_args(args_dict):
return ", ".join(formatted_args)


def get_random_color(img_channels, dtype=np.uint8):
if dtype == np.uint8:
fill_value = np.random.randint(0, 256, img_channels, np.uint8)
else:
fill_value = np.random.uniform(0, 1, size=img_channels).astype(np.float32)

return fill_value


@add_metaclass(ABCMeta)
class Params:
def __init__(self, format, label_fields=None):
Expand Down
19 changes: 19 additions & 0 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,3 +652,22 @@ def test_gauss_noise_incorrect_var_limit_type():
A.GaussNoise(var_limit={"low": 70, "high": 90})
message = "Expected var_limit type to be one of (int, float, tuple, list), got <class 'dict'>"
assert str(exc_info.value) == message


@pytest.mark.parametrize(
["augmentation_cls", "params"],
[
[A.CoarseDropout, {"fill_value": "random"}],
[A.GridDropout, {"fill_value": "random"}],
[A.MaskDropout, {"image_fill_value": "random"}],
],
)
def test_fill_value_random(augmentation_cls, params):
image = np.zeros((100, 100, 3))
mask = np.random.randint(0, 5, image.shape[:2], dtype=np.uint8)
aug = augmentation_cls(always_apply=True, **params)

augmented1 = aug(image=image, mask=mask)["image"]
augmented2 = aug(image=image, mask=mask)["image"]

assert not np.allclose(np.unique(augmented1), np.unique(augmented2))