--- a +++ b/src/data.py @@ -0,0 +1,145 @@ +import torch +from torch.utils.data import Dataset +import json +import os +from PIL import Image +import numpy as np +from torchvision import transforms +import yaml + + +class IHDataset(Dataset): + """Intracranial Hemorrhage dataset.""" + def __init__(self, root_dir, stage='train', transform=None): + self.root_dir = root_dir + self.stage = stage + self.transform = transform + self.data_dir = os.path.join(self.root_dir, self.stage) + + with open(os.path.join(root_dir, f'annots/{stage}.csv'), 'r') as csv: + lines = [line.strip().split(',') for line in csv.readlines()] + header, lines = lines[0], lines[1:] + # k: dicom ID, v: IH/noIH (1/0) + self.annots = {k:[] for k in header} + for line in lines: + for k,v in zip(header,line): + self.annots[k].append(v) + + self.idx_to_slice_id = dict(enumerate(self.annots['ID'])) + self.slice_id_to_idx = {v:k for k,v in self.idx_to_slice_id.items()} + + test_patients_path = os.path.join(root_dir, + f'annots/patient_stats_test.yaml') + self.test_patients = yaml.safe_load(open(test_patients_path, 'r')) + + def __len__(self): + return len(self.annots['ID']) + + def __getitem__(self, idx): + img_name = f'{self.idx_to_slice_id[idx]}.png' + img_path = os.path.join(self.data_dir, img_name) + sample = Image.open(img_path) + + if self.transform: + sample = self.transform(sample) + + target = int(self.annots['IH'][idx]) + return sample, target + + def getSlice(self, slice_id): + idx = self.slice_id_to_idx[slice_id] + return self.__getitem__(idx) + + +class IHTestDataset(Dataset): + """Intracranial Hemorrhage dataset.""" + def __init__(self, root_dir, stage='', transform=None): + self.root_dir = root_dir + self.transform = transform + self.data_dir = root_dir if root_dir[-1] == '/' else root_dir + '/' + import glob + self.data = glob.glob(self.data_dir + '*') + self.idx_to_img_path = dict(enumerate(self.data)) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + img_path = self.idx_to_img_path[idx] + sample = Image.open(img_path) + if self.transform: + sample = self.transform(sample) + return sample, img_path + + +def get_datasets(conf): + dataset = conf['data']['name'] + root_dir = conf['data']['path'] + train_transform = conf['train_transform'] + valid_transform = conf['valid_transform'] + test_transform = conf['test_transform'] + if dataset == 'IHDataset': + train_dataset = IHDataset(root_dir=root_dir, + stage='train', + transform=train_transform) + valid_dataset = IHDataset(root_dir=root_dir, + stage='valid', + transform=valid_transform) + test_dataset = IHDataset(root_dir=root_dir, + stage='test', + transform=test_transform) + patients_dataset = IHDataset(root_dir=root_dir, + stage='test_no_balanced', + transform=test_transform) + elif dataset == 'IHTestDataset': + return None, None, IHTestDataset(root_dir=root_dir, + transform=test_transform), None + else: + print('Dataset {dataset} not supported.') + exit() + return train_dataset, valid_dataset, test_dataset, patients_dataset + + +def get_dataloaders(conf): + from torch.utils.data import DataLoader + train_dataset = conf['train_dataset'] + valid_dataset = conf['valid_dataset'] + test_dataset = conf['test_dataset'] + batch_size = conf['data']['batch_size'] + num_workers = conf['data']['num_workers'] + train_loader = DataLoader(dataset=train_dataset, + batch_size=batch_size, + num_workers=num_workers, + shuffle=True, + pin_memory=True) + valid_loader = DataLoader(dataset=valid_dataset, + batch_size= batch_size, + num_workers=num_workers, + shuffle=False, + pin_memory=True) + test_loader = DataLoader(dataset=valid_dataset, + batch_size= batch_size, + num_workers=num_workers, + shuffle=False, + pin_memory=True) + + return train_loader, valid_loader, test_loader + + +if __name__ == '__main__': + ''' + print(os.listdir('./data/windowed')) + ds = IHDataset(root_dir='./data/windowed/', stage='valid') + sample, target = ds[2] + arr = np.transpose(255 * (0.5 * sample + 0.5), (1,2,0)) + im = Image.fromarray(np.uint8(arr)) + im.save('example.jpg') + print(target) + print(len(ds)) + ''' + ds = IHTestDataset(root_dir='../patients_windowed/001_1/') + sample, img_path = ds[0] + print('Img path:', img_path) + arr = sample + im = Image.fromarray(np.uint8(arr)) + im.save('example.jpg')