Diff of /utils/augmentations.py [000000] .. [f2ca4d]

Switch to unified view

a b/utils/augmentations.py
1
import numpy as np
2
import nibabel as nib
3
import scipy.ndimage
4
import warnings
5
import PP
6
import sys
7
8
#---------------------------------------------
9
#Functions for image augmentations on 3D input
10
#---------------------------------------------
11
12
#img_b, label_b is (batch_num) x 1 x dim1 x dim2 x dim3
13
#takes in a list of 3D images (1st one is input, 2nd one needs to be label)
14
def augmentPatchLossy(imgs, rotation=[5,5,5], scale_min=0.9, scale_max=1.1, flip_lvl = 0):
15
    new_imgs = []
16
17
    rot_x = np.random.uniform(-rotation[0], rotation[0]) * np.pi / 180.0
18
    rot_y = np.random.uniform(-rotation[1], rotation[1]) * np.pi / 180.0
19
    rot_z = np.random.uniform(-rotation[2], rotation[2]) * np.pi / 180.0
20
21
    zoom_val = np.random.uniform(scale_min, scale_max)
22
    for i in range(len(imgs)):
23
        l = convertBatchToList(imgs[i])
24
        if i == 0:
25
            spline_orders = [3] * len(l)
26
        else:
27
            spline_orders = [0] * len(l)
28
        scaled = applyScale(l, zoom_val, spline_orders)
29
        rotated = applyRotation(scaled, [rot_x, rot_y, rot_z], spline_orders)
30
        new_imgs.append(convertListToBatch(rotated))
31
    return imgs
32
33
def convertBatchToList(img):
34
    l = []
35
    b, c, d1, d2, d3 = img.shape
36
    for i in range(img.shape[0]):
37
        l.append(img[i,:,:,:,:].reshape([1,c,d1,d2,d3]))
38
    return l
39
40
def convertListToBatch(img_list):
41
    b, c, d1, d2, d3 = img_list[0].shape
42
    a = np.zeros([len(img_list), c, d1,d2,d3])
43
    for i in range(len(img_list)):
44
        a[i,:,:,:,:] = img_list[i]
45
    return a
46
47
def augmentPatchLossLess(imgs):
48
    new_imgs = []
49
50
    p = np.random.rand(3) > 0.5
51
    locations = np.where(p == 1)[0] + 2
52
53
    for i in range(len(imgs)):
54
        l = convertBatchToList(imgs[i])
55
        if i == 0:
56
            spline_orders = [3] * len(l)
57
        else:
58
            spline_orders = [0] * len(l)
59
        flipped = applyFLIPS2(l, locations)
60
61
        rot_x = np.random.randint(4) * np.pi / 2.0 # (0,1,2,3)*90/180.0
62
        rot_y = np.random.randint(4) * np.pi / 2.0 # (0,1,2,3)*90/180.0
63
        rot_z = np.random.randint(4) * np.pi / 2.0 # (0,1,2,3)*90/180.0
64
        rotated = applyRotation(flipped, [rot_x, rot_y, rot_z], spline_orders)
65
        new_imgs.append(convertListToBatch(rotated))
66
    return new_imgs
67
68
def augmentBoth(imgs):
69
    imgs = augmentPatchLossy(imgs)
70
    imgs = augmentPatchLessLess(imgs)
71
    return imgs
72
73
def getRotationVal(rotation=[5,5,5]):
74
    rot_x = np.random.uniform(-rotation[0], rotation[0]) * np.pi / 180.0
75
    rot_y = np.random.uniform(-rotation[1], rotation[1]) * np.pi / 180.0
76
    rot_z = np.random.uniform(-rotation[2], rotation[2]) * np.pi / 180.0
77
    return rot_x, rot_y, rot_z
78
79
def getScalingVal(scale_min = 0.9, scale_max = 1.1):
80
    return np.random.uniform(scale_min, scale_max)
81
82
def applyFLIPS(images, flip_lvl = 0):
83
    if flip_lvl == 0:
84
        p = np.random.rand(2) > 0.5
85
    else:
86
        p = np.random.rand(3) > 0.5
87
    locations = np.where(p == 1)[0] + 2
88
89
    new_imgs = []
90
    for img in images:
91
        for i in locations:
92
            img = np.flip(img, axis=i)
93
        new_imgs.append(img)
94
    return new_imgs
95
96
def applyFLIPS2(images, locations):
97
    new_imgs = []
98
    for img in images:
99
        for i in locations:
100
            img = np.flip(img, axis=i)
101
        new_imgs.append(img)
102
    return new_imgs
103
104
def applyRotation(images, rot, spline_orders):
105
    transform_x = np.array([[1.0,               0.0,            0.0],
106
                            [0.0,               np.cos(rot[0]), -np.sin(rot[0])],
107
                            [0.0,               np.sin(rot[0]), np.cos(rot[0])]])
108
109
    transform_y = np.array([[np.cos(rot[1]),    0.0,            np.sin(rot[1])],
110
                            [0.0,               1.0,            0.0],
111
                            [-np.sin(rot[1]),   0.0,            np.cos(rot[1])]])
112
113
    transform_z = np.array([[np.cos(rot[2]),    -np.sin(rot[2]),    0.0],
114
                            [np.sin(rot[2]),    np.cos(rot[2]),     0.0],
115
                            [0.0,               0,                  1]])
116
    transform = np.dot(transform_z, np.dot(transform_x, transform_y))
117
118
    new_imgs = []
119
    for i, img in enumerate(images):
120
        mid_index = 0.5 * np.asarray(img.squeeze().shape, dtype=np.int64)
121
        offset = mid_index - mid_index.dot(np.linalg.inv(transform))
122
        new_img = scipy.ndimage.affine_transform(
123
                                            input = img.squeeze(), 
124
                                            matrix = transform, 
125
                                            offset = offset, 
126
                                            order = spline_orders[i],
127
                                            mode = 'nearest')
128
        new_img = new_img[np.newaxis,np.newaxis,:]
129
        new_imgs.append(new_img)
130
    return new_imgs
131
132
def applyScale(images, zoom_val, spline_orders):
133
    new_imgs = []
134
    for i, img in enumerate(images):
135
        with warnings.catch_warnings():
136
            warnings.simplefilter("ignore")
137
            try:
138
                new_img = scipy.ndimage.zoom(img.squeeze(), zoom_val, order = spline_orders[i])
139
                new_img = new_img[np.newaxis,np.newaxis,:]
140
                new_imgs.append(new_img)
141
            except:
142
                pass
143
    return new_imgs