Diff of /DESS/Augmentation.py [000000] .. [6a4082]

Switch to side-by-side view

--- a
+++ b/DESS/Augmentation.py
@@ -0,0 +1,241 @@
+# ==============================================================================
+# Copyright (C) 2023 Haresh Rengaraj Rajamohan, Tianyu Wang, Kevin Leung, 
+# Gregory Chang, Kyunghyun Cho, Richard Kijowski & Cem M. Deniz 
+#
+# This file is part of OAI-MRI-TKR
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published
+# by the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <https://www.gnu.org/licenses/>.
+# ==============================================================================
+import random
+import numpy as np
+import utils
+import math
+import h5py
+from scipy import ndimage as nd
+
+
+class Random_Rotation:
+
+
+    def __init__(self,image):
+        self.image = image
+
+    def RandomRotation(self,output_shape):
+        '''
+        The function generated a rotation matrix randomly such that the rotation axis is uniformly distributed
+        on a unit sphere and the angle is uniformally distributed.
+        :param output_shape: shape of the output image
+        :return:Randomly rotated 3D image
+        '''
+
+
+        rotation_matrix =utils.generating_random_rotation_matrix()
+        Output_image =self.__rotation__(output_shape, rotation_matrix)
+        return Output_image
+
+    def __rotation__(self,Output_shape,rotation_matrix):
+        '''
+        :param Output_shape: shape of the rotated image
+        :param rotation_matrix:
+        :return: Rotated 3D image based on rotation matrix
+        '''
+
+        image = self.image
+        h,w,d = image.shape
+        coordinate_j, coordinate_i, coordinate_k = np.meshgrid(
+            np.array(range(Output_shape[1])), np.array(range(Output_shape[0])),
+            np.array(range(Output_shape[2])))
+
+        center_target = np.array([int(sh/2) for sh in list(Output_shape)]).reshape(3,1)
+        center_source = np.array([int(h/2),int(w/2),int(d/2)]).reshape(3,1)
+        coordinate_init = np.array([coordinate_j.flatten(), coordinate_i.flatten(), coordinate_k.flatten()])
+        Rotation_matrix = rotation_matrix
+        coordinate = coordinate_init - np.matlib.repmat(center_target,1,coordinate_init.shape[1]) + np.matlib.repmat(np.matmul(Rotation_matrix,center_source),1,coordinate_init.shape[1])
+
+
+        mapped_to_source_coordinate = np.linalg.solve(Rotation_matrix,coordinate)
+        output_coordinate_value = nd.map_coordinates(input=image,coordinates=mapped_to_source_coordinate,cval = -1000,order = 4,mode = 'constant')
+
+        Output_image = -1000*np.ones(shape=Output_shape,dtype=float)
+
+        for k in range(coordinate_init.shape[1]):
+            Output_image[coordinate_init[0,k],coordinate_init[1,k],coordinate_init[2,k]] = output_coordinate_value[k]
+
+        return Output_image
+
+    def __random_rotation_matrix_fixing_rotation_axis__(self,axis = 0):
+
+        '''
+        :param axis: axis to fix
+        :return: randomly rotated 3D image such that the axis of rotation is fixed
+        '''
+
+        axis_of_rotation = [0.0, 0.0, 0.0]
+        axis_of_rotation[axis] = 1.0
+        angle_of_rotation = 2 * math.pi * random.uniform(0, 1)
+        Rotation_matrix = utils.angle_axis_to_rotation_matrix(angle=angle_of_rotation,axis=axis_of_rotation)
+        return Rotation_matrix
+
+    def RandomRotation_x_axis(self,Output_shape):
+
+        '''
+        :param Output_shape: shape of the output image
+        :return: rotated 3D image along the x axis
+        '''
+
+        h, w, d = self.image.shape
+        rotation_matrix = self.__random_rotation_matrix_fixing_rotation_axis__(axis = 0 )
+        Output_image = self.__rotation__(Output_shape, rotation_matrix)
+        return Output_image
+
+    def RandomRotation_y_axis(self,Output_shape):
+        '''
+        :param Output_shape: shape of the output image
+        :return: rotated 3D image along the y axis
+        '''
+
+        h, w, d = self.image.shape
+        rotation_matrix = self.__random_rotation_matrix_fixing_rotation_axis__(axis=1)
+        Output_image = self.__rotation__(Output_shape, rotation_matrix)
+        return Output_image
+
+    def RandomRotation_z_axis(self,Output_shape):
+        '''
+        :param Output_shape: shape of the output image
+        :return: rotated 3D image along the z axis
+        '''
+        h, w, d = self.image.shape
+        rotation_matrix = self.__random_rotation_matrix_fixing_rotation_axis__(axis=2)
+        print("rotation_matrix: ",rotation_matrix)
+        Output_image = self.__rotation__(Output_shape, rotation_matrix)
+        return Output_image
+
+
+class RandomCrop:
+
+    '''
+    Randomly Crop 3D image
+    '''
+
+    def __init__(self,image):
+        self.image = image
+        self.h,self.w,self.d = image.shape
+
+    def __functional__(self,size):
+        '''
+        :param size: crop size
+        :return: cropped 3D image
+        '''
+
+        h, w, d = self.image.shape
+        crop_h, crop_w, crop_d = size
+        i = random.randint(0, h - crop_h)
+        j = random.randint(0, w - crop_w)
+        k = random.randint(0, d - crop_d)
+        crop_image = self.image[i:i + crop_h, j:j + crop_w, k:k + crop_d]
+        return crop_image
+
+    def crop_along_hieght_width(self,crop_size):
+
+        crop = (crop_size[0],crop_size[2],self.d)
+        self.crop_image = self.__functional__(size=crop)
+        return self.crop_image
+
+    def crop_along_hieght_width_depth(self,crop_size):
+        self.crop_image = self.__functional__(size=crop_size)
+        return self.crop_image
+
+
+class CenterCrop:
+
+    '''
+    CenterCrop Images
+    '''
+
+    def __init__(self,image):
+        self.image = image
+        self.h,self.w,self.d = image.shape
+
+    def __functional__(self,size):
+        '''
+        :param size: crop Size
+        :return: Center Crop Images
+        '''
+        crop_h = int((self.h-size[0])/2)
+        crop_w = int((self.w-size[1])/2)
+        crop_d = int((self.d-size[2])/2)
+        return self.image[crop_h:crop_h+size[0],crop_w:crop_w+size[1],crop_d:crop_d+size[2]]
+
+    def crop(self,size):
+        return self.__functional__(size)
+
+
+class RandomFlip:
+
+    def __init__(self,image,p=0.5):
+        self.image = image
+        self.p = p
+        self.h,self.w,self.d = image.shape
+
+    def horizontal_flip(self,p=-1):
+        '''
+        :param p: probability of flip
+        :return: randomly horizontaly flipped image
+        '''
+        if p == -1:
+            p = self.p
+
+        integer = random.randint(0, 1)
+        if integer <= p:
+            output_image = self.image[:, -1:0:-1, :]
+        else:
+            output_image = self.image
+
+        return output_image
+
+    def vertical_flip(self,p=-1):
+
+        '''
+        :param p: probability of flip
+        :return: randomly vertically flipped image
+        '''
+
+        if p == -1:
+            p = self.p
+
+        integer = random.randint(0, 1)
+        if integer <= p:
+            output_image = np.flipud(self.image)
+        else:
+            output_image = self.image
+
+        return output_image
+    
+    def horizontal_flip(self,p=-1):
+
+        '''
+        :param p: probability of flip
+        :return: randomly vertically flipped image
+        '''
+
+        if p == -1:
+            p = self.p
+
+        integer = random.randint(0, 1)
+        if integer <= p:
+            output_image = np.fliplr(self.image)
+        else:
+            output_image = self.image
+
+        return output_image