Diff of /src/data.py [000000] .. [f45789]

Switch to unified view

a b/src/data.py
1
import torch
2
from torch.utils.data import Dataset
3
import json
4
import os
5
from PIL import Image
6
import numpy as np
7
from torchvision import transforms
8
import yaml
9
10
11
class IHDataset(Dataset):
12
    """Intracranial Hemorrhage dataset."""
13
    def __init__(self, root_dir, stage='train', transform=None):
14
        self.root_dir = root_dir
15
        self.stage = stage
16
        self.transform = transform
17
        self.data_dir = os.path.join(self.root_dir, self.stage)
18
19
        with open(os.path.join(root_dir, f'annots/{stage}.csv'), 'r') as csv:
20
            lines = [line.strip().split(',') for line in csv.readlines()]
21
            header, lines = lines[0], lines[1:]
22
            # k: dicom ID, v: IH/noIH (1/0)
23
            self.annots = {k:[] for k in header}
24
            for line in lines:
25
                for k,v in zip(header,line):
26
                    self.annots[k].append(v)
27
28
        self.idx_to_slice_id = dict(enumerate(self.annots['ID']))
29
        self.slice_id_to_idx = {v:k for k,v in self.idx_to_slice_id.items()}
30
31
        test_patients_path = os.path.join(root_dir,
32
                                          f'annots/patient_stats_test.yaml')
33
        self.test_patients = yaml.safe_load(open(test_patients_path, 'r'))
34
35
    def __len__(self):
36
        return len(self.annots['ID'])
37
38
    def __getitem__(self, idx):
39
        img_name = f'{self.idx_to_slice_id[idx]}.png'
40
        img_path = os.path.join(self.data_dir, img_name)
41
        sample = Image.open(img_path)
42
43
        if self.transform:
44
            sample = self.transform(sample)
45
46
        target = int(self.annots['IH'][idx])
47
        return sample, target
48
49
    def getSlice(self, slice_id):
50
        idx = self.slice_id_to_idx[slice_id]
51
        return self.__getitem__(idx)
52
53
54
class IHTestDataset(Dataset):
55
    """Intracranial Hemorrhage dataset."""
56
    def __init__(self, root_dir, stage='', transform=None):
57
        self.root_dir = root_dir
58
        self.transform = transform
59
        self.data_dir = root_dir if root_dir[-1] == '/' else root_dir + '/'
60
        import glob
61
        self.data = glob.glob(self.data_dir + '*')
62
        self.idx_to_img_path = dict(enumerate(self.data))
63
64
    def __len__(self):
65
        return len(self.data)
66
67
    def __getitem__(self, idx):
68
        img_path = self.idx_to_img_path[idx]
69
        sample = Image.open(img_path)
70
        if self.transform:
71
            sample = self.transform(sample)
72
        return sample, img_path
73
74
75
def get_datasets(conf):
76
    dataset = conf['data']['name']
77
    root_dir = conf['data']['path']
78
    train_transform = conf['train_transform']
79
    valid_transform = conf['valid_transform']
80
    test_transform = conf['test_transform']
81
    if dataset == 'IHDataset':
82
        train_dataset = IHDataset(root_dir=root_dir,
83
                                  stage='train',
84
                                  transform=train_transform)
85
        valid_dataset = IHDataset(root_dir=root_dir,
86
                                  stage='valid',
87
                                  transform=valid_transform)
88
        test_dataset = IHDataset(root_dir=root_dir,
89
                                  stage='test',
90
                                  transform=test_transform)
91
        patients_dataset = IHDataset(root_dir=root_dir,
92
                                     stage='test_no_balanced',
93
                                     transform=test_transform)
94
    elif dataset == 'IHTestDataset':
95
        return None, None, IHTestDataset(root_dir=root_dir,
96
                                         transform=test_transform), None
97
    else:
98
        print('Dataset {dataset} not supported.')
99
        exit()
100
    return train_dataset, valid_dataset, test_dataset, patients_dataset
101
102
103
def get_dataloaders(conf):
104
    from torch.utils.data import DataLoader
105
    train_dataset = conf['train_dataset']
106
    valid_dataset = conf['valid_dataset']
107
    test_dataset = conf['test_dataset']
108
    batch_size = conf['data']['batch_size']
109
    num_workers = conf['data']['num_workers']
110
    train_loader = DataLoader(dataset=train_dataset,
111
                              batch_size=batch_size,
112
                              num_workers=num_workers,
113
                              shuffle=True,
114
                              pin_memory=True)
115
    valid_loader = DataLoader(dataset=valid_dataset,
116
                              batch_size= batch_size,
117
                              num_workers=num_workers,
118
                              shuffle=False,
119
                              pin_memory=True)
120
    test_loader = DataLoader(dataset=valid_dataset,
121
                             batch_size= batch_size,
122
                             num_workers=num_workers,
123
                             shuffle=False,
124
                             pin_memory=True)
125
126
    return train_loader, valid_loader, test_loader
127
128
129
if __name__ == '__main__':
130
    '''
131
    print(os.listdir('./data/windowed'))
132
    ds = IHDataset(root_dir='./data/windowed/', stage='valid')
133
    sample, target = ds[2]
134
    arr = np.transpose(255 * (0.5 * sample + 0.5), (1,2,0))
135
    im = Image.fromarray(np.uint8(arr))
136
    im.save('example.jpg')
137
    print(target)
138
    print(len(ds))
139
    '''
140
    ds = IHTestDataset(root_dir='../patients_windowed/001_1/')
141
    sample, img_path = ds[0]
142
    print('Img path:', img_path)
143
    arr = sample
144
    im = Image.fromarray(np.uint8(arr))
145
    im.save('example.jpg')