Diff of /utils.py [000000] .. [72db80]

Switch to side-by-side view

--- a
+++ b/utils.py
@@ -0,0 +1,185 @@
+import os
+import json
+import torch
+import glob
+
+from torch.utils.data import Dataset, DataLoader
+from torchvision import transforms
+
+from imgaug import augmenters as iaa
+from imgaug.augmentables.segmaps import SegmentationMapsOnImage
+
+import numpy as np
+
+def create_nested_dir(log_path):
+    # Create the experiment directory if not present
+    if not os.path.isdir(log_path):
+        os.makedirs(log_path)
+        os.makedirs(os.path.join(log_path, 'checkpoint'))
+
+def load_dataset_dist():
+  
+    with open(os.path.join('configuration', 'cases_division.json'), 'r') as f:
+        dataset = json.load(f)
+
+    return dataset       
+
+def get_data_loaders(data_aug, cases, dataset_dir, batch_size):
+    dataloaders = {}
+
+    dataloaders['Train'] = get_dataset(
+        dataset_dir, data_aug, cases=cases['train'], balanced_filelist=None, batch_size=batch_size)
+
+    dataloaders['Valid'] = get_dataset(
+        dataset_dir, 'none', cases=cases['valid'], batch_size=batch_size)
+
+    return dataloaders 
+
+def get_dataset(data_dir, data_aug, cases=[], balanced_filelist=None, imageFolder='Images', maskFolder='Masks', batch_size=4):
+
+    data_transforms = {
+        'Train': transforms.Compose([ToTensor()]),
+        'Test': transforms.Compose([ToTensor()]),
+    }
+
+    image_dataset = SegNumpyDataset(
+        data_aug=data_aug, root_dir=data_dir, cases=cases, transform=data_transforms['Train'], maskFolder=maskFolder, imageFolder=imageFolder, balanced_filelist=balanced_filelist)
+
+    dataloader = DataLoader(
+        image_dataset, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=4)
+
+    return dataloader    
+
+class ToTensor(object):
+    """Convert ndarrays in sample to Tensors."""
+
+    def __call__(self, sample, maskresize=None, imageresize=None):
+        image, mask = sample['image'], sample['mask']
+        if len(mask.shape) == 2:
+            mask = mask.reshape((1,)+mask.shape)
+        if len(image.shape) == 2:
+            image = image.reshape((1,)+image.shape)
+        return {'image': torch.from_numpy(image).float(),
+                'mask': torch.from_numpy(mask).float()}
+
+class SegNumpyDataset(Dataset):
+    """Segmentation Dataset"""
+
+    def __init__(self, root_dir, cases, imageFolder, maskFolder, data_aug, cases_number_format=False, transform=None, balanced_filelist=None):
+        self.in_channels  = 3
+        self.root_dir = root_dir
+        self.transform = transform
+        self.data_aug = data_aug
+
+        if cases_number_format:
+            cases_names = ["case_{:05d}".format(i) for i in cases]
+        else:
+            cases_names = cases
+
+        image_names = []
+        mask_names = []
+
+        if balanced_filelist is None:
+            for case in cases_names:
+                image_names.extend(glob.glob(os.path.join(
+                    self.root_dir, case, imageFolder, '*')))
+                mask_names.extend(glob.glob(os.path.join(
+                    self.root_dir, case, maskFolder, '*')))
+        else:
+            # Essa condição é necessária, pois no data aug offline o nome dos arquivos muda.
+            if data_aug != 'offline':
+                for case in cases_names:
+                    image_list = set(os.listdir(os.path.join(
+                        self.root_dir, case, imageFolder)))
+                    set_balanced = set(balanced_filelist)
+
+                    image_list = list(set_balanced.intersection(image_list))
+                    fullpath_image_list = [os.path.join(self.root_dir, case, imageFolder, x)
+                                           for x in image_list]
+                    fullpath_mask_list = [os.path.join(self.root_dir, case, maskFolder, "masc_"+str(x))
+                                          for x in image_list]
+
+                    image_names.extend(fullpath_image_list)
+                    mask_names.extend(fullpath_mask_list)
+            else:
+                for case in cases_names:
+                    image_list = set(os.listdir(os.path.join(
+                        self.root_dir, case, imageFolder)))
+
+                    balanced_filelist_aug = []
+                    # adiciona os data aug manualmente
+
+                    for fl in balanced_filelist:
+                        for i in range(0, 5):
+                            # case_00000-0-aug-0
+                            balanced_filelist_aug.append(
+                                "{}-aug-{}.npz".format(fl.replace(".npz", ""), i))
+
+                    set_balanced = set(balanced_filelist_aug)
+
+                    image_list = list(set_balanced.intersection(image_list))
+                    fullpath_image_list = [os.path.join(self.root_dir, case, imageFolder, x)
+                                           for x in image_list]
+                    fullpath_mask_list = [os.path.join(self.root_dir, case, maskFolder, "masc_"+str(x))
+                                          for x in image_list]
+
+                    image_names.extend(fullpath_image_list)
+                    mask_names.extend(fullpath_mask_list)
+
+        self.image_names = sorted(image_names)
+        self.mask_names = sorted(mask_names)
+
+    def __len__(self):
+        return len(self.image_names)
+
+    def __getitem__(self, idx):
+
+        image = np.load(self.image_names[idx])
+        mask = np.load(self.mask_names[idx])
+
+        __, file_extension = os.path.splitext(self.image_names[idx])
+
+        if file_extension == '.npz':
+            image = image['arr_0']
+            mask = mask['arr_0']
+
+
+        if self.in_channels == 1:
+            image = image[1]
+        
+        if self.data_aug == 'online':
+
+            segmap = SegmentationMapsOnImage(mask, shape=(256, 256))
+
+            seq = iaa.Sequential([
+                
+                iaa.Affine(
+                    scale=(0.5, 1.2),
+                    rotate=(-15, 15)
+                ),  # rotate the image
+                iaa.Flipud(0.5),
+                iaa.PiecewiseAffine(scale=(0.01, 0.05)),
+                iaa.Sometimes(
+                    0.1,
+                    iaa.GaussianBlur((0.1, 1.5)),
+                ),
+                iaa.Sometimes(
+                    0.1,
+                    iaa.LinearContrast((0.5, 2.0), per_channel=0.5),
+                )
+            ])
+
+            image = image.transpose(1, 2, 0)
+            # Apply augmentations for image and mask
+            image, mask = seq(image=image, segmentation_maps=segmap)
+            image = image.copy()
+            mask = mask.copy()
+            image = image.transpose(2, 0, 1)
+            mask = mask.get_arr()
+
+        sample = {'image': image, 'mask': mask}
+
+        if self.transform:
+            sample = self.transform(sample)
+
+        return sample