diff --git a/ppdet/data/transform/autoaugment_utils.py b/ppdet/data/transform/autoaugment_utils.py index cfa89d374d9..a7859e0f516 100644 --- a/ppdet/data/transform/autoaugment_utils.py +++ b/ppdet/data/transform/autoaugment_utils.py @@ -21,8 +21,9 @@ import inspect import math -from PIL import Image, ImageEnhance +from PIL import Image, ImageEnhance, ImageOps import numpy as np +import random import cv2 from copy import deepcopy @@ -195,7 +196,7 @@ def blend(image1, image2, factor): return np.clip(temp, a_min=0, a_max=255).astype(np.uint8) -def cutout(image, pad_size, replace=0): +def cutout(image, bboxes, pad_size, replace=0, threshold=0.5): """Apply cutout (https://arxiv.org/abs/1708.04552) to image. This operation applies a (2*pad_size x 2*pad_size) mask of zeros to @@ -210,6 +211,8 @@ def cutout(image, pad_size, replace=0): (2*pad_size x 2*pad_size). replace: What pixel value to fill in the image in the area that has the cutout mask applied to it. + threshold: float, Calculate the proportion of cut area in the box, + and if the cut area is less than the threshold, it will be retained Returns: An image Tensor that is of type uint8. @@ -217,6 +220,9 @@ def cutout(image, pad_size, replace=0): img = cv2.imread( "/home/vis/gry/train/img_data/test.jpg", cv2.COLOR_BGR2RGB ) new_img = cutout(img, pad_size=50, replace=0) """ + if not (bboxes.size > 0): + return image, bboxes + image_height, image_width = image.shape[0], image.shape[1] cutout_center_height = np.random.randint(low=0, high=image_height) @@ -232,6 +238,15 @@ def cutout(image, pad_size, replace=0): image_width - (left_pad + right_pad) ] padding_dims = [[lower_pad, upper_pad], [left_pad, right_pad]] + + cut_box = [ + left_pad, lower_pad, cutout_center_width + pad_size, + cutout_center_height + pad_size + ] + overlapping_iou = _cut_iou_calculate(cut_box, bboxes, image_height, + image_width) + bboxes = bboxes[overlapping_iou < threshold] + mask = np.pad(np.zeros( cutout_shape, dtype=image.dtype), padding_dims, @@ -244,7 +259,7 @@ def cutout(image, pad_size, replace=0): np.ones_like( image, dtype=image.dtype) * replace, image) - return image.astype(np.uint8) + return image.astype(np.uint8), bboxes def solarize(image, threshold=128): @@ -272,21 +287,27 @@ def color(image, factor): # refer to https://github.com/4uiiurz1/pytorch-auto-augment/blob/024b2eac4140c38df8342f09998e307234cafc80/auto_augment.py#L197 -def contrast(img, factor): - img = ImageEnhance.Contrast(Image.fromarray(img)).enhance(factor) - return np.array(img) +def contrast(image, factor): + image = ImageEnhance.Contrast(Image.fromarray(image)).enhance(factor) + return np.array(image) def brightness(image, factor): """Equivalent of PIL Brightness.""" - degenerate = np.zeros_like(image) - return blend(degenerate, image, factor) + if isinstance(image, np.ndarray): + image = Image.fromarray(np.uint8(image)) + image = ImageEnhance.Brightness(image) + image = image.enhance(factor) + return np.array(image) def posterize(image, bits): """Equivalent of PIL Posterize.""" shift = 8 - bits - return np.left_shift(np.right_shift(image, shift), shift) + if isinstance(image, np.ndarray): + image = Image.fromarray(np.uint8(image)) + image = ImageOps.posterize(image, shift) + return np.array(image) def rotate(image, degrees, replace): @@ -531,6 +552,9 @@ def _apply_bbox_augmentation(image, bbox, augmentation_func, *args): # Get the sub-tensor that is the image within the bounding box region. bbox_content = image[min_y:max_y + 1, min_x:max_x + 1, :] + if bbox_content.shape[0] == 0 or bbox_content.shape[1] == 0: + return image + # Apply the augmentation function to the bbox portion of the image. augmented_bbox_content = augmentation_func(bbox_content, *args) @@ -543,15 +567,15 @@ def _apply_bbox_augmentation(image, bbox, augmentation_func, *args): constant_values=1) # Create a mask that will be used to zero out a part of the original image. - mask_tensor = np.zeros_like(bbox_content) + mask_array = np.zeros_like(bbox_content) - mask_tensor = np.pad(mask_tensor, - [[min_y, (image_height - 1) - max_y], - [min_x, (image_width - 1) - max_x], [0, 0]], - 'constant', - constant_values=1) + mask_array = np.pad(mask_array, + [[min_y, (image_height - 1) - max_y], + [min_x, (image_width - 1) - max_x], [0, 0]], + 'constant', + constant_values=1) # Replace the old bbox content with the new augmented content. - image = image * mask_tensor + augmented_bbox_content + image = image * mask_array + augmented_bbox_content return image.astype(np.uint8) @@ -560,10 +584,10 @@ def _concat_bbox(bbox, bboxes): # Note if all elements in bboxes are -1 (_INVALID_BOX), then this means # we discard bboxes and start the bboxes Tensor with the current bbox. - bboxes_sum_check = np.sum(bboxes) - bbox = np.expand_dims(bbox, 0) + bboxes_sum_check = np.sum(bboxes, axis=-1) + bbox = bbox[np.newaxis, ...] # This check will be true when it is an _INVALID_BOX - if _equal(bboxes_sum_check, -4): + if np.any(bboxes_sum_check == -4.0): bboxes = bbox else: bboxes = np.concatenate([bboxes, bbox], 0) @@ -646,7 +670,9 @@ def _apply_multi_bbox_augmentation(image, bboxes, prob, aug_func, # If the bboxes are empty, then just give it _INVALID_BOX. The result # will be thrown away. - bboxes = np.array((_INVALID_BOX)) if bboxes.size == 0 else bboxes + gt_with_crowd = np.array( + [[-1.0, -1.0]]) if bboxes.size == 0 else bboxes[:, 4:] + bboxes = np.array((_INVALID_BOX)) if bboxes.size == 0 else bboxes[:, :4] assert bboxes.shape[1] == 4, "bboxes.shape[1] must be 4!!!!" @@ -692,19 +718,17 @@ def cond(_idx, _images_and_bboxes): final_bboxes = new_bboxes else: final_bboxes = bboxes - return image, final_bboxes + return image, np.concatenate([final_bboxes, gt_with_crowd], 1) def _apply_multi_bbox_augmentation_wrapper(image, bboxes, prob, aug_func, func_changes_bbox, *args): """Checks to be sure num bboxes > 0 before calling inner function.""" num_bboxes = len(bboxes) - new_image = deepcopy(image) - new_bboxes = deepcopy(bboxes) if num_bboxes != 0: - new_image, new_bboxes = _apply_multi_bbox_augmentation( - new_image, new_bboxes, prob, aug_func, func_changes_bbox, *args) - return new_image, new_bboxes + image, bboxes = _apply_multi_bbox_augmentation( + image, bboxes, prob, aug_func, func_changes_bbox, *args) + return image, bboxes def rotate_only_bboxes(image, bboxes, prob, degrees, replace): @@ -809,6 +833,7 @@ def _rotate_bbox(bbox, image_height, image_width, degrees): min_x = int(image_width * (bbox[1] - 0.5)) max_y = -int(image_height * (bbox[2] - 0.5)) max_x = int(image_width * (bbox[3] - 0.5)) + gt_class, is_crowd = bbox[4], bbox[5] coordinates = np.stack([[min_y, min_x], [min_y, max_x], [max_y, min_x], [max_y, max_x]]).astype(np.float32) # Rotate the coordinates according to the rotation matrix clockwise if @@ -824,10 +849,13 @@ def _rotate_bbox(bbox, image_height, image_width, degrees): max_y = -(float(np.min(new_coords[0, :])) / image_height - 0.5) max_x = float(np.max(new_coords[1, :])) / image_width + 0.5 + if max_x < 0. or min_x > 1.0 or max_y < 0. or min_y > 1.0: + return None + # Clip the bboxes to be sure the fall between [0, 1]. min_y, min_x, max_y, max_x = _clip_bbox(min_y, min_x, max_y, max_x) min_y, min_x, max_y, max_x = _check_bbox_area(min_y, min_x, max_y, max_x) - return np.stack([min_y, min_x, max_y, max_x]) + return np.stack([min_y, min_x, max_y, max_x, gt_class, is_crowd]) def rotate_with_bboxes(image, bboxes, degrees, replace): @@ -839,24 +867,32 @@ def rotate_with_bboxes(image, bboxes, degrees, replace): # pylint:disable=g-long-lambda wrapped_rotate_bbox = lambda bbox: _rotate_bbox(bbox, image_height, image_width, degrees) # pylint:enable=g-long-lambda - new_bboxes = np.zeros_like(bboxes) - for idx in range(len(bboxes)): - new_bboxes[idx] = wrapped_rotate_bbox(bboxes[idx]) - return image, new_bboxes + bboxes = np.array([ + box for box in list(map(wrapped_rotate_bbox, bboxes)) if box is not None + ]) + return image, bboxes def translate_x(image, pixels, replace): """Equivalent of PIL Translate in X dimension.""" - image = Image.fromarray(wrap(image)) - image = image.transform(image.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0)) - return unwrap(np.array(image), replace) + if not isinstance(replace, tuple): + replace = tuple(replace) + if isinstance(image, np.ndarray): + image = Image.fromarray(np.uint8(image)) + image = image.transform( + image.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), fillcolor=replace) + return np.array(image) def translate_y(image, pixels, replace): """Equivalent of PIL Translate in Y dimension.""" - image = Image.fromarray(wrap(image)) - image = image.transform(image.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels)) - return unwrap(np.array(image), replace) + if not isinstance(replace, tuple): + replace = tuple(replace) + if isinstance(image, np.ndarray): + image = Image.fromarray(np.uint8(image)) + image = image.transform( + image.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), fillcolor=replace) + return np.array(image) def _shift_bbox(bbox, image_height, image_width, pixels, shift_horizontal): @@ -880,6 +916,7 @@ def _shift_bbox(bbox, image_height, image_width, pixels, shift_horizontal): min_x = int(float(image_width) * bbox[1]) max_y = int(float(image_height) * bbox[2]) max_x = int(float(image_width) * bbox[3]) + gt_class, is_crowd = bbox[4], bbox[5] if shift_horizontal: min_x = np.maximum(0, min_x - pixels) @@ -894,10 +931,14 @@ def _shift_bbox(bbox, image_height, image_width, pixels, shift_horizontal): max_y = float(max_y) / float(image_height) max_x = float(max_x) / float(image_width) + # Out of bounds. + if max_x < 0. or min_x > 1.0 or max_y < 0. or min_y > 1.0: + return None + # Clip the bboxes to be sure the fall between [0, 1]. min_y, min_x, max_y, max_x = _clip_bbox(min_y, min_x, max_y, max_x) min_y, min_x, max_y, max_x = _check_bbox_area(min_y, min_x, max_y, max_x) - return np.stack([min_y, min_x, max_y, max_x]) + return np.stack([min_y, min_x, max_y, max_x, gt_class, is_crowd]) def translate_bbox(image, bboxes, pixels, replace, shift_horizontal): @@ -918,6 +959,8 @@ def translate_bbox(image, bboxes, pixels, replace, shift_horizontal): image by pixels. The second element of the tuple is bboxes, where now the coordinates will be shifted to reflect the shifted image. """ + if not isinstance(replace, tuple): + replace = tuple(replace) if shift_horizontal: image = translate_x(image, pixels, replace) else: @@ -928,11 +971,10 @@ def translate_bbox(image, bboxes, pixels, replace, shift_horizontal): # pylint:disable=g-long-lambda wrapped_shift_bbox = lambda bbox: _shift_bbox(bbox, image_height, image_width, pixels, shift_horizontal) # pylint:enable=g-long-lambda - new_bboxes = deepcopy(bboxes) - num_bboxes = len(bboxes) - for idx in range(num_bboxes): - new_bboxes[idx] = wrapped_shift_bbox(bboxes[idx]) - return image.astype(np.uint8), new_bboxes + bboxes = np.array([ + box for box in list(map(wrapped_shift_bbox, bboxes)) if box is not None + ]) + return image.astype(np.uint8), bboxes def shear_x(image, level, replace): @@ -941,9 +983,16 @@ def shear_x(image, level, replace): # with a matrix form of: # [1 level # 0 1]. - image = Image.fromarray(wrap(image)) - image = image.transform(image.size, Image.AFFINE, (1, level, 0, 0, 1, 0)) - return unwrap(np.array(image), replace) + if not isinstance(replace, tuple): + replace = tuple(replace) + if isinstance(image, np.ndarray): + image = Image.fromarray(np.uint8(image)) + image = image.transform( + image.size, + Image.AFFINE, (1, level, 0, 0, 1, 0), + Image.BICUBIC, + fillcolor=replace) + return np.array(image) def shear_y(image, level, replace): @@ -952,9 +1001,16 @@ def shear_y(image, level, replace): # with a matrix form of: # [1 0 # level 1]. - image = Image.fromarray(wrap(image)) - image = image.transform(image.size, Image.AFFINE, (1, 0, 0, level, 1, 0)) - return unwrap(np.array(image), replace) + if not isinstance(replace, tuple): + replace = tuple(replace) + if isinstance(image, np.ndarray): + image = Image.fromarray(np.uint8(image)) + image = image.transform( + image.size, + Image.AFFINE, (1, 0, 0, level, 1, 0), + Image.BICUBIC, + fillcolor=replace) + return np.array(image) def _shear_bbox(bbox, image_height, image_width, level, shear_horizontal): @@ -979,6 +1035,7 @@ def _shear_bbox(bbox, image_height, image_width, level, shear_horizontal): min_x = int(image_width * bbox[1]) max_y = int(image_height * bbox[2]) max_x = int(image_width * bbox[3]) + gt_class, is_crowd = bbox[4], bbox[5] coordinates = np.stack( [[min_y, min_x], [min_y, max_x], [max_y, min_x], [max_y, max_x]]) coordinates = coordinates.astype(np.float32) @@ -998,10 +1055,13 @@ def _shear_bbox(bbox, image_height, image_width, level, shear_horizontal): max_y = float(np.max(new_coords[0, :])) / image_height max_x = float(np.max(new_coords[1, :])) / image_width + if max_x < 0. or min_x > 1.0 or max_y < 0. or min_y > 1.0: + return None + # Clip the bboxes to be sure the fall between [0, 1]. min_y, min_x, max_y, max_x = _clip_bbox(min_y, min_x, max_y, max_x) min_y, min_x, max_y, max_x = _check_bbox_area(min_y, min_x, max_y, max_x) - return np.stack([min_y, min_x, max_y, max_x]) + return np.stack([min_y, min_x, max_y, max_x, gt_class, is_crowd]) def shear_with_bboxes(image, bboxes, level, replace, shear_horizontal): @@ -1033,105 +1093,35 @@ def shear_with_bboxes(image, bboxes, level, replace, shear_horizontal): # pylint:disable=g-long-lambda wrapped_shear_bbox = lambda bbox: _shear_bbox(bbox, image_height, image_width, level, shear_horizontal) # pylint:enable=g-long-lambda - new_bboxes = deepcopy(bboxes) - num_bboxes = len(bboxes) - for idx in range(num_bboxes): - new_bboxes[idx] = wrapped_shear_bbox(bboxes[idx]) - return image.astype(np.uint8), new_bboxes + bboxes = np.array([ + box for box in list(map(wrapped_shear_bbox, bboxes)) if box is not None + ]) + return image, bboxes def autocontrast(image): - """Implements Autocontrast function from PIL. - - Args: - image: A 3D uint8 tensor. - - Returns: - The image after it has had autocontrast applied to it and will be of type - uint8. - """ - - def scale_channel(image): - """Scale the 2D image using the autocontrast rule.""" - # A possibly cheaper version can be done using cumsum/unique_with_counts - # over the histogram values, rather than iterating over the entire image. - # to compute mins and maxes. - lo = float(np.min(image)) - hi = float(np.max(image)) - - # Scale the image, making the lowest value 0 and the highest value 255. - def scale_values(im): - scale = 255.0 / (hi - lo) - offset = -lo * scale - im = im.astype(np.float32) * scale + offset - img = np.clip(im, a_min=0, a_max=255.0) - return im.astype(np.uint8) - - result = scale_values(image) if hi > lo else image - return result - - # Assumes RGB for now. Scales each channel independently - # and then stacks the result. - s1 = scale_channel(image[:, :, 0]) - s2 = scale_channel(image[:, :, 1]) - s3 = scale_channel(image[:, :, 2]) - image = np.stack([s1, s2, s3], 2) - return image + """Implements Autocontrast function from PIL.""" + if isinstance(image, np.ndarray): + image = Image.fromarray(np.uint8(image)) + image = ImageOps.autocontrast(image) + return np.array(image) def sharpness(image, factor): """Implements Sharpness function from PIL.""" - orig_image = image - image = image.astype(np.float32) - # Make image 4D for conv operation. - # SMOOTH PIL Kernel. - kernel = np.array([[1, 1, 1], [1, 5, 1], [1, 1, 1]], dtype=np.float32) / 13. - result = cv2.filter2D(image, -1, kernel).astype(np.uint8) - - # Blend the final result. - return blend(result, orig_image, factor) + if isinstance(image, np.ndarray): + image = Image.fromarray(np.uint8(image)) + image = ImageEnhance.Sharpness(image).enhance(1 + factor * random.choice( + [-1, 1])) + return np.array(image) def equalize(image): """Implements Equalize function from PIL using.""" - - def scale_channel(im, c): - """Scale the data in the channel to implement equalize.""" - im = im[:, :, c].astype(np.int32) - # Compute the histogram of the image channel. - histo, _ = np.histogram(im, range=[0, 255], bins=256) - - # For the purposes of computing the step, filter out the nonzeros. - nonzero = np.where(np.not_equal(histo, 0)) - nonzero_histo = np.reshape(np.take(histo, nonzero), [-1]) - step = (np.sum(nonzero_histo) - nonzero_histo[-1]) // 255 - - def build_lut(histo, step): - # Compute the cumulative sum, shifting by step // 2 - # and then normalization by step. - lut = (np.cumsum(histo) + (step // 2)) // step - # Shift lut, prepending with 0. - lut = np.concatenate([[0], lut[:-1]], 0) - # Clip the counts to be in range. This is done - # in the C code for image.point. - return np.clip(lut, a_min=0, a_max=255).astype(np.uint8) - - # If step is zero, return the original image. Otherwise, build - # lut from the full histogram and step and then index from it. - if step == 0: - result = im - else: - result = np.take(build_lut(histo, step), im) - - return result.astype(np.uint8) - - # Assumes RGB for now. Scales each channel independently - # and then stacks the result. - s1 = scale_channel(image, 0) - s2 = scale_channel(image, 1) - s3 = scale_channel(image, 2) - image = np.stack([s1, s2, s3], 2) - return image + if isinstance(image, np.ndarray): + image = Image.fromarray(np.uint8(image)) + image = ImageOps.equalize(image) + return np.array(image) def wrap(image): @@ -1184,7 +1174,29 @@ def unwrap(image, replace): return image.astype(np.uint8) -def _cutout_inside_bbox(image, bbox, pad_fraction): +def _cut_iou_calculate(cut_area, box, image_height, image_width): + x_min1, y_min1, x_max1, y_max1 = cut_area + y_min2 = np.int32(box[:, 0] * image_height) + x_min2 = np.int32(box[:, 1] * image_width) + y_max2 = np.int32(box[:, 2] * image_height) + x_max2 = np.int32(box[:, 3] * image_width) + + x_min_overlap = np.maximum(x_min1, x_min2) + y_min_overlap = np.maximum(y_min1, y_min2) + x_max_overlap = np.minimum(x_max1, x_max2) + y_max_overlap = np.minimum(y_max1, y_max2) + + width_overlap = np.maximum(0, x_max_overlap - x_min_overlap) + height_overlap = np.maximum(0, y_max_overlap - y_min_overlap) + + area_overlap = width_overlap * height_overlap + area_box = (x_max2 - x_min2) * (y_max2 - y_min2) + 1e-8 + iou = area_overlap / area_box + + return iou + + +def _cutout_inside_bbox(image, bbox, pad_fraction, bboxes, threshold=0.5): """Generates cutout mask and the mean pixel value of the bbox. First a location is randomly chosen within the image as the center where the @@ -1199,6 +1211,9 @@ def _cutout_inside_bbox(image, bbox, pad_fraction): in reference to the size of the original bbox. If pad_fraction is 0.25, then the cutout mask will be of shape (0.25 * bbox height, 0.25 * bbox width). + bboxes: GT labels, + threshold: float, Calculate the proportion of cut area in the box, + and if the cut area is less than the threshold, it will be retained Returns: A tuple. Fist element is a tensor of the same shape as image where each @@ -1251,10 +1266,20 @@ def _cutout_inside_bbox(image, bbox, pad_fraction): mask = np.expand_dims(mask, 2) mask = np.tile(mask, [1, 1, 3]) - return mask, mean + # xyxy + cut_box = [ + left_pad, lower_pad, cutout_center_width + pad_size_width, + cutout_center_height + pad_size_height + ] + # Calculate the proportion of cut area in the box, + # and if the cut area is less than the threshold, it will be retained + overlapping_iou = _cut_iou_calculate(cut_box, bboxes, image_height, + image_width) + bboxes = bboxes[overlapping_iou < threshold] + return mask, mean, bboxes -def bbox_cutout(image, bboxes, pad_fraction, replace_with_mean): +def bbox_cutout(image, bboxes, pad_fraction, replace_with_mean, threshold=0.5): """Applies cutout to the image according to bbox information. This is a cutout variant that using bbox information to make more informed @@ -1275,6 +1300,8 @@ def bbox_cutout(image, bboxes, pad_fraction, replace_with_mean): we set the value to be 128. If replace_with_mean is True then we find the mean pixel values across the channel dimension and use those to fill in where the cutout mask is applied. + threshold: float, Calculate the proportion of cut area in the box, + and if the cut area is less than the threshold, it will be retained Returns: A tuple. First element is a tensor of the same shape as image that has @@ -1282,17 +1309,18 @@ def bbox_cutout(image, bboxes, pad_fraction, replace_with_mean): that will be unchanged. """ - def apply_bbox_cutout(image, bboxes, pad_fraction): + def apply_bbox_cutout(image, bboxes, pad_fraction, threshold=threshold): """Applies cutout to a single bounding box within image.""" # Choose a single bounding box to apply cutout to. random_index = np.random.randint(0, bboxes.shape[0], dtype=np.int32) # Select the corresponding bbox and apply cutout. - chosen_bbox = np.take(bboxes, random_index, axis=0) - mask, mean = _cutout_inside_bbox(image, chosen_bbox, pad_fraction) + chosen_bbox = bboxes[random_index] + mask, mean, bboxes = _cutout_inside_bbox( + image, chosen_bbox, pad_fraction, bboxes, threshold=threshold) # When applying cutout we either set the pixel value to 128 or to the mean # value inside the bbox. - replace = mean if replace_with_mean else [128] * 3 + replace = mean if replace_with_mean else 128 # Apply the cutout mask to the image. Where the mask is 0 we fill it with # `replace`. @@ -1301,11 +1329,11 @@ def apply_bbox_cutout(image, bboxes, pad_fraction): np.ones_like( image, dtype=image.dtype) * replace, image).astype(image.dtype) - return image + return image, bboxes # Check to see if there are boxes, if so then apply boxcutout. if len(bboxes) != 0: - image = apply_bbox_cutout(image, bboxes, pad_fraction) + image, bboxes = apply_bbox_cutout(image, bboxes, pad_fraction) return image, bboxes @@ -1457,8 +1485,6 @@ def _parse_policy_info(name, prob, level, replace_value, augmentation_hparams): # Add in replace arg if it is required for the function that is being called. if 'replace' in inspect.getfullargspec(func)[0]: - # Make sure replace is the final argument - assert 'replace' == inspect.getfullargspec(func)[0][-1] args = tuple(list(args) + [replace_value]) # Add bboxes as the second positional argument for the function if it does diff --git a/ppdet/data/transform/operators.py b/ppdet/data/transform/operators.py index 61a4aacba02..e5049b240aa 100644 --- a/ppdet/data/transform/operators.py +++ b/ppdet/data/transform/operators.py @@ -36,6 +36,7 @@ import logging import cv2 from PIL import Image, ImageDraw +import paddle.vision.transforms.functional as TF import pickle import threading MUTEX = threading.Lock() @@ -495,7 +496,7 @@ class RandomDistort(BaseOperator): """ def __init__(self, - hue=[-18, 18, 0.5], + hue=[-0.5, 0.5, 0.5], saturation=[0.5, 1.5, 0.5], contrast=[0.5, 1.5, 0.5], brightness=[0.5, 1.5, 0.5], @@ -515,19 +516,8 @@ def apply_hue(self, img): low, high, prob = self.hue if np.random.uniform(0., 1.) < prob: return img - - img = img.astype(np.float32) - # it works, but result differ from HSV version delta = np.random.uniform(low, high) - u = np.cos(delta * np.pi) - w = np.sin(delta * np.pi) - bt = np.array([[1.0, 0.0, 0.0], [0.0, u, -w], [0.0, w, u]]) - tyiq = np.array([[0.299, 0.587, 0.114], [0.596, -0.274, -0.321], - [0.211, -0.523, 0.311]]) - ityiq = np.array([[1.0, 0.956, 0.621], [1.0, -0.272, -0.647], - [1.0, -1.107, 1.705]]) - t = np.dot(np.dot(ityiq, bt), tyiq).T - img = np.dot(img, t) + img = TF.adjust_hue(img, delta) return img def apply_saturation(self, img): @@ -535,13 +525,7 @@ def apply_saturation(self, img): if np.random.uniform(0., 1.) < prob: return img delta = np.random.uniform(low, high) - img = img.astype(np.float32) - # it works, but result differ from HSV version - gray = img * np.array([[[0.299, 0.587, 0.114]]], dtype=np.float32) - gray = gray.sum(axis=2, keepdims=True) - gray *= (1.0 - delta) - img *= delta - img += gray + img = TF.adjust_saturation(img, delta) return img def apply_contrast(self, img): @@ -549,8 +533,7 @@ def apply_contrast(self, img): if np.random.uniform(0., 1.) < prob: return img delta = np.random.uniform(low, high) - img = img.astype(np.float32) - img *= delta + img = TF.adjust_contrast(img, delta) return img def apply_brightness(self, img): @@ -558,8 +541,7 @@ def apply_brightness(self, img): if np.random.uniform(0., 1.) < prob: return img delta = np.random.uniform(low, high) - img = img.astype(np.float32) - img += delta + img = TF.adjust_brightness(img, delta) return img def apply(self, sample, context=None): @@ -711,32 +693,41 @@ def apply(self, sample, context=None): Learning Data Augmentation Strategies for Object Detection, see https://arxiv.org/abs/1906.11172 """ im = sample['image'] - gt_bbox = sample['gt_bbox'] + gt_bbox, gt_class, is_crowd = sample['gt_bbox'], sample[ + 'gt_class'], sample['is_crowd'] + _labels = np.concatenate([gt_bbox, gt_class, is_crowd], 1) + if not isinstance(im, np.ndarray): raise TypeError("{}: image is not a numpy array.".format(self)) if len(im.shape) != 3: raise ImageError("{}: image is not 3-dimensional.".format(self)) - if len(gt_bbox) == 0: + if gt_bbox.size == 0: return sample height, width, _ = im.shape - norm_gt_bbox = np.ones_like(gt_bbox, dtype=np.float32) - norm_gt_bbox[:, 0] = gt_bbox[:, 1] / float(height) - norm_gt_bbox[:, 1] = gt_bbox[:, 0] / float(width) - norm_gt_bbox[:, 2] = gt_bbox[:, 3] / float(height) - norm_gt_bbox[:, 3] = gt_bbox[:, 2] / float(width) + norm_labels = np.copy(_labels) + norm_labels[:, 0] = _labels[:, 1] / float(height) + norm_labels[:, 1] = _labels[:, 0] / float(width) + norm_labels[:, 2] = _labels[:, 3] / float(height) + norm_labels[:, 3] = _labels[:, 2] / float(width) from .autoaugment_utils import distort_image_with_autoaugment - im, norm_gt_bbox = distort_image_with_autoaugment(im, norm_gt_bbox, - self.autoaug_type) + im, norm_labels = distort_image_with_autoaugment(im, norm_labels, + self.autoaug_type) + # if there is not boxes in image after augmentation, just return sample + if norm_labels.size == 0: + return sample - gt_bbox[:, 0] = norm_gt_bbox[:, 1] * float(width) - gt_bbox[:, 1] = norm_gt_bbox[:, 0] * float(height) - gt_bbox[:, 2] = norm_gt_bbox[:, 3] * float(width) - gt_bbox[:, 3] = norm_gt_bbox[:, 2] * float(height) + gt_labels = np.copy(norm_labels) + gt_labels[:, 0] = norm_labels[:, 1] * float(width) + gt_labels[:, 1] = norm_labels[:, 0] * float(height) + gt_labels[:, 2] = norm_labels[:, 3] * float(width) + gt_labels[:, 3] = norm_labels[:, 2] * float(height) sample['image'] = im - sample['gt_bbox'] = gt_bbox + sample['gt_bbox'] = gt_labels[:, :4] + sample['gt_class'] = gt_labels[:, 4:5] + sample['is_crowd'] = gt_labels[:, 5:6] return sample @@ -3455,6 +3446,10 @@ class Mosaic(BaseOperator): enable_mixup (bool): whether to enable Mixup or not mixup_prob (float): probability of using Mixup, 1.0 as default mixup_scale (list[int]): scale range of Mixup + iou_thresh_alpha (float): Ratio of mosaic box area to un clip mosaic box area, + boxes with ratios higher than this value are retained, 0.0 as default, recommand 0.2. + iou_thresh_beta (float): Ratio of mosaic box area to transformed image area, + boxes with ratios higher than this value are retained, 0.0 as default, recommand 0.05. remove_outside_box (bool): whether remove outside boxes, False as default in COCO dataset, True in MOT dataset """ @@ -3469,6 +3464,8 @@ def __init__(self, enable_mixup=True, mixup_prob=1.0, mixup_scale=[0.5, 1.5], + iou_thresh_alpha=0.0, + iou_thresh_beta=0.0, remove_outside_box=False): super(Mosaic, self).__init__() self.prob = prob @@ -3482,6 +3479,8 @@ def __init__(self, self.enable_mixup = enable_mixup self.mixup_prob = mixup_prob self.mixup_scale = mixup_scale + self.iou_thresh_alpha = iou_thresh_alpha + self.iou_thresh_beta = iou_thresh_beta self.remove_outside_box = remove_outside_box def get_mosaic_coords(self, mosaic_idx, xc, yc, w, h, input_h, input_w): @@ -3507,60 +3506,89 @@ def get_mosaic_coords(self, mosaic_idx, xc, yc, w, h, input_h, input_w): return (x1, y1, x2, y2), small_coords - def random_affine_augment(self, - img, - labels=[], - input_dim=[640, 640], - degrees=[-10, 10], - scales=[0.1, 2], - shears=[-2, 2], - translates=[-0.1, 0.1]): - # random rotation and scale + def get_affine_matrix(self, + input_dim, + degrees=[-10, 10], + scales=[0.1, 2], + shears=[-2, 2], + translates=[-0.1, 0.1]): + iwidth, iheight = input_dim + + # Rotation and Scale degree = random.uniform(degrees[0], degrees[1]) scale = random.uniform(scales[0], scales[1]) - assert scale > 0, "Argument scale should be positive." + + if scale <= 0.0: + raise ValueError("Argument scale should be positive") + R = cv2.getRotationMatrix2D(angle=degree, center=(0, 0), scale=scale) - M = np.ones([2, 3]) - # random shear + M = np.ones([2, 3]) + # Shear shear = random.uniform(shears[0], shears[1]) shear_x = math.tan(shear * math.pi / 180) shear_y = math.tan(shear * math.pi / 180) + M[0] = R[0] + shear_y * R[1] M[1] = R[1] + shear_x * R[0] - # random translation + # Translation translate = random.uniform(translates[0], translates[1]) - translation_x = translate * input_dim[0] - translation_y = translate * input_dim[1] + translation_x = translate * iwidth # x translation (pixels) + translation_y = translate * iheight # y translation (pixels) + M[0, 2] = translation_x M[1, 2] = translation_y + return M, scale + + def apply_affine_to_bboxes(self, labels, input_dim, M): + num_gts = len(labels) + + # warp corner points + twidth, theight = input_dim + corner_points = np.ones((4 * num_gts, 3)) + corner_points[:, :2] = labels[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape( + 4 * num_gts, 2) # x1y1, x2y2, x1y2, x2y1 + corner_points = corner_points @M.T # apply affine transform + corner_points = corner_points.reshape(num_gts, 8) + + # create new boxes + corner_xs = corner_points[:, 0::2] + corner_ys = corner_points[:, 1::2] + new_bboxes = (np.concatenate((corner_xs.min(1), corner_ys.min(1), + corner_xs.max(1), corner_ys.max(1))) + .reshape(4, num_gts).T) + + # clip boxes + new_bboxes[:, 0::2] = new_bboxes[:, 0::2].clip(0, twidth) + new_bboxes[:, 1::2] = new_bboxes[:, 1::2].clip(0, theight) + + labels[:, :4] = new_bboxes + + return labels + + def random_affine_augment(self, + img, + labels=[], + input_dim=[640, 640], + degrees=[-10, 10], + scales=[0.1, 2], + shears=[-2, 2], + translates=[-0.1, 0.1]): + # random rotation and scale + M, _ = self.get_affine_matrix( + input_dim=input_dim, + degrees=degrees, + scales=scales, + shears=shears, + translates=translates) # warpAffine img = cv2.warpAffine( img, M, dsize=tuple(input_dim), borderValue=(114, 114, 114)) - num_gts = len(labels) - if num_gts > 0: - # warp corner points - corner_points = np.ones((4 * num_gts, 3)) - corner_points[:, :2] = labels[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape( - 4 * num_gts, 2) # x1y1, x2y2, x1y2, x2y1 - # apply affine transform - corner_points = corner_points @M.T - corner_points = corner_points.reshape(num_gts, 8) - - # create new boxes - corner_xs = corner_points[:, 0::2] - corner_ys = corner_points[:, 1::2] - new_bboxes = np.concatenate((corner_xs.min(1), corner_ys.min(1), - corner_xs.max(1), corner_ys.max(1))) - new_bboxes = new_bboxes.reshape(4, num_gts).T - - # clip boxes - new_bboxes[:, 0::2] = np.clip(new_bboxes[:, 0::2], 0, input_dim[0]) - new_bboxes[:, 1::2] = np.clip(new_bboxes[:, 1::2], 0, input_dim[1]) - labels[:, :4] = new_bboxes + if len(labels) > 0: + labels = self.apply_affine_to_bboxes(labels, input_dim, M) return img, labels @@ -3574,6 +3602,7 @@ def __call__(self, sample, context=None): return sample[0] mosaic_gt_bbox, mosaic_gt_class, mosaic_is_crowd, mosaic_difficult = [], [], [], [] + mosaic_labels = [] input_h, input_w = self.input_dim yc = int(random.uniform(0.5 * input_h, 1.5 * input_h)) xc = int(random.uniform(0.5 * input_w, 1.5 * input_w)) @@ -3581,8 +3610,9 @@ def __call__(self, sample, context=None): # 1. get mosaic coords for mosaic_idx, sp in enumerate(sample[:4]): - img = sp['image'] - gt_bbox = sp['gt_bbox'] + img, gt_bbox, gt_class, is_crowd = sp['image'], sp['gt_bbox'], sp[ + 'gt_class'], sp['is_crowd'] + _labels = np.concatenate([gt_bbox, gt_class, is_crowd], 1) h0, w0 = img.shape[:2] scale = min(1. * input_h / h0, 1. * input_w / w0) img = cv2.resize( @@ -3598,49 +3628,26 @@ def __call__(self, sample, context=None): mosaic_img[l_y1:l_y2, l_x1:l_x2] = img[s_y1:s_y2, s_x1:s_x2] padw, padh = l_x1 - s_x1, l_y1 - s_y1 + labels = _labels.copy() # Normalized xywh to pixel xyxy format - _gt_bbox = gt_bbox.copy() - if len(gt_bbox) > 0: - _gt_bbox[:, 0] = scale * gt_bbox[:, 0] + padw - _gt_bbox[:, 1] = scale * gt_bbox[:, 1] + padh - _gt_bbox[:, 2] = scale * gt_bbox[:, 2] + padw - _gt_bbox[:, 3] = scale * gt_bbox[:, 3] + padh - - mosaic_gt_bbox.append(_gt_bbox) - mosaic_gt_class.append(sp['gt_class']) - if 'is_crowd' in sp: - mosaic_is_crowd.append(sp['is_crowd']) - if 'difficult' in sp: - mosaic_difficult.append(sp['difficult']) + if len(_labels) > 0: + labels[:, 0] = scale * _labels[:, 0] + padw + labels[:, 1] = scale * _labels[:, 1] + padh + labels[:, 2] = scale * _labels[:, 2] + padw + labels[:, 3] = scale * _labels[:, 3] + padh + + mosaic_labels.append(labels) # 2. clip bbox and get mosaic_labels([gt_bbox, gt_class, is_crowd]) - if len(mosaic_gt_bbox): - mosaic_gt_bbox = np.concatenate(mosaic_gt_bbox, 0) - mosaic_gt_class = np.concatenate(mosaic_gt_class, 0) - if mosaic_is_crowd: - mosaic_is_crowd = np.concatenate(mosaic_is_crowd, 0) - mosaic_labels = np.concatenate([ - mosaic_gt_bbox, - mosaic_gt_class.astype(mosaic_gt_bbox.dtype), - mosaic_is_crowd.astype(mosaic_gt_bbox.dtype) - ], 1) - elif mosaic_difficult: - mosaic_difficult = np.concatenate(mosaic_difficult, 0) - mosaic_labels = np.concatenate([ - mosaic_gt_bbox, - mosaic_gt_class.astype(mosaic_gt_bbox.dtype), - mosaic_difficult.astype(mosaic_gt_bbox.dtype) - ], 1) - else: - mosaic_labels = np.concatenate([ - mosaic_gt_bbox, mosaic_gt_class.astype(mosaic_gt_bbox.dtype) - ], 1) + if len(mosaic_labels): + mosaic_labels = np.concatenate(mosaic_labels, 0) + un_clip_mosaic_labels = np.copy(mosaic_labels) if self.remove_outside_box: # for MOT dataset - flag1 = mosaic_gt_bbox[:, 0] < 2 * input_w - flag2 = mosaic_gt_bbox[:, 2] > 0 - flag3 = mosaic_gt_bbox[:, 1] < 2 * input_h - flag4 = mosaic_gt_bbox[:, 3] > 0 + flag1 = mosaic_labels[:, 0] < 2 * input_w + flag2 = mosaic_labels[:, 2] > 0 + flag3 = mosaic_labels[:, 1] < 2 * input_h + flag4 = mosaic_labels[:, 3] > 0 flag_all = flag1 * flag2 * flag3 * flag4 mosaic_labels = mosaic_labels[flag_all] else: @@ -3652,14 +3659,12 @@ def __call__(self, sample, context=None): 2 * input_w) mosaic_labels[:, 3] = np.clip(mosaic_labels[:, 3], 0, 2 * input_h) - else: - mosaic_labels = np.zeros((1, 6)) # 3. random_affine augment mosaic_img, mosaic_labels = self.random_affine_augment( mosaic_img, mosaic_labels, - input_dim=self.input_dim, + input_dim=[input_w, input_h], degrees=self.degrees, translates=self.translate, scales=self.scale, @@ -3669,27 +3674,22 @@ def __call__(self, sample, context=None): # optinal, not used(enable_mixup=False) in tiny/nano if (self.enable_mixup and not len(mosaic_labels) == 0 and random.random() < self.mixup_prob): - sample_mixup = sample[4] - mixup_img = sample_mixup['image'] - if 'is_crowd' in sample_mixup: - cp_labels = np.concatenate([ - sample_mixup['gt_bbox'], - sample_mixup['gt_class'].astype(mosaic_labels.dtype), - sample_mixup['is_crowd'].astype(mosaic_labels.dtype) - ], 1) - elif 'difficult' in sample_mixup: - cp_labels = np.concatenate([ - sample_mixup['gt_bbox'], - sample_mixup['gt_class'].astype(mosaic_labels.dtype), - sample_mixup['difficult'].astype(mosaic_labels.dtype) - ], 1) - else: - cp_labels = np.concatenate([ - sample_mixup['gt_bbox'], - sample_mixup['gt_class'].astype(mosaic_labels.dtype) - ], 1) - mosaic_img, mosaic_labels = self.mixup_augment( - mosaic_img, mosaic_labels, self.input_dim, cp_labels, mixup_img) + mosaic_img, mosaic_labels, add_labels = self.mixup_augment( + mosaic_img, mosaic_labels, self.input_dim, sample[-1]) + un_clip_mosaic_labels = np.vstack( + (un_clip_mosaic_labels, add_labels)) + + # Only retain boxes with an IoU greater than the threshold after truncation + bbox_area = (mosaic_labels[:, 2] - mosaic_labels[:, 0]) * ( + mosaic_labels[:, 3] - mosaic_labels[:, 1]) + un_clip_bbox_area = ( + un_clip_mosaic_labels[:, 2] - un_clip_mosaic_labels[:, 0] + ) * (un_clip_mosaic_labels[:, 3] - un_clip_mosaic_labels[:, 1]) + 1e-8 + bbox_un_clip_bbox_iou = bbox_area / un_clip_bbox_area + bbox_image_iou = bbox_area / (self.input_dim[0] * self.input_dim[1]) + mosaic_labels = mosaic_labels[ + np.logical_or(bbox_un_clip_bbox_iou > self.iou_thresh_alpha, + bbox_image_iou > self.iou_thresh_beta)] sample0 = sample[0] sample0['image'] = mosaic_img.astype(np.uint8) # can not be float32 @@ -3701,12 +3701,12 @@ def __call__(self, sample, context=None): sample0['gt_class'] = mosaic_labels[:, 4:5].astype(np.float32) if 'is_crowd' in sample[0]: sample0['is_crowd'] = mosaic_labels[:, 5:6].astype(np.float32) - if 'difficult' in sample[0]: - sample0['difficult'] = mosaic_labels[:, 5:6].astype(np.float32) return sample0 - def mixup_augment(self, origin_img, origin_labels, input_dim, cp_labels, - img): + def mixup_augment(self, origin_img, origin_labels, input_dim, cp_sample): + img, gt_bbox, gt_class, is_crowd = cp_sample['image'], cp_sample[ + 'gt_bbox'], cp_sample['gt_class'], cp_sample['is_crowd'] + cp_labels = np.concatenate([gt_bbox, gt_class, is_crowd], 1) jit_factor = random.uniform(*self.mixup_scale) FLIP = random.uniform(0, 1) > 0.5 if len(img.shape) == 3: @@ -3786,7 +3786,7 @@ def mixup_augment(self, origin_img, origin_labels, input_dim, cp_labels, origin_img = 0.5 * origin_img + 0.5 * padded_cropped_img.astype( np.float32) - return origin_img.astype(np.uint8), origin_labels + return origin_img.astype(np.uint8), origin_labels, labels @register_op