Diff of /src/data.py [000000] .. [f45789]

Switch to side-by-side view

--- a
+++ b/src/data.py
@@ -0,0 +1,145 @@
+import torch
+from torch.utils.data import Dataset
+import json
+import os
+from PIL import Image
+import numpy as np
+from torchvision import transforms
+import yaml
+
+
+class IHDataset(Dataset):
+    """Intracranial Hemorrhage dataset."""
+    def __init__(self, root_dir, stage='train', transform=None):
+        self.root_dir = root_dir
+        self.stage = stage
+        self.transform = transform
+        self.data_dir = os.path.join(self.root_dir, self.stage)
+
+        with open(os.path.join(root_dir, f'annots/{stage}.csv'), 'r') as csv:
+            lines = [line.strip().split(',') for line in csv.readlines()]
+            header, lines = lines[0], lines[1:]
+            # k: dicom ID, v: IH/noIH (1/0)
+            self.annots = {k:[] for k in header}
+            for line in lines:
+                for k,v in zip(header,line):
+                    self.annots[k].append(v)
+
+        self.idx_to_slice_id = dict(enumerate(self.annots['ID']))
+        self.slice_id_to_idx = {v:k for k,v in self.idx_to_slice_id.items()}
+
+        test_patients_path = os.path.join(root_dir,
+                                          f'annots/patient_stats_test.yaml')
+        self.test_patients = yaml.safe_load(open(test_patients_path, 'r'))
+
+    def __len__(self):
+        return len(self.annots['ID'])
+
+    def __getitem__(self, idx):
+        img_name = f'{self.idx_to_slice_id[idx]}.png'
+        img_path = os.path.join(self.data_dir, img_name)
+        sample = Image.open(img_path)
+
+        if self.transform:
+            sample = self.transform(sample)
+
+        target = int(self.annots['IH'][idx])
+        return sample, target
+
+    def getSlice(self, slice_id):
+        idx = self.slice_id_to_idx[slice_id]
+        return self.__getitem__(idx)
+
+
+class IHTestDataset(Dataset):
+    """Intracranial Hemorrhage dataset."""
+    def __init__(self, root_dir, stage='', transform=None):
+        self.root_dir = root_dir
+        self.transform = transform
+        self.data_dir = root_dir if root_dir[-1] == '/' else root_dir + '/'
+        import glob
+        self.data = glob.glob(self.data_dir + '*')
+        self.idx_to_img_path = dict(enumerate(self.data))
+
+    def __len__(self):
+        return len(self.data)
+
+    def __getitem__(self, idx):
+        img_path = self.idx_to_img_path[idx]
+        sample = Image.open(img_path)
+        if self.transform:
+            sample = self.transform(sample)
+        return sample, img_path
+
+
+def get_datasets(conf):
+    dataset = conf['data']['name']
+    root_dir = conf['data']['path']
+    train_transform = conf['train_transform']
+    valid_transform = conf['valid_transform']
+    test_transform = conf['test_transform']
+    if dataset == 'IHDataset':
+        train_dataset = IHDataset(root_dir=root_dir,
+                                  stage='train',
+                                  transform=train_transform)
+        valid_dataset = IHDataset(root_dir=root_dir,
+                                  stage='valid',
+                                  transform=valid_transform)
+        test_dataset = IHDataset(root_dir=root_dir,
+                                  stage='test',
+                                  transform=test_transform)
+        patients_dataset = IHDataset(root_dir=root_dir,
+                                     stage='test_no_balanced',
+                                     transform=test_transform)
+    elif dataset == 'IHTestDataset':
+        return None, None, IHTestDataset(root_dir=root_dir,
+                                         transform=test_transform), None
+    else:
+        print('Dataset {dataset} not supported.')
+        exit()
+    return train_dataset, valid_dataset, test_dataset, patients_dataset
+
+
+def get_dataloaders(conf):
+    from torch.utils.data import DataLoader
+    train_dataset = conf['train_dataset']
+    valid_dataset = conf['valid_dataset']
+    test_dataset = conf['test_dataset']
+    batch_size = conf['data']['batch_size']
+    num_workers = conf['data']['num_workers']
+    train_loader = DataLoader(dataset=train_dataset,
+                              batch_size=batch_size,
+                              num_workers=num_workers,
+                              shuffle=True,
+                              pin_memory=True)
+    valid_loader = DataLoader(dataset=valid_dataset,
+                              batch_size= batch_size,
+                              num_workers=num_workers,
+                              shuffle=False,
+                              pin_memory=True)
+    test_loader = DataLoader(dataset=valid_dataset,
+                             batch_size= batch_size,
+                             num_workers=num_workers,
+                             shuffle=False,
+                             pin_memory=True)
+
+    return train_loader, valid_loader, test_loader
+
+
+if __name__ == '__main__':
+    '''
+    print(os.listdir('./data/windowed'))
+    ds = IHDataset(root_dir='./data/windowed/', stage='valid')
+    sample, target = ds[2]
+    arr = np.transpose(255 * (0.5 * sample + 0.5), (1,2,0))
+    im = Image.fromarray(np.uint8(arr))
+    im.save('example.jpg')
+    print(target)
+    print(len(ds))
+    '''
+    ds = IHTestDataset(root_dir='../patients_windowed/001_1/')
+    sample, img_path = ds[0]
+    print('Img path:', img_path)
+    arr = sample
+    im = Image.fromarray(np.uint8(arr))
+    im.save('example.jpg')