Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AugMix #607

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ The full documentation is available at [albumentations.readthedocs.io](https://a
## Pixel-level transforms
Pixel-level transforms will change just an input image and will leave any additional targets such as masks, bounding boxes, and keypoints unchanged. The list of pixel-level transforms:

- [AugMix](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.AugMix)
- [Autocontrast](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.Autocontrast)
- [Blur](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.Blur)
- [CLAHE](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.CLAHE)
- [ChannelDropout](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.ChannelDropout)
Expand Down Expand Up @@ -183,6 +185,7 @@ Spatial-level transforms will simultaneously change both an input image as well
| [RandomResizedCrop](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomResizedCrop) | ✓ | ✓ | ✓ | ✓ |
| [RandomRotate90](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomRotate90) | ✓ | ✓ | ✓ | ✓ |
| [RandomScale](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomScale) | ✓ | ✓ | ✓ | ✓ |
| [RandomShear](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomShear) | ✓ | ✓ | ✓ | ✓ |
| [RandomSizedBBoxSafeCrop](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomSizedBBoxSafeCrop) | ✓ | ✓ | ✓ | |
| [RandomSizedCrop](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomSizedCrop) | ✓ | ✓ | ✓ | ✓ |
| [Resize](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.Resize) | ✓ | ✓ | ✓ | ✓ |
Expand Down
102 changes: 96 additions & 6 deletions albumentations/augmentations/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,13 +234,20 @@ def shift_scale_rotate(
return warp_affine_fn(img)


def bbox_shift_scale_rotate(bbox, angle, scale, dx, dy, rows, cols, **kwargs): # skipcq: PYL-W0613
@preserve_channel_dim
def shear(img, shear_x=0, shear_y=0, interpolation=cv2.INTER_LINEAR, border_mode=cv2.BORDER_REFLECT_101, value=None):
height, width = img.shape[:2]
matrix = np.array([[1, shear_x, 0], [shear_y, 1, 0]], dtype=np.float32)

warp_affine_fn = _maybe_process_in_chunks(
cv2.warpAffine, M=matrix, dsize=(width, height), flags=interpolation, borderMode=border_mode, borderValue=value
)
return warp_affine_fn(img)


def bbox_affine_transform(bbox, height, width, matrix):
x_min, y_min, x_max, y_max = bbox[:4]
height, width = rows, cols
center = (width / 2, height / 2)
matrix = cv2.getRotationMatrix2D(center, angle, scale)
matrix[0, 2] += dx * width
matrix[1, 2] += dy * height

x = np.array([x_min, x_max, x_max, x_min])
y = np.array([y_min, y_min, y_max, y_max])
ones = np.ones(shape=(len(x)))
Expand All @@ -257,6 +264,22 @@ def bbox_shift_scale_rotate(bbox, angle, scale, dx, dy, rows, cols, **kwargs):
return x_min, y_min, x_max, y_max


def bbox_shift_scale_rotate(bbox, angle, scale, dx, dy, rows, cols, **kwargs): # skipcq: PYL-W0613
height, width = rows, cols
center = (width / 2, height / 2)
matrix = cv2.getRotationMatrix2D(center, angle, scale)
matrix[0, 2] += dx * width
matrix[1, 2] += dy * height

return bbox_affine_transform(bbox, height, width, matrix)


def bbox_shear(bbox, shear_x, shear_y, rows, cols, **kwargs): # skipcq: PYL-W0613
matrix = np.array([[1, shear_x, 0], [shear_y, 1, 0]], dtype=np.float32)

return bbox_affine_transform(bbox, rows, cols, matrix)


@angle_2pi_range
def keypoint_shift_scale_rotate(keypoint, angle, scale, dx, dy, rows, cols, **params):
x, y, a, s, = keypoint[:4]
Expand All @@ -273,6 +296,15 @@ def keypoint_shift_scale_rotate(keypoint, angle, scale, dx, dy, rows, cols, **pa
return x, y, angle, scale


def keypoint_shear(keypoint, shear_x, shear_y):
x, y, a, s, = keypoint[:4]
matrix = np.array([[1, shear_x, 0], [shear_y, 1, 0]], dtype=np.float32)

x, y = cv2.transform(np.array([[[x, y]]]), matrix).squeeze()

return x, y, a, s


def crop(img, x_min, y_min, x_max, y_max):
height, width = img.shape[:2]
if x_max <= x_min or y_max <= y_min:
Expand Down Expand Up @@ -610,6 +642,39 @@ def equalize(img, mask=None, mode="cv", by_channels=True):
return result_img


def _autocontrast(img):
h = cv2.calcHist([img], [0], None, [256], (0, 256)).ravel()

for lo in range(256):
if h[lo]:
break
for hi in range(255, -1, -1):
if h[hi]:
break

if hi > lo:
lut = np.zeros(256, dtype=np.uint8)
scale_coef = 255.0 / (hi - lo)
offset = -lo * scale_coef
for ix in range(256):
lut[ix] = int(np.clip(ix * scale_coef + offset, 0, 255))

img = cv2.LUT(img, lut)

return img


@preserve_channel_dim
def autocontrast(img):
if len(img.shape) == 2:
result = _autocontrast(img)
else:
result = np.zeros_like(img)
for ch in range(get_num_channels(img)):
result[..., ch] = _autocontrast(img[..., ch])
return result


@clipped
def _shift_rgb_non_uint8(img, r_shift, g_shift, b_shift):
if r_shift == g_shift == b_shift:
Expand Down Expand Up @@ -2041,3 +2106,28 @@ def glass_blur(img, sigma, max_delta, iterations, dxy, mode):
x[h, w], x[h + dy, w + dx] = x[h + dy, w + dx], x[h, w]

return np.clip(cv2.GaussianBlur(x / coef, sigmaX=sigma, ksize=(0, 0)), 0, 1) * coef


def aug_mix(img, alpha, width, depth, transforms, mean, std, random_state=None):
if random_state is None:
random_state = np.random.RandomState(42)

if img.dtype == np.float32:
img = (img * 255).astype(np.uint8)

ws = np.float32(random_state.dirichlet([alpha] * width))
m = np.float32(random_state.beta(alpha, alpha))

mix = np.zeros_like(img, dtype=np.float32)
for i in range(width):
image_aug = img.copy()

for _ in range(depth):
op = random_state.choice(transforms)
image_aug = op(image=image_aug)["image"]

mix += ws[i] * normalize(image_aug, mean, std)

mix = (1 - m) * normalize(img, mean, std) + m * mix

return mix
173 changes: 173 additions & 0 deletions albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@
"FancyPCA",
"MaskDropout",
"GridDropout",
"AugMix",
"RandomShear",
"Autocontrast",
]


Expand Down Expand Up @@ -581,6 +584,68 @@ def get_transform_init_args(self):
return {"interpolation": self.interpolation, "scale_limit": to_tuple(self.scale_limit, bias=-1.0)}


class RandomShear(DualTransform):
"""Randomly resize the input. Output image size is different from the input image size.
Args:
shear_x (float, tuple of floats): Shear along x axis. If single float shear_x is picked
from (-shear_x, shear_x) interval. Default: 0.1.
shear_y (float, tuple of floats): Shear along y axis. If single float shear_y is picked
from (-shear_y, shear_y) interval. Default: 0.1.
interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of:
cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
Default: cv2.INTER_LINEAR.
border_mode (OpenCV flag): flag that is used to specify the pixel extrapolation method. Should be one of:
cv2.BORDER_CONSTANT, cv2.BORDER_REPLICATE, cv2.BORDER_REFLECT, cv2.BORDER_WRAP, cv2.BORDER_REFLECT_101.
Default: cv2.BORDER_REFLECT_101
value (int, float, list of int, list of float): padding value if border_mode is cv2.BORDER_CONSTANT.
p (float): probability of applying the transform. Default: 0.5.
Targets:
image, mask, bboxes, keypoints
Image types:
uint8, float32
"""

def __init__(
self,
shear_x=0.1,
shear_y=0.1,
interpolation=cv2.INTER_LINEAR,
border_mode=cv2.cv2.BORDER_REFLECT_101,
value=None,
always_apply=False,
p=0.5,
):
super().__init__(always_apply, p)
self.shear_x = to_tuple(shear_x)
self.shear_y = to_tuple(shear_y)
self.interpolation = interpolation
self.border_mode = border_mode
self.value = value

def get_params(self):
return {
"shear_x": random.uniform(self.shear_x[0], self.shear_x[1]),
"shear_y": random.uniform(self.shear_y[0], self.shear_y[1]),
}

def apply(self, img, shear_x=0, shear_y=0, **params):
return F.shear(
img, shear_x, shear_y, interpolation=cv2.INTER_LINEAR, border_mode=cv2.BORDER_REFLECT_101, value=self.value
)

def apply_to_bbox(self, bbox, shear_x=0, shear_y=0, **params):
return F.bbox_shear(bbox, shear_x, shear_y, **params)

def apply_to_keypoint(self, keypoint, shear_x=0, shear_y=0, **params):
return F.keypoint_shear(keypoint, shear_x, shear_y)

def get_transform_init_args_names(self):
return ("shear_x", "shear_y", "interpolation", "border_mode", "value")


class ShiftScaleRotate(DualTransform):
"""Randomly apply affine transforms: translate, scale and rotate the input.
Expand Down Expand Up @@ -2287,6 +2352,30 @@ def get_transform_init_args_names(self):
return ("mode", "by_channels")


class Autocontrast(ImageOnlyTransform):
"""Perform automatic contrast enhancement.
Args:
p (float): probability of applying the transform. Default: 0.5.
Targets:
image
Image types:
uint8
"""

def __init__(self, always_apply=False, p=0.5):
super().__init__(always_apply, p)

def apply(self, image, **params):
return F.autocontrast(image)

def get_transform_init_args_names(self):
return ()


class RGBShift(ImageOnlyTransform):
"""Randomly shift values for each channel of the input RGB image.
Expand Down Expand Up @@ -3378,3 +3467,87 @@ def get_transform_init_args_names(self):
"mask_fill_value",
"random_offset",
)


class AugMix(ImageOnlyTransform):
"""Augmentation from "AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty"
Please note that this augmentation performs normalization internally, so resulting image will be float type.
Args:
alpha (float): Probability coefficient for Beta and Dirichlet distributions. Default: (1.0).
width (int): Width of augmentation chain. Default: 3.
depth (int, tuple of ints): Depth of augmentation chain. If single int will be used provided number.
If tuple of depth will be generated in range `[depth[0], depth[1])`. Default: (3).
transforms (list of albumentation transforms): List of transforms from which augmentation will be sampled
on each step of AugMix procedure.
mean (float, list of floats, tuple of float): mean values for normalization. Default: (0.485, 0.456, 0.406).
std (float, list of floats, tuple of floats): std values for normalization. Default: (0.229, 0.224, 0.225).
Targets:
image
Image types:
uint8, float32 3-channel images only
Credit:
https://arxiv.org/pdf/1912.02781.pdf
https://github.com/google-research/augmix
"""

def __init__(
self,
alpha=1.0,
width=3,
depth=3,
transforms=None,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
always_apply=False,
p=0.5,
):
super().__init__(always_apply=always_apply, p=p)
self.alpha = alpha
self.width = width
self.depth = to_tuple(depth, low=1)
self.transforms = transforms

if self.transforms is None:
self.transforms = [
Autocontrast(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All transforms need to have p=1, such that they are always executed when selected by Augmix.

Maybe add a comment and change default values ?

Posterize(num_bits=(3, 4)),
ShiftScaleRotate(shift_limit=0, scale_limit=0, rotate_limit=5, border_mode=cv2.BORDER_CONSTANT),
Solarize(threshold=77),
RandomShear(shear_x=0.09, shear_y=0, border_mode=cv2.BORDER_CONSTANT),
RandomShear(shear_x=0, shear_y=0.09, border_mode=cv2.BORDER_CONSTANT),
ShiftScaleRotate(shift_limit=0.09, scale_limit=0, rotate_limit=0, border_mode=cv2.BORDER_CONSTANT),
]

self.mean = mean
self.std = std

def apply(self, img, depth=3, random_state=None, **params):
return F.aug_mix(
img,
self.alpha,
self.width,
depth,
self.transforms,
self.mean,
self.std,
random_state=np.random.RandomState(random_state),
)

def get_params(self):
return {"depth": random.randint(self.depth[0], self.depth[1]), "random_state": random.randint(0, 10000)}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assigning a single int to depth here makes all augmentations of width have the same depth (this is in divergence to the reference implementation


def _to_dict(self):
state = {
"__class_fullname__": self.get_class_fullname(),
"alpha": self.alpha,
"width": self.width,
"depth": self.depth,
"mean": self.mean,
"std": self.std,
"transforms": [t._to_dict() for t in self.transforms], # skipcq: PYL-W0212
}
state.update(self.get_base_init_args())
return state
Loading