|
a |
|
b/transform.py |
|
|
1 |
import numpy as np |
|
|
2 |
from skimage.transform import rescale, rotate |
|
|
3 |
from torchvision.transforms import Compose |
|
|
4 |
|
|
|
5 |
|
|
|
6 |
def transforms(scale=None, angle=None, flip_prob=None): |
|
|
7 |
transform_list = [] |
|
|
8 |
|
|
|
9 |
if scale is not None: |
|
|
10 |
transform_list.append(Scale(scale)) |
|
|
11 |
if angle is not None: |
|
|
12 |
transform_list.append(Rotate(angle)) |
|
|
13 |
if flip_prob is not None: |
|
|
14 |
transform_list.append(HorizontalFlip(flip_prob)) |
|
|
15 |
|
|
|
16 |
return Compose(transform_list) |
|
|
17 |
|
|
|
18 |
|
|
|
19 |
class Scale(object): |
|
|
20 |
|
|
|
21 |
def __init__(self, scale): |
|
|
22 |
self.scale = scale |
|
|
23 |
|
|
|
24 |
def __call__(self, sample): |
|
|
25 |
image, mask = sample |
|
|
26 |
|
|
|
27 |
img_size = image.shape[0] |
|
|
28 |
|
|
|
29 |
scale = np.random.uniform(low=1.0 - self.scale, high=1.0 + self.scale) |
|
|
30 |
|
|
|
31 |
image = rescale( |
|
|
32 |
image, |
|
|
33 |
(scale, scale), |
|
|
34 |
multichannel=True, |
|
|
35 |
preserve_range=True, |
|
|
36 |
mode="constant", |
|
|
37 |
anti_aliasing=False, |
|
|
38 |
) |
|
|
39 |
mask = rescale( |
|
|
40 |
mask, |
|
|
41 |
(scale, scale), |
|
|
42 |
order=0, |
|
|
43 |
multichannel=True, |
|
|
44 |
preserve_range=True, |
|
|
45 |
mode="constant", |
|
|
46 |
anti_aliasing=False, |
|
|
47 |
) |
|
|
48 |
|
|
|
49 |
if scale < 1.0: |
|
|
50 |
diff = (img_size - image.shape[0]) / 2.0 |
|
|
51 |
padding = ((int(np.floor(diff)), int(np.ceil(diff))),) * 2 + ((0, 0),) |
|
|
52 |
image = np.pad(image, padding, mode="constant", constant_values=0) |
|
|
53 |
mask = np.pad(mask, padding, mode="constant", constant_values=0) |
|
|
54 |
else: |
|
|
55 |
x_min = (image.shape[0] - img_size) // 2 |
|
|
56 |
x_max = x_min + img_size |
|
|
57 |
image = image[x_min:x_max, x_min:x_max, ...] |
|
|
58 |
mask = mask[x_min:x_max, x_min:x_max, ...] |
|
|
59 |
|
|
|
60 |
return image, mask |
|
|
61 |
|
|
|
62 |
|
|
|
63 |
class Rotate(object): |
|
|
64 |
|
|
|
65 |
def __init__(self, angle): |
|
|
66 |
self.angle = angle |
|
|
67 |
|
|
|
68 |
def __call__(self, sample): |
|
|
69 |
image, mask = sample |
|
|
70 |
|
|
|
71 |
angle = np.random.uniform(low=-self.angle, high=self.angle) |
|
|
72 |
image = rotate(image, angle, resize=False, preserve_range=True, mode="constant") |
|
|
73 |
mask = rotate( |
|
|
74 |
mask, angle, resize=False, order=0, preserve_range=True, mode="constant" |
|
|
75 |
) |
|
|
76 |
return image, mask |
|
|
77 |
|
|
|
78 |
|
|
|
79 |
class HorizontalFlip(object): |
|
|
80 |
|
|
|
81 |
def __init__(self, flip_prob): |
|
|
82 |
self.flip_prob = flip_prob |
|
|
83 |
|
|
|
84 |
def __call__(self, sample): |
|
|
85 |
image, mask = sample |
|
|
86 |
|
|
|
87 |
if np.random.rand() > self.flip_prob: |
|
|
88 |
return image, mask |
|
|
89 |
|
|
|
90 |
image = np.fliplr(image).copy() |
|
|
91 |
mask = np.fliplr(mask).copy() |
|
|
92 |
|
|
|
93 |
return image, mask |