Switch to unified view

a b/Preprocessing Medical Data Pipeline/data_augmentation.py
1
import numpy as np
2
import cv2
3
from random import randint
4
from tqdm import tqdm
5
6
7
def flatten_array(arr):
8
    return arr.reshape(-1, *arr.shape[2:])
9
10
def ImageDataAugmentation(img, mask, num_augmentations):
11
    augmented_images = []
12
    augmented_masks = []
13
14
    for i in range(num_augmentations):
15
        
16
        # Apply a random rotation to the images
17
        angle = randint(-15, 15)
18
        M = cv2.getRotationMatrix2D((img.shape[1] / 2, img.shape[0] / 2), angle, 1)
19
        rotated_img = cv2.warpAffine(img, M, (img.shape[1], img.shape[0]))
20
        rotated_mask = cv2.warpAffine(mask, M, (mask.shape[1], mask.shape[0]))
21
22
        # Apply a random horizontal flip to the images
23
        if randint(0, 1):
24
            flipped_img = cv2.flip(rotated_img, 1)
25
            flipped_mask = cv2.flip(rotated_mask, 1)
26
        else:
27
            flipped_img = rotated_img
28
            flipped_mask = rotated_mask
29
30
        # Append the augmented images and masks to the lists
31
        augmented_images.append(flipped_img)
32
        augmented_masks.append(flipped_mask)
33
34
    # Convert the lists of augmented images and masks to NumPy arrays
35
    augmented_images = np.array(augmented_images)
36
    augmented_masks = np.array(augmented_masks)
37
38
    return augmented_images, augmented_masks
39
40
def DataAugmentation (imgs_train, imgs_mask_train, num_augmentations):
41
    augmented_images_train = []
42
    augmented_masks_train = []
43
44
    for i in tqdm(range(len(imgs_train))):
45
        augmented_images, augmented_masks = ImageDataAugmentation(imgs_train[i,:,:,0], imgs_mask_train[i,:,:,0], num_augmentations)
46
        augmented_images_train.append(augmented_images)
47
        augmented_masks_train.append(augmented_masks)
48
49
    augmented_images_train = np.array(augmented_images_train)
50
    augmented_masks_train = np.array(augmented_masks_train)
51
52
    augmented_images_train = flatten_array(augmented_images_train)
53
    augmented_masks_train = flatten_array(augmented_masks_train)
54
55
    return augmented_images_train, augmented_masks_train