Diff of /utils.py [000000] .. [72db80]

Switch to unified view

a b/utils.py
1
import os
2
import json
3
import torch
4
import glob
5
6
from torch.utils.data import Dataset, DataLoader
7
from torchvision import transforms
8
9
from imgaug import augmenters as iaa
10
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
11
12
import numpy as np
13
14
def create_nested_dir(log_path):
15
    # Create the experiment directory if not present
16
    if not os.path.isdir(log_path):
17
        os.makedirs(log_path)
18
        os.makedirs(os.path.join(log_path, 'checkpoint'))
19
20
def load_dataset_dist():
21
  
22
    with open(os.path.join('configuration', 'cases_division.json'), 'r') as f:
23
        dataset = json.load(f)
24
25
    return dataset       
26
27
def get_data_loaders(data_aug, cases, dataset_dir, batch_size):
28
    dataloaders = {}
29
30
    dataloaders['Train'] = get_dataset(
31
        dataset_dir, data_aug, cases=cases['train'], balanced_filelist=None, batch_size=batch_size)
32
33
    dataloaders['Valid'] = get_dataset(
34
        dataset_dir, 'none', cases=cases['valid'], batch_size=batch_size)
35
36
    return dataloaders 
37
38
def get_dataset(data_dir, data_aug, cases=[], balanced_filelist=None, imageFolder='Images', maskFolder='Masks', batch_size=4):
39
40
    data_transforms = {
41
        'Train': transforms.Compose([ToTensor()]),
42
        'Test': transforms.Compose([ToTensor()]),
43
    }
44
45
    image_dataset = SegNumpyDataset(
46
        data_aug=data_aug, root_dir=data_dir, cases=cases, transform=data_transforms['Train'], maskFolder=maskFolder, imageFolder=imageFolder, balanced_filelist=balanced_filelist)
47
48
    dataloader = DataLoader(
49
        image_dataset, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=4)
50
51
    return dataloader    
52
53
class ToTensor(object):
54
    """Convert ndarrays in sample to Tensors."""
55
56
    def __call__(self, sample, maskresize=None, imageresize=None):
57
        image, mask = sample['image'], sample['mask']
58
        if len(mask.shape) == 2:
59
            mask = mask.reshape((1,)+mask.shape)
60
        if len(image.shape) == 2:
61
            image = image.reshape((1,)+image.shape)
62
        return {'image': torch.from_numpy(image).float(),
63
                'mask': torch.from_numpy(mask).float()}
64
65
class SegNumpyDataset(Dataset):
66
    """Segmentation Dataset"""
67
68
    def __init__(self, root_dir, cases, imageFolder, maskFolder, data_aug, cases_number_format=False, transform=None, balanced_filelist=None):
69
        self.in_channels  = 3
70
        self.root_dir = root_dir
71
        self.transform = transform
72
        self.data_aug = data_aug
73
74
        if cases_number_format:
75
            cases_names = ["case_{:05d}".format(i) for i in cases]
76
        else:
77
            cases_names = cases
78
79
        image_names = []
80
        mask_names = []
81
82
        if balanced_filelist is None:
83
            for case in cases_names:
84
                image_names.extend(glob.glob(os.path.join(
85
                    self.root_dir, case, imageFolder, '*')))
86
                mask_names.extend(glob.glob(os.path.join(
87
                    self.root_dir, case, maskFolder, '*')))
88
        else:
89
            # Essa condição é necessária, pois no data aug offline o nome dos arquivos muda.
90
            if data_aug != 'offline':
91
                for case in cases_names:
92
                    image_list = set(os.listdir(os.path.join(
93
                        self.root_dir, case, imageFolder)))
94
                    set_balanced = set(balanced_filelist)
95
96
                    image_list = list(set_balanced.intersection(image_list))
97
                    fullpath_image_list = [os.path.join(self.root_dir, case, imageFolder, x)
98
                                           for x in image_list]
99
                    fullpath_mask_list = [os.path.join(self.root_dir, case, maskFolder, "masc_"+str(x))
100
                                          for x in image_list]
101
102
                    image_names.extend(fullpath_image_list)
103
                    mask_names.extend(fullpath_mask_list)
104
            else:
105
                for case in cases_names:
106
                    image_list = set(os.listdir(os.path.join(
107
                        self.root_dir, case, imageFolder)))
108
109
                    balanced_filelist_aug = []
110
                    # adiciona os data aug manualmente
111
112
                    for fl in balanced_filelist:
113
                        for i in range(0, 5):
114
                            # case_00000-0-aug-0
115
                            balanced_filelist_aug.append(
116
                                "{}-aug-{}.npz".format(fl.replace(".npz", ""), i))
117
118
                    set_balanced = set(balanced_filelist_aug)
119
120
                    image_list = list(set_balanced.intersection(image_list))
121
                    fullpath_image_list = [os.path.join(self.root_dir, case, imageFolder, x)
122
                                           for x in image_list]
123
                    fullpath_mask_list = [os.path.join(self.root_dir, case, maskFolder, "masc_"+str(x))
124
                                          for x in image_list]
125
126
                    image_names.extend(fullpath_image_list)
127
                    mask_names.extend(fullpath_mask_list)
128
129
        self.image_names = sorted(image_names)
130
        self.mask_names = sorted(mask_names)
131
132
    def __len__(self):
133
        return len(self.image_names)
134
135
    def __getitem__(self, idx):
136
137
        image = np.load(self.image_names[idx])
138
        mask = np.load(self.mask_names[idx])
139
140
        __, file_extension = os.path.splitext(self.image_names[idx])
141
142
        if file_extension == '.npz':
143
            image = image['arr_0']
144
            mask = mask['arr_0']
145
146
147
        if self.in_channels == 1:
148
            image = image[1]
149
        
150
        if self.data_aug == 'online':
151
152
            segmap = SegmentationMapsOnImage(mask, shape=(256, 256))
153
154
            seq = iaa.Sequential([
155
                
156
                iaa.Affine(
157
                    scale=(0.5, 1.2),
158
                    rotate=(-15, 15)
159
                ),  # rotate the image
160
                iaa.Flipud(0.5),
161
                iaa.PiecewiseAffine(scale=(0.01, 0.05)),
162
                iaa.Sometimes(
163
                    0.1,
164
                    iaa.GaussianBlur((0.1, 1.5)),
165
                ),
166
                iaa.Sometimes(
167
                    0.1,
168
                    iaa.LinearContrast((0.5, 2.0), per_channel=0.5),
169
                )
170
            ])
171
172
            image = image.transpose(1, 2, 0)
173
            # Apply augmentations for image and mask
174
            image, mask = seq(image=image, segmentation_maps=segmap)
175
            image = image.copy()
176
            mask = mask.copy()
177
            image = image.transpose(2, 0, 1)
178
            mask = mask.get_arr()
179
180
        sample = {'image': image, 'mask': mask}
181
182
        if self.transform:
183
            sample = self.transform(sample)
184
185
        return sample