Diff of /T1-TSE/Augmentation.py [000000] .. [6a4082]

Switch to unified view

a b/T1-TSE/Augmentation.py
1
# ==============================================================================
2
# Copyright (C) 2023 Haresh Rengaraj Rajamohan, Tianyu Wang, Kevin Leung, 
3
# Gregory Chang, Kyunghyun Cho, Richard Kijowski & Cem M. Deniz 
4
#
5
# This file is part of OAI-MRI-TKR
6
#
7
# This program is free software: you can redistribute it and/or modify
8
# it under the terms of the GNU Affero General Public License as published
9
# by the Free Software Foundation, either version 3 of the License, or
10
# (at your option) any later version.
11
12
# This program is distributed in the hope that it will be useful,
13
# but WITHOUT ANY WARRANTY; without even the implied warranty of
14
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15
# GNU Affero General Public License for more details.
16
17
# You should have received a copy of the GNU Affero General Public License
18
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
19
# ==============================================================================
20
import random
21
import numpy as np
22
import utils
23
import math
24
import h5py
25
from scipy import ndimage as nd
26
27
28
class Random_Rotation:
29
30
31
    def __init__(self,image):
32
        self.image = image
33
34
    def RandomRotation(self,output_shape):
35
        '''
36
        The function generated a rotation matrix randomly such that the rotation axis is uniformly distributed
37
        on a unit sphere and the angle is uniformally distributed.
38
        :param output_shape: shape of the output image
39
        :return:Randomly rotated 3D image
40
        '''
41
42
43
        rotation_matrix =utils.generating_random_rotation_matrix()
44
        Output_image =self.__rotation__(output_shape, rotation_matrix)
45
        return Output_image
46
47
    def __rotation__(self,Output_shape,rotation_matrix):
48
        '''
49
        :param Output_shape: shape of the rotated image
50
        :param rotation_matrix:
51
        :return: Rotated 3D image based on rotation matrix
52
        '''
53
54
        image = self.image
55
        h,w,d = image.shape
56
        coordinate_j, coordinate_i, coordinate_k = np.meshgrid(
57
            np.array(range(Output_shape[1])), np.array(range(Output_shape[0])),
58
            np.array(range(Output_shape[2])))
59
60
        center_target = np.array([int(sh/2) for sh in list(Output_shape)]).reshape(3,1)
61
        center_source = np.array([int(h/2),int(w/2),int(d/2)]).reshape(3,1)
62
        coordinate_init = np.array([coordinate_j.flatten(), coordinate_i.flatten(), coordinate_k.flatten()])
63
        Rotation_matrix = rotation_matrix
64
        coordinate = coordinate_init - np.matlib.repmat(center_target,1,coordinate_init.shape[1]) + np.matlib.repmat(np.matmul(Rotation_matrix,center_source),1,coordinate_init.shape[1])
65
66
67
        mapped_to_source_coordinate = np.linalg.solve(Rotation_matrix,coordinate)
68
        output_coordinate_value = nd.map_coordinates(input=image,coordinates=mapped_to_source_coordinate,cval = -1000,order = 4,mode = 'constant')
69
70
        Output_image = -1000*np.ones(shape=Output_shape,dtype=float)
71
72
        for k in range(coordinate_init.shape[1]):
73
            Output_image[coordinate_init[0,k],coordinate_init[1,k],coordinate_init[2,k]] = output_coordinate_value[k]
74
75
        return Output_image
76
77
    def __random_rotation_matrix_fixing_rotation_axis__(self,axis = 0):
78
79
        '''
80
        :param axis: axis to fix
81
        :return: randomly rotated 3D image such that the axis of rotation is fixed
82
        '''
83
84
        axis_of_rotation = [0.0, 0.0, 0.0]
85
        axis_of_rotation[axis] = 1.0
86
        angle_of_rotation = 2 * math.pi * random.uniform(0, 1)
87
        Rotation_matrix = utils.angle_axis_to_rotation_matrix(angle=angle_of_rotation,axis=axis_of_rotation)
88
        return Rotation_matrix
89
90
    def RandomRotation_x_axis(self,Output_shape):
91
92
        '''
93
        :param Output_shape: shape of the output image
94
        :return: rotated 3D image along the x axis
95
        '''
96
97
        h, w, d = self.image.shape
