--- a +++ b/utils/augmentations.py @@ -0,0 +1,143 @@ +import numpy as np +import nibabel as nib +import scipy.ndimage +import warnings +import PP +import sys + +#--------------------------------------------- +#Functions for image augmentations on 3D input +#--------------------------------------------- + +#img_b, label_b is (batch_num) x 1 x dim1 x dim2 x dim3 +#takes in a list of 3D images (1st one is input, 2nd one needs to be label) +def augmentPatchLossy(imgs, rotation=[5,5,5], scale_min=0.9, scale_max=1.1, flip_lvl = 0): + new_imgs = [] + + rot_x = np.random.uniform(-rotation[0], rotation[0]) * np.pi / 180.0 + rot_y = np.random.uniform(-rotation[1], rotation[1]) * np.pi / 180.0 + rot_z = np.random.uniform(-rotation[2], rotation[2]) * np.pi / 180.0 + + zoom_val = np.random.uniform(scale_min, scale_max) + for i in range(len(imgs)): + l = convertBatchToList(imgs[i]) + if i == 0: + spline_orders = [3] * len(l) + else: + spline_orders = [0] * len(l) + scaled = applyScale(l, zoom_val, spline_orders) + rotated = applyRotation(scaled, [rot_x, rot_y, rot_z], spline_orders) + new_imgs.append(convertListToBatch(rotated)) + return imgs + +def convertBatchToList(img): + l = [] + b, c, d1, d2, d3 = img.shape + for i in range(img.shape[0]): + l.append(img[i,:,:,:,:].reshape([1,c,d1,d2,d3])) + return l + +def convertListToBatch(img_list): + b, c, d1, d2, d3 = img_list[0].shape + a = np.zeros([len(img_list), c, d1,d2,d3]) + for i in range(len(img_list)): + a[i,:,:,:,:] = img_list[i] + return a + +def augmentPatchLossLess(imgs): + new_imgs = [] + + p = np.random.rand(3) > 0.5 + locations = np.where(p == 1)[0] + 2 + + for i in range(len(imgs)): + l = convertBatchToList(imgs[i]) + if i == 0: + spline_orders = [3] * len(l) + else: + spline_orders = [0] * len(l) + flipped = applyFLIPS2(l, locations) + + rot_x = np.random.randint(4) * np.pi / 2.0 # (0,1,2,3)*90/180.0 + rot_y = np.random.randint(4) * np.pi / 2.0 # (0,1,2,3)*90/180.0 + rot_z = np.random.randint(4) * np.pi / 2.0 # (0,1,2,3)*90/180.0 + rotated = applyRotation(flipped, [rot_x, rot_y, rot_z], spline_orders) + new_imgs.append(convertListToBatch(rotated)) + return new_imgs + +def augmentBoth(imgs): + imgs = augmentPatchLossy(imgs) + imgs = augmentPatchLessLess(imgs) + return imgs + +def getRotationVal(rotation=[5,5,5]): + rot_x = np.random.uniform(-rotation[0], rotation[0]) * np.pi / 180.0 + rot_y = np.random.uniform(-rotation[1], rotation[1]) * np.pi / 180.0 + rot_z = np.random.uniform(-rotation[2], rotation[2]) * np.pi / 180.0 + return rot_x, rot_y, rot_z + +def getScalingVal(scale_min = 0.9, scale_max = 1.1): + return np.random.uniform(scale_min, scale_max) + +def applyFLIPS(images, flip_lvl = 0): + if flip_lvl == 0: + p = np.random.rand(2) > 0.5 + else: + p = np.random.rand(3) > 0.5 + locations = np.where(p == 1)[0] + 2 + + new_imgs = [] + for img in images: + for i in locations: + img = np.flip(img, axis=i) + new_imgs.append(img) + return new_imgs + +def applyFLIPS2(images, locations): + new_imgs = [] + for img in images: + for i in locations: + img = np.flip(img, axis=i) + new_imgs.append(img) + return new_imgs + +def applyRotation(images, rot, spline_orders): + transform_x = np.array([[1.0, 0.0, 0.0], + [0.0, np.cos(rot[0]), -np.sin(rot[0])], + [0.0, np.sin(rot[0]), np.cos(rot[0])]]) + + transform_y = np.array([[np.cos(rot[1]), 0.0, np.sin(rot[1])], + [0.0, 1.0, 0.0], + [-np.sin(rot[1]), 0.0, np.cos(rot[1])]]) + + transform_z = np.array([[np.cos(rot[2]), -np.sin(rot[2]), 0.0], + [np.sin(rot[2]), np.cos(rot[2]), 0.0], + [0.0, 0, 1]]) + transform = np.dot(transform_z, np.dot(transform_x, transform_y)) + + new_imgs = [] + for i, img in enumerate(images): + mid_index = 0.5 * np.asarray(img.squeeze().shape, dtype=np.int64) + offset = mid_index - mid_index.dot(np.linalg.inv(transform)) + new_img = scipy.ndimage.affine_transform( + input = img.squeeze(), + matrix = transform, + offset = offset, + order = spline_orders[i], + mode = 'nearest') + new_img = new_img[np.newaxis,np.newaxis,:] + new_imgs.append(new_img) + return new_imgs + +def applyScale(images, zoom_val, spline_orders): + new_imgs = [] + for i, img in enumerate(images): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + new_img = scipy.ndimage.zoom(img.squeeze(), zoom_val, order = spline_orders[i]) + new_img = new_img[np.newaxis,np.newaxis,:] + new_imgs.append(new_img) + except: + pass + return new_imgs \ No newline at end of file