--- a +++ b/libs/datasets/joint_augment.py @@ -0,0 +1,336 @@ + +import numpy as np +import random +import numbers +from PIL import Image, ImageEnhance, ImageOps +from .augment import to_pil_image +import torch +from torchvision.transforms import functional as F + + +class Compose(): + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, img, mask): + assert img.size == mask.size + for t in self.transforms: + img, mask = t(img, mask) + return img, mask + +class To_Tensor(): + def __call__(self, arr, arr2): + if len(np.array(arr).shape) == 2: + arr = np.array(arr)[:,:,None] + if len(np.array(arr2).shape) == 2: + arr2 = np.array(arr2)[:, :, None] + arr = torch.from_numpy(np.array(arr).transpose(2,0,1)) + arr2 = torch.from_numpy(np.array(arr2).transpose(2,0,1)) + return arr, arr2 + +class To_PIL_Image(): + def __call__(self, img, mask): + img = to_pil_image(img) + mask = to_pil_image(mask) + return img, mask + +class RandomVerticalFlip(): + def __init__(self, prob): + self.prob = prob + + def __call__(self, img, mask): + if random.random() < self.prob: + if isinstance(img, Image.Image): + return img.transpose(Image.FLIP_TOP_BOTTOM), mask.transpose(Image.FLIP_TOP_BOTTOM) + if isinstance(img, np.ndarray): + return np.flip(img, axis=0), np.flip(mask, axis=0) + return img, mask + +class RandomHorizontallyFlip(): + def __init__(self, prob=0.5): + self.prob = prob + + def __call__(self, img, mask): + if random.random() < self.prob: + if isinstance(img, Image.Image): + return img.transpose(Image.FLIP_LEFT_RIGHT), mask.transpose(Image.FLIP_LEFT_RIGHT) + if isinstance(img, np.ndarray): + return np.flip(img, axis=1), np.flip(mask, axis=1) + return img, mask + +class RandomRotate(): + def __init__(self, degrees, prob=0.5): + self.prob = prob + self.degrees = degrees + + def __call__(self, img, mask): + if random.random() < self.prob: + rotate_detree = random.uniform(self.degrees[0], self.degrees[1]) + return img.rotate(rotate_detree, Image.BILINEAR), mask.rotate(rotate_detree, Image.NEAREST) + return img, mask + +class FixResize(): + def __init__(self, size): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + + def __call__(self, img=None, mask=None): + if mask is None: + return img.resize(self.size, Image.BILINEAR) + if img is None: + return mask.resize(self.size, Image.NEAREST) + return img.resize(self.size, Image.BILINEAR), mask.resize(self.size, Image.NEAREST) + +class Scale(object): + def __init__(self, size): + self.size = size + + def __call__(self, img, mask): + assert img.size == mask.size + w, h = img.size + if (w >= h and w == self.size) or (h >= w and h == self.size): + return img, mask + if w > h: + ow = self.size + oh = int(self.size * h / w) + return img.resize((ow, oh), Image.BILINEAR), mask.resize((ow, oh), Image.NEAREST) + else: + oh = self.size + ow = int(self.size * w / h) + return img.resize((ow, oh), Image.BILINEAR), mask.resize((ow, oh), Image.NEAREST) + +class RandomCrop(object): + def __init__(self, size, padding=0): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + self.padding = padding + + def __call__(self, img, mask): + if self.padding > 0: + img = ImageOps.expand(img, border=self.padding, fill=0) + mask = ImageOps.expand(mask, border=self.padding, fill=0) + + assert img.size == mask.size + w, h = img.size + th, tw = self.size + if w == tw and h == th: + return img, mask + if w < tw or h < th: + return img.resize((tw, th), Image.BILINEAR), mask.resize((tw, th), Image.NEAREST) + + x1 = random.randint(0, w - tw) + y1 = random.randint(0, h - th) + return img.crop((x1, y1, x1 + tw, y1 + th)), mask.crop((x1, y1, x1 + tw, y1 + th)) + +class RandomSized(object): + def __init__(self, size): + self.size = size + self.scale = Scale(self.size) + self.crop = RandomCrop(self.size) + + def __call__(self, img, mask): + assert img.size == mask.size + + w = int(random.uniform(0.5, 2) * img.size[0]) + h = int(random.uniform(0.5, 2) * img.size[1]) + + img, mask = img.resize((w, h), Image.BILINEAR), mask.resize((w, h), Image.NEAREST) + + return self.crop(*self.scale(img, mask)) + +class ScaleRatio(): + def __init__(self, scale_factor=1): + self.scale_factor = scale_factor + + def __call__(self, img, interpolation): + w, h = img.size + new_h = int(h * self.scale_factor) + new_w = int(w * self.scale_factor) + return img.resize((new_w, new_h), interpolation) + +class RandomScale(): + def __init__(self, min_factor=0.8, max_factor=1.2, prob=0.5): + self.min_factor = min_factor + self.max_factor = max_factor + self.prob = prob + + def __scale(self, img, scale_factor, interpolation): + w, h = img.size + new_h = int(h * scale_factor) + new_w = int(w * scale_factor) + return img.resize((new_w, new_h), interpolation) + + def __call__(self, img, mask): + if random.random() < self.prob: + factor = np.random.uniform(self.min_factor, self.max_factor) + return self.__scale(img, factor, Image.BILINEAR), self.__scale(mask, factor, Image.NEAREST) + return img, mask + + +class Resize(object): + def __init__(self, min_size, max_size): + if not isinstance(min_size, (list, tuple)): + min_size = (min_size,) + self.min_size = min_size + self.max_size = max_size + + # modified from torchvision to add support for max size + def get_size(self, image_size): + w, h = image_size + size = random.choice(self.min_size) + max_size = self.max_size + if max_size is not None: + min_original_size = float(min((w, h))) + max_original_size = float(max((w, h))) + if max_original_size / min_original_size * size > max_size: + size = int(round(max_size * min_original_size / max_original_size)) + + if (w <= h and w == size) or (h <= w and h == size): + return (h, w) + + if w < h: + ow = size + oh = int(size * h / w) + else: + oh = size + ow = int(size * w / h) + + return (oh, ow) + + def __call__(self, image, target=None): + size = self.get_size(image.size) + image = F.resize(image, size) + if isinstance(target, list): + target = [t.resize(image.size) for t in target] + elif target is None: + return image + else: + target = target.resize(image.size, Image.NEAREST) + return image, target + + +class RandomAffine(object): + """Random affine transformation of the image keeping center invariant + + Args: + degrees (sequence or float or int): Range of degrees to select from. + If degrees is a number instead of sequence like (min, max), the range of degrees + will be (-degrees, +degrees). Set to 0 to desactivate rotations. + translate (tuple, optional): tuple of maximum absolute fraction for horizontal + and vertical translations. For example translate=(a, b), then horizontal shift + is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is + randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default. + scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is + randomly sampled from the range a <= scale <= b. Will keep original scale by default. + shear (sequence or float or int, optional): Range of degrees to select from. + If degrees is a number instead of sequence like (min, max), the range of degrees + will be (-degrees, +degrees). Will not apply shear by default + resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional): + An optional resampling filter. + See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters + If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. + fillcolor (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0) + """ + + def __init__(self, degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0, prob=0.5): + self.prob = prob + if isinstance(degrees, numbers.Number): + if degrees < 0: + raise ValueError("If degrees is a single number, it must be positive.") + self.degrees = (-degrees, degrees) + else: + assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \ + "degrees should be a list or tuple and it must be of length 2." + self.degrees = degrees + + if translate is not None: + assert isinstance(translate, (tuple, list)) and len(translate) == 2, \ + "translate should be a list or tuple and it must be of length 2." + for t in translate: + if not (0.0 <= t <= 1.0): + raise ValueError("translation values should be between 0 and 1") + self.translate = translate + + if scale is not None: + assert isinstance(scale, (tuple, list)) and len(scale) == 2, \ + "scale should be a list or tuple and it must be of length 2." + for s in scale: + if s <= 0: + raise ValueError("scale values should be positive") + self.scale = scale + + if shear is not None: + if isinstance(shear, numbers.Number): + if shear < 0: + raise ValueError("If shear is a single number, it must be positive.") + self.shear = (-shear, shear) + else: + assert isinstance(shear, (tuple, list)) and len(shear) == 2, \ + "shear should be a list or tuple and it must be of length 2." + self.shear = shear + else: + self.shear = shear + + self.resample = resample + self.fillcolor = fillcolor + + @staticmethod + def get_params(degrees, translate, scale_ranges, shears, img_size): + """Get parameters for affine transformation + + Returns: + sequence: params to be passed to the affine transformation + """ + angle = random.uniform(degrees[0], degrees[1]) + if translate is not None: + max_dx = translate[0] * img_size[0] + max_dy = translate[1] * img_size[1] + translations = (np.round(random.uniform(-max_dx, max_dx)), + np.round(random.uniform(-max_dy, max_dy))) + else: + translations = (0, 0) + + if scale_ranges is not None: + scale = random.uniform(scale_ranges[0], scale_ranges[1]) + else: + scale = 1.0 + + if shears is not None: + shear = random.uniform(shears[0], shears[1]) + else: + shear = 0.0 + + return angle, translations, scale, shear + + def __call__(self, img, mask): + """ + img (PIL Image): Image to be transformed. + + Returns: + PIL Image: Affine transformed image. + """ + if random.random() < self.prob: + ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img.size) + return F.affine(img, *ret, resample=self.resample, fillcolor=self.fillcolor), F.affine(mask, *ret, resample=self.resample, fillcolor=self.fillcolor) + return img, mask + + def __repr__(self): + s = '{name}(degrees={degrees}' + if self.translate is not None: + s += ', translate={translate}' + if self.scale is not None: + s += ', scale={scale}' + if self.shear is not None: + s += ', shear={shear}' + if self.resample > 0: + s += ', resample={resample}' + if self.fillcolor != 0: + s += ', fillcolor={fillcolor}' + s += ')' + d = dict(self.__dict__) + d['resample'] = _pil_interpolation_to_str[d['resample']] + return s.format(name=self.__class__.__name__, **d)