Diff of /Dataset/datasetloader.py [000000] .. [6d4adb]

Switch to unified view

a b/Dataset/datasetloader.py
1
import numpy as np
2
import cv2
3
from torch.utils.data import Dataset
4
import torch
5
import random
6
IGNORED = ['.DS_Store']
7
8
9
class MRIDataset(Dataset):
10
    def __init__(self, imgpath, labelpath, preprocessors=None, verbose=-1):
11
        super(MRIDataset, self).__init__()
12
        # store the image preprocessor
13
        self.preprocessors = preprocessors
14
        self.imgpath = imgpath
15
        self.labelpath = labelpath
16
17
        # if the preprocessors are None, initialize them as an
18
        # empty list
19
        if self.preprocessors is None:
20
            self.preprocessors = []
21
22
        self.images = []
23
        self.masks = []
24
25
        for (i, path) in enumerate(self.imgpath):
26
            image = cv2.imread(path,0)
27
28
            if self.preprocessors is not None:
29
                for p in self.preprocessors:
30
                    image = p.preprocess(image)
31
32
            image = torch.from_numpy(image)
33
            image = image.unsqueeze(0)
34
35
            self.images.append(image)
36
37
            if verbose > 0 and i > 0 and (i + 1) % verbose == 0:
38
                print("[INFO] processed {}/{}".format(i + 1, len(path)))
39
40
        for (i, path) in enumerate(self.labelpath):
41
            label = cv2.imread(path)
42
43
            if self.preprocessors is not None:
44
                for p in self.preprocessors:
45
                    label = p.preprocess(label)
46
            label = np.sum(label, axis=2)
47
            label = label > 0.5
48
            label = torch.from_numpy(label)
49
            label = label.unsqueeze(0)
50
51
            self.masks.append(label)
52
53
            if verbose > 0 and i > 0 and (i + 1) % verbose == 0:
54
                print("[INFO] processed {}/{}".format(i + 1, len(path)))
55
56
    def __len__(self):
57
        return len(self.images)
58
59
    def __getitem__(self, idx):
60
61
        image = self.images[idx]
62
        mask = self.masks[idx]
63
64
        # Flip image for data augmentation
65
        if random.random() > 0.5:
66
            image = torch.flip(image, [0])
67
            mask = torch.flip(mask, [0])
68
69
        return image.float(), mask.float()
70