|
a |
|
b/Dataset/datasetloader.py |
|
|
1 |
import numpy as np |
|
|
2 |
import cv2 |
|
|
3 |
from torch.utils.data import Dataset |
|
|
4 |
import torch |
|
|
5 |
import random |
|
|
6 |
IGNORED = ['.DS_Store'] |
|
|
7 |
|
|
|
8 |
|
|
|
9 |
class MRIDataset(Dataset): |
|
|
10 |
def __init__(self, imgpath, labelpath, preprocessors=None, verbose=-1): |
|
|
11 |
super(MRIDataset, self).__init__() |
|
|
12 |
# store the image preprocessor |
|
|
13 |
self.preprocessors = preprocessors |
|
|
14 |
self.imgpath = imgpath |
|
|
15 |
self.labelpath = labelpath |
|
|
16 |
|
|
|
17 |
# if the preprocessors are None, initialize them as an |
|
|
18 |
# empty list |
|
|
19 |
if self.preprocessors is None: |
|
|
20 |
self.preprocessors = [] |
|
|
21 |
|
|
|
22 |
self.images = [] |
|
|
23 |
self.masks = [] |
|
|
24 |
|
|
|
25 |
for (i, path) in enumerate(self.imgpath): |
|
|
26 |
image = cv2.imread(path,0) |
|
|
27 |
|
|
|
28 |
if self.preprocessors is not None: |
|
|
29 |
for p in self.preprocessors: |
|
|
30 |
image = p.preprocess(image) |
|
|
31 |
|
|
|
32 |
image = torch.from_numpy(image) |
|
|
33 |
image = image.unsqueeze(0) |
|
|
34 |
|
|
|
35 |
self.images.append(image) |
|
|
36 |
|
|
|
37 |
if verbose > 0 and i > 0 and (i + 1) % verbose == 0: |
|
|
38 |
print("[INFO] processed {}/{}".format(i + 1, len(path))) |
|
|
39 |
|
|
|
40 |
for (i, path) in enumerate(self.labelpath): |
|
|
41 |
label = cv2.imread(path) |
|
|
42 |
|
|
|
43 |
if self.preprocessors is not None: |
|
|
44 |
for p in self.preprocessors: |
|
|
45 |
label = p.preprocess(label) |
|
|
46 |
label = np.sum(label, axis=2) |
|
|
47 |
label = label > 0.5 |
|
|
48 |
label = torch.from_numpy(label) |
|
|
49 |
label = label.unsqueeze(0) |
|
|
50 |
|
|
|
51 |
self.masks.append(label) |
|
|
52 |
|
|
|
53 |
if verbose > 0 and i > 0 and (i + 1) % verbose == 0: |
|
|
54 |
print("[INFO] processed {}/{}".format(i + 1, len(path))) |
|
|
55 |
|
|
|
56 |
def __len__(self): |
|
|
57 |
return len(self.images) |
|
|
58 |
|
|
|
59 |
def __getitem__(self, idx): |
|
|
60 |
|
|
|
61 |
image = self.images[idx] |
|
|
62 |
mask = self.masks[idx] |
|
|
63 |
|
|
|
64 |
# Flip image for data augmentation |
|
|
65 |
if random.random() > 0.5: |
|
|
66 |
image = torch.flip(image, [0]) |
|
|
67 |
mask = torch.flip(mask, [0]) |
|
|
68 |
|
|
|
69 |
return image.float(), mask.float() |
|
|
70 |
|