|
a |
|
b/utils.py |
|
|
1 |
import os |
|
|
2 |
import json |
|
|
3 |
import torch |
|
|
4 |
import glob |
|
|
5 |
|
|
|
6 |
from torch.utils.data import Dataset, DataLoader |
|
|
7 |
from torchvision import transforms |
|
|
8 |
|
|
|
9 |
from imgaug import augmenters as iaa |
|
|
10 |
from imgaug.augmentables.segmaps import SegmentationMapsOnImage |
|
|
11 |
|
|
|
12 |
import numpy as np |
|
|
13 |
|
|
|
14 |
def create_nested_dir(log_path): |
|
|
15 |
# Create the experiment directory if not present |
|
|
16 |
if not os.path.isdir(log_path): |
|
|
17 |
os.makedirs(log_path) |
|
|
18 |
os.makedirs(os.path.join(log_path, 'checkpoint')) |
|
|
19 |
|
|
|
20 |
def load_dataset_dist(): |
|
|
21 |
|
|
|
22 |
with open(os.path.join('configuration', 'cases_division.json'), 'r') as f: |
|
|
23 |
dataset = json.load(f) |
|
|
24 |
|
|
|
25 |
return dataset |
|
|
26 |
|
|
|
27 |
def get_data_loaders(data_aug, cases, dataset_dir, batch_size): |
|
|
28 |
dataloaders = {} |
|
|
29 |
|
|
|
30 |
dataloaders['Train'] = get_dataset( |
|
|
31 |
dataset_dir, data_aug, cases=cases['train'], balanced_filelist=None, batch_size=batch_size) |
|
|
32 |
|
|
|
33 |
dataloaders['Valid'] = get_dataset( |
|
|
34 |
dataset_dir, 'none', cases=cases['valid'], batch_size=batch_size) |
|
|
35 |
|
|
|
36 |
return dataloaders |
|
|
37 |
|
|
|
38 |
def get_dataset(data_dir, data_aug, cases=[], balanced_filelist=None, imageFolder='Images', maskFolder='Masks', batch_size=4): |
|
|
39 |
|
|
|
40 |
data_transforms = { |
|
|
41 |
'Train': transforms.Compose([ToTensor()]), |
|
|
42 |
'Test': transforms.Compose([ToTensor()]), |
|
|
43 |
} |
|
|
44 |
|
|
|
45 |
image_dataset = SegNumpyDataset( |
|
|
46 |
data_aug=data_aug, root_dir=data_dir, cases=cases, transform=data_transforms['Train'], maskFolder=maskFolder, imageFolder=imageFolder, balanced_filelist=balanced_filelist) |
|
|
47 |
|
|
|
48 |
dataloader = DataLoader( |
|
|
49 |
image_dataset, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=4) |
|
|
50 |
|
|
|
51 |
return dataloader |
|
|
52 |
|
|
|
53 |
class ToTensor(object): |
|
|
54 |
"""Convert ndarrays in sample to Tensors.""" |
|
|
55 |
|
|
|
56 |
def __call__(self, sample, maskresize=None, imageresize=None): |
|
|
57 |
image, mask = sample['image'], sample['mask'] |
|
|
58 |
if len(mask.shape) == 2: |
|
|
59 |
mask = mask.reshape((1,)+mask.shape) |
|
|
60 |
if len(image.shape) == 2: |
|
|
61 |
image = image.reshape((1,)+image.shape) |
|
|
62 |
return {'image': torch.from_numpy(image).float(), |
|
|
63 |
'mask': torch.from_numpy(mask).float()} |
|
|
64 |
|
|
|
65 |
class SegNumpyDataset(Dataset): |
|
|
66 |
"""Segmentation Dataset""" |
|
|
67 |
|
|
|
68 |
def __init__(self, root_dir, cases, imageFolder, maskFolder, data_aug, cases_number_format=False, transform=None, balanced_filelist=None): |
|
|
69 |
self.in_channels = 3 |
|
|
70 |
self.root_dir = root_dir |
|
|
71 |
self.transform = transform |
|
|
72 |
self.data_aug = data_aug |
|
|
73 |
|
|
|
74 |
if cases_number_format: |
|
|
75 |
cases_names = ["case_{:05d}".format(i) for i in cases] |
|
|
76 |
else: |
|
|
77 |
cases_names = cases |
|
|
78 |
|
|
|
79 |
image_names = [] |
|
|
80 |
mask_names = [] |
|
|
81 |
|
|
|
82 |
if balanced_filelist is None: |
|
|
83 |
for case in cases_names: |
|
|
84 |
image_names.extend(glob.glob(os.path.join( |
|
|
85 |
self.root_dir, case, imageFolder, '*'))) |
|
|
86 |
mask_names.extend(glob.glob(os.path.join( |
|
|
87 |
self.root_dir, case, maskFolder, '*'))) |
|
|
88 |
else: |
|
|
89 |
# Essa condição é necessária, pois no data aug offline o nome dos arquivos muda. |
|
|
90 |
if data_aug != 'offline': |
|
|
91 |
for case in cases_names: |
|
|
92 |
image_list = set(os.listdir(os.path.join( |
|
|
93 |
self.root_dir, case, imageFolder))) |
|
|
94 |
set_balanced = set(balanced_filelist) |
|
|
95 |
|
|
|
96 |
image_list = list(set_balanced.intersection(image_list)) |
|
|
97 |
fullpath_image_list = [os.path.join(self.root_dir, case, imageFolder, x) |
|
|
98 |
for x in image_list] |
|
|
99 |
fullpath_mask_list = [os.path.join(self.root_dir, case, maskFolder, "masc_"+str(x)) |
|
|
100 |
for x in image_list] |
|
|
101 |
|
|
|
102 |
image_names.extend(fullpath_image_list) |
|
|
103 |
mask_names.extend(fullpath_mask_list) |
|
|
104 |
else: |
|
|
105 |
for case in cases_names: |
|
|
106 |
image_list = set(os.listdir(os.path.join( |
|
|
107 |
self.root_dir, case, imageFolder))) |
|
|
108 |
|
|
|
109 |
balanced_filelist_aug = [] |
|
|
110 |
# adiciona os data aug manualmente |
|
|
111 |
|
|
|
112 |
for fl in balanced_filelist: |
|
|
113 |
for i in range(0, 5): |
|
|
114 |
# case_00000-0-aug-0 |
|
|
115 |
balanced_filelist_aug.append( |
|
|
116 |
"{}-aug-{}.npz".format(fl.replace(".npz", ""), i)) |
|
|
117 |
|
|
|
118 |
set_balanced = set(balanced_filelist_aug) |
|
|
119 |
|
|
|
120 |
image_list = list(set_balanced.intersection(image_list)) |
|
|
121 |
fullpath_image_list = [os.path.join(self.root_dir, case, imageFolder, x) |
|
|
122 |
for x in image_list] |
|
|
123 |
fullpath_mask_list = [os.path.join(self.root_dir, case, maskFolder, "masc_"+str(x)) |
|
|
124 |
for x in image_list] |
|
|
125 |
|
|
|
126 |
image_names.extend(fullpath_image_list) |
|
|
127 |
mask_names.extend(fullpath_mask_list) |
|
|
128 |
|
|
|
129 |
self.image_names = sorted(image_names) |
|
|
130 |
self.mask_names = sorted(mask_names) |
|
|
131 |
|
|
|
132 |
def __len__(self): |
|
|
133 |
return len(self.image_names) |
|
|
134 |
|
|
|
135 |
def __getitem__(self, idx): |
|
|
136 |
|
|
|
137 |
image = np.load(self.image_names[idx]) |
|
|
138 |
mask = np.load(self.mask_names[idx]) |
|
|
139 |
|
|
|
140 |
__, file_extension = os.path.splitext(self.image_names[idx]) |
|
|
141 |
|
|
|
142 |
if file_extension == '.npz': |
|
|
143 |
image = image['arr_0'] |
|
|
144 |
mask = mask['arr_0'] |
|
|
145 |
|
|
|
146 |
|
|
|
147 |
if self.in_channels == 1: |
|
|
148 |
image = image[1] |
|
|
149 |
|
|
|
150 |
if self.data_aug == 'online': |
|
|
151 |
|
|
|
152 |
segmap = SegmentationMapsOnImage(mask, shape=(256, 256)) |
|
|
153 |
|
|
|
154 |
seq = iaa.Sequential([ |
|
|
155 |
|
|
|
156 |
iaa.Affine( |
|
|
157 |
scale=(0.5, 1.2), |
|
|
158 |
rotate=(-15, 15) |
|
|
159 |
), # rotate the image |
|
|
160 |
iaa.Flipud(0.5), |
|
|
161 |
iaa.PiecewiseAffine(scale=(0.01, 0.05)), |
|
|
162 |
iaa.Sometimes( |
|
|
163 |
0.1, |
|
|
164 |
iaa.GaussianBlur((0.1, 1.5)), |
|
|
165 |
), |
|
|
166 |
iaa.Sometimes( |
|
|
167 |
0.1, |
|
|
168 |
iaa.LinearContrast((0.5, 2.0), per_channel=0.5), |
|
|
169 |
) |
|
|
170 |
]) |
|
|
171 |
|
|
|
172 |
image = image.transpose(1, 2, 0) |
|
|
173 |
# Apply augmentations for image and mask |
|
|
174 |
image, mask = seq(image=image, segmentation_maps=segmap) |
|
|
175 |
image = image.copy() |
|
|
176 |
mask = mask.copy() |
|
|
177 |
image = image.transpose(2, 0, 1) |
|
|
178 |
mask = mask.get_arr() |
|
|
179 |
|
|
|
180 |
sample = {'image': image, 'mask': mask} |
|
|
181 |
|
|
|
182 |
if self.transform: |
|
|
183 |
sample = self.transform(sample) |
|
|
184 |
|
|
|
185 |
return sample |