Diff of /transform.py [000000] .. [9cc651]

Switch to side-by-side view

--- a
+++ b/transform.py
@@ -0,0 +1,93 @@
+import numpy as np
+from skimage.transform import rescale, rotate
+from torchvision.transforms import Compose
+
+
+def transforms(scale=None, angle=None, flip_prob=None):
+    transform_list = []
+
+    if scale is not None:
+        transform_list.append(Scale(scale))
+    if angle is not None:
+        transform_list.append(Rotate(angle))
+    if flip_prob is not None:
+        transform_list.append(HorizontalFlip(flip_prob))
+
+    return Compose(transform_list)
+
+
+class Scale(object):
+
+    def __init__(self, scale):
+        self.scale = scale
+
+    def __call__(self, sample):
+        image, mask = sample
+
+        img_size = image.shape[0]
+
+        scale = np.random.uniform(low=1.0 - self.scale, high=1.0 + self.scale)
+
+        image = rescale(
+            image,
+            (scale, scale),
+            multichannel=True,
+            preserve_range=True,
+            mode="constant",
+            anti_aliasing=False,
+        )
+        mask = rescale(
+            mask,
+            (scale, scale),
+            order=0,
+            multichannel=True,
+            preserve_range=True,
+            mode="constant",
+            anti_aliasing=False,
+        )
+
+        if scale < 1.0:
+            diff = (img_size - image.shape[0]) / 2.0
+            padding = ((int(np.floor(diff)), int(np.ceil(diff))),) * 2 + ((0, 0),)
+            image = np.pad(image, padding, mode="constant", constant_values=0)
+            mask = np.pad(mask, padding, mode="constant", constant_values=0)
+        else:
+            x_min = (image.shape[0] - img_size) // 2
+            x_max = x_min + img_size
+            image = image[x_min:x_max, x_min:x_max, ...]
+            mask = mask[x_min:x_max, x_min:x_max, ...]
+
+        return image, mask
+
+
+class Rotate(object):
+
+    def __init__(self, angle):
+        self.angle = angle
+
+    def __call__(self, sample):
+        image, mask = sample
+
+        angle = np.random.uniform(low=-self.angle, high=self.angle)
+        image = rotate(image, angle, resize=False, preserve_range=True, mode="constant")
+        mask = rotate(
+            mask, angle, resize=False, order=0, preserve_range=True, mode="constant"
+        )
+        return image, mask
+
+
+class HorizontalFlip(object):
+
+    def __init__(self, flip_prob):
+        self.flip_prob = flip_prob
+
+    def __call__(self, sample):
+        image, mask = sample
+
+        if np.random.rand() > self.flip_prob:
+            return image, mask
+
+        image = np.fliplr(image).copy()
+        mask = np.fliplr(mask).copy()
+
+        return image, mask