--- a +++ b/IW-TSE/Augmentation.py @@ -0,0 +1,241 @@ +# ============================================================================== +# Copyright (C) 2023 Haresh Rengaraj Rajamohan, Tianyu Wang, Kevin Leung, +# Gregory Chang, Kyunghyun Cho, Richard Kijowski & Cem M. Deniz +# +# This file is part of OAI-MRI-TKR +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see <https://www.gnu.org/licenses/>. +# ============================================================================== +import random +import numpy as np +import utils +import math +import h5py +from scipy import ndimage as nd + + +class Random_Rotation: + + + def __init__(self,image): + self.image = image + + def RandomRotation(self,output_shape): + ''' + The function generated a rotation matrix randomly such that the rotation axis is uniformly distributed + on a unit sphere and the angle is uniformally distributed. + :param output_shape: shape of the output image + :return:Randomly rotated 3D image + ''' + + + rotation_matrix =utils.generating_random_rotation_matrix() + Output_image =self.__rotation__(output_shape, rotation_matrix) + return Output_image + + def __rotation__(self,Output_shape,rotation_matrix): + ''' + :param Output_shape: shape of the rotated image + :param rotation_matrix: + :return: Rotated 3D image based on rotation matrix + ''' + + image = self.image + h,w,d = image.shape + coordinate_j, coordinate_i, coordinate_k = np.meshgrid( + np.array(range(Output_shape[1])), np.array(range(Output_shape[0])), + np.array(range(Output_shape[2]))) + + center_target = np.array([int(sh/2) for sh in list(Output_shape)]).reshape(3,1) + center_source = np.array([int(h/2),int(w/2),int(d/2)]).reshape(3,1) + coordinate_init = np.array([coordinate_j.flatten(), coordinate_i.flatten(), coordinate_k.flatten()]) + Rotation_matrix = rotation_matrix + coordinate = coordinate_init - np.matlib.repmat(center_target,1,coordinate_init.shape[1]) + np.matlib.repmat(np.matmul(Rotation_matrix,center_source),1,coordinate_init.shape[1]) + + + mapped_to_source_coordinate = np.linalg.solve(Rotation_matrix,coordinate) + output_coordinate_value = nd.map_coordinates(input=image,coordinates=mapped_to_source_coordinate,cval = -1000,order = 4,mode = 'constant') + + Output_image = -1000*np.ones(shape=Output_shape,dtype=float) + + for k in range(coordinate_init.shape[1]): + Output_image[coordinate_init[0,k],coordinate_init[1,k],coordinate_init[2,k]] = output_coordinate_value[k] + + return Output_image + + def __random_rotation_matrix_fixing_rotation_axis__(self,axis = 0): + + ''' + :param axis: axis to fix + :return: randomly rotated 3D image such that the axis of rotation is fixed + ''' + + axis_of_rotation = [0.0, 0.0, 0.0] + axis_of_rotation[axis] = 1.0 + angle_of_rotation = 2 * math.pi * random.uniform(0, 1) + Rotation_matrix = utils.angle_axis_to_rotation_matrix(angle=angle_of_rotation,axis=axis_of_rotation) + return Rotation_matrix + + def RandomRotation_x_axis(self,Output_shape): + + ''' + :param Output_shape: shape of the output image + :return: rotated 3D image along the x axis + ''' + + h, w, d = self.image.shape + rotation_matrix = self.__random_rotation_matrix_fixing_rotation_axis__(axis = 0 ) + Output_image = self.__rotation__(Output_shape, rotation_matrix) + return Output_image + + def RandomRotation_y_axis(self,Output_shape): + ''' + :param Output_shape: shape of the output image + :return: rotated 3D image along the y axis + ''' + + h, w, d = self.image.shape + rotation_matrix = self.__random_rotation_matrix_fixing_rotation_axis__(axis=1) + Output_image = self.__rotation__(Output_shape, rotation_matrix) + return Output_image + + def RandomRotation_z_axis(self,Output_shape): + ''' + :param Output_shape: shape of the output image + :return: rotated 3D image along the z axis + ''' + h, w, d = self.image.shape + rotation_matrix = self.__random_rotation_matrix_fixing_rotation_axis__(axis=2) + print("rotation_matrix: ",rotation_matrix) + Output_image = self.__rotation__(Output_shape, rotation_matrix) + return Output_image + + +class RandomCrop: + + ''' + Randomly Crop 3D image + ''' + + def __init__(self,image): + self.image = image + self.h,self.w,self.d = image.shape + + def __functional__(self,size): + ''' + :param size: crop size + :return: cropped 3D image + ''' + + h, w, d = self.image.shape + crop_h, crop_w, crop_d = size + i = random.randint(0, h - crop_h) + j = random.randint(0, w - crop_w) + k = random.randint(0, d - crop_d) + crop_image = self.image[i:i + crop_h, j:j + crop_w, k:k + crop_d] + return crop_image + + def crop_along_hieght_width(self,crop_size): + + crop = (crop_size[0],crop_size[2],self.d) + self.crop_image = self.__functional__(size=crop) + return self.crop_image + + def crop_along_hieght_width_depth(self,crop_size): + self.crop_image = self.__functional__(size=crop_size) + return self.crop_image + + +class CenterCrop: + + ''' + CenterCrop Images + ''' + + def __init__(self,image): + self.image = image + self.h,self.w,self.d = image.shape + + def __functional__(self,size): + ''' + :param size: crop Size + :return: Center Crop Images + ''' + crop_h = int((self.h-size[0])/2) + crop_w = int((self.w-size[1])/2) + crop_d = int((self.d-size[2])/2) + return self.image[crop_h:crop_h+size[0],crop_w:crop_w+size[1],crop_d:crop_d+size[2]] + + def crop(self,size): + return self.__functional__(size) + + +class RandomFlip: + + def __init__(self,image,p=0.5): + self.image = image + self.p = p + self.h,self.w,self.d = image.shape + + def horizontal_flip(self,p=-1): + ''' + :param p: probability of flip + :return: randomly horizontaly flipped image + ''' + if p == -1: + p = self.p + + integer = random.randint(0, 1) + if integer <= p: + output_image = self.image[:, -1:0:-1, :] + else: + output_image = self.image + + return output_image + + def vertical_flip(self,p=-1): + + ''' + :param p: probability of flip + :return: randomly vertically flipped image + ''' + + if p == -1: + p = self.p + + integer = random.randint(0, 1) + if integer <= p: + output_image = np.flipud(self.image) + else: + output_image = self.image + + return output_image + + def horizontal_flip(self,p=-1): + + ''' + :param p: probability of flip + :return: randomly vertically flipped image + ''' + + if p == -1: + p = self.p + + integer = random.randint(0, 1) + if integer <= p: + output_image = np.fliplr(self.image) + else: + output_image = self.image + + return output_image