98
        rotation_matrix = self.__random_rotation_matrix_fixing_rotation_axis__(axis = 0 )
99
        Output_image = self.__rotation__(Output_shape, rotation_matrix)
100
        return Output_image
101
102
    def RandomRotation_y_axis(self,Output_shape):
103
        '''
104
        :param Output_shape: shape of the output image
105
        :return: rotated 3D image along the y axis
106
        '''
107
108
        h, w, d = self.image.shape
109
        rotation_matrix = self.__random_rotation_matrix_fixing_rotation_axis__(axis=1)
110
        Output_image = self.__rotation__(Output_shape, rotation_matrix)
111
        return Output_image
112
113
    def RandomRotation_z_axis(self,Output_shape):
114
        '''
115
        :param Output_shape: shape of the output image
116
        :return: rotated 3D image along the z axis
117
        '''
118
        h, w, d = self.image.shape
119
        rotation_matrix = self.__random_rotation_matrix_fixing_rotation_axis__(axis=2)
120
        print("rotation_matrix: ",rotation_matrix)
121
        Output_image = self.__rotation__(Output_shape, rotation_matrix)
122
        return Output_image
123
124
125
class RandomCrop:
126
127
    '''
128
    Randomly Crop 3D image
129
    '''
130
131
    def __init__(self,image):
132
        self.image = image
133
        self.h,self.w,self.d = image.shape
134
135
    def __functional__(self,size):
136
        '''
137
        :param size: crop size
138
        :return: cropped 3D image
139
        '''
140
141
        h, w, d = self.image.shape
142
        crop_h, crop_w, crop_d = size
143
        i = random.randint(0, h - crop_h)
144
        j = random.randint(0, w - crop_w)
145
        k = random.randint(0, d - crop_d)
146
        crop_image = self.image[i:i + crop_h, j:j + crop_w, k:k + crop_d]
147
        return crop_image
148
149
    def crop_along_hieght_width(self,crop_size):
150
151
        crop = (crop_size[0],crop_size[2],self.d)
152
        self.crop_image = self.__functional__(size=crop)
153
        return self.crop_image
154
155
    def crop_along_hieght_width_depth(self,crop_size):
156
        self.crop_image = self.__functional__(size=crop_size)
157
        return self.crop_image
158
159
160
class CenterCrop:
161
162
    '''
163
    CenterCrop Images
164
    '''
165
166
    def __init__(self,image):
167
        self.image = image
168
        self.h,self.w,self.d = image.shape
169
170
    def __functional__(self,size):
171
        '''
172
        :param size: crop Size
173
        :return: Center Crop Images
174
        '''
175
        crop_h = int((self.h-size[0])/2)
176
        crop_w = int((self.w-size[1])/2)
177
        crop_d = int((self.d-size[2])/2)
178
        return self.image[crop_h:crop_h+size[0],crop_w:crop_w+size[1],crop_d:crop_d+size[2]]
179
180
    def crop(self,size):
181
        return self.__functional__(size)
182
183
184
class RandomFlip:
185
186
    def __init__(self,image,p=0.5):
187
        self.image = image
188
        self.p = p
189
        self.h,self.w,self.d = image.shape
190
191
    def horizontal_flip(self,p=-1):
192
        '''
193
        :param p: probability of flip
194
        :return: randomly horizontaly flipped image
195
        '''
196
        if p == -1:
197
            p = self.p
198
199
        integer = random.randint(0, 1)
200
        if integer <= p:
201
            output_image = self.image[:, -1:0:-1, :]
202
        else:
203
            output_image = self.image
204
205
        return output_image
206
207
    def vertical_flip(self,p=-1):
208
209
        '''
210
        :param p: probability of flip
211
        :return: randomly vertically flipped image
212
        '''
213
214
        if p == -1:
215
            p = self.p
216
217
        integer = random.randint(0, 1)
218
        if integer <= p:
219
            output_image = np.flipud(self.image)
220
        else:
221
            output_image = self.image
222
223
        return output_image
224
    
225
    def horizontal_flip(self,p=-1):
226
227
        '''
228
        :param p: probability of flip
229
        :return: randomly vertically flipped image
230
        '''
231
232
        if p == -1:
233
            p = self.p
234
235
        integer = random.randint(0, 1)
236
        if integer <= p:
237
            output_image = np.fliplr(self.image)
238
        else:
239
            output_image = self.image
240
241
        return output_image