a b/transform.py
1
import numpy as np
2
from skimage.transform import rescale, rotate
3
from torchvision.transforms import Compose
4
5
6
def transforms(scale=None, angle=None, flip_prob=None):
7
    transform_list = []
8
9
    if scale is not None:
10
        transform_list.append(Scale(scale))
11
    if angle is not None:
12
        transform_list.append(Rotate(angle))
13
    if flip_prob is not None:
14
        transform_list.append(HorizontalFlip(flip_prob))
15
16
    return Compose(transform_list)
17
18
19
class Scale(object):
20
21
    def __init__(self, scale):
22
        self.scale = scale
23
24
    def __call__(self, sample):
25
        image, mask = sample
26
27
        img_size = image.shape[0]
28
29
        scale = np.random.uniform(low=1.0 - self.scale, high=1.0 + self.scale)
30
31
        image = rescale(
32
            image,
33
            (scale, scale),
34
            multichannel=True,
35
            preserve_range=True,
36
            mode="constant",
37
            anti_aliasing=False,
38
        )
39
        mask = rescale(
40
            mask,
41
            (scale, scale),
42
            order=0,
43
            multichannel=True,
44
            preserve_range=True,
45
            mode="constant",
46
            anti_aliasing=False,
47
        )
48
49
        if scale < 1.0:
50
            diff = (img_size - image.shape[0]) / 2.0
51
            padding = ((int(np.floor(diff)), int(np.ceil(diff))),) * 2 + ((0, 0),)
52
            image = np.pad(image, padding, mode="constant", constant_values=0)
53
            mask = np.pad(mask, padding, mode="constant", constant_values=0)
54
        else:
55
            x_min = (image.shape[0] - img_size) // 2
56
            x_max = x_min + img_size
57
            image = image[x_min:x_max, x_min:x_max, ...]
58
            mask = mask[x_min:x_max, x_min:x_max, ...]
59
60
        return image, mask
61
62
63
class Rotate(object):
64
65
    def __init__(self, angle):
66
        self.angle = angle
67
68
    def __call__(self, sample):
69
        image, mask = sample
70
71
        angle = np.random.uniform(low=-self.angle, high=self.angle)
72
        image = rotate(image, angle, resize=False, preserve_range=True, mode="constant")
73
        mask = rotate(
74
            mask, angle, resize=False, order=0, preserve_range=True, mode="constant"
75
        )
76
        return image, mask
77
78
79
class HorizontalFlip(object):
80
81
    def __init__(self, flip_prob):
82
        self.flip_prob = flip_prob
83
84
    def __call__(self, sample):
85
        image, mask = sample
86
87
        if np.random.rand() > self.flip_prob:
88
            return image, mask
89
90
        image = np.fliplr(image).copy()
91
        mask = np.fliplr(mask).copy()
92
93
        return image, mask