--- a +++ b/dataloader.py @@ -0,0 +1,69 @@ +import os +import pandas as pd +import numpy as np + +import torch +import torch.nn.functional as F +import torchvision.transforms.functional as TF +import torch.utils.data as data +from torchvision import transforms +from torchsample.transforms import RandomRotate, RandomTranslate, RandomFlip, ToTensor, Compose, RandomAffine + + +class MRDataset(data.Dataset): + def __init__(self, root_dir, task, plane, train=True, transform=None, weights=None): + super().__init__() + self.task = task + self.plane = plane + self.root_dir = root_dir + self.train = train + if self.train: + self.folder_path = self.root_dir + 'train/{0}/'.format(plane) + self.records = pd.read_csv( + self.root_dir + 'train-{0}.csv'.format(task), header=None, names=['id', 'label']) + else: + transform = None + self.folder_path = self.root_dir + 'valid/{0}/'.format(plane) + self.records = pd.read_csv( + self.root_dir + 'valid-{0}.csv'.format(task), header=None, names=['id', 'label']) + + self.records['id'] = self.records['id'].map( + lambda i: '0' * (4 - len(str(i))) + str(i)) + self.paths = [self.folder_path + filename + + '.npy' for filename in self.records['id'].tolist()] + self.labels = self.records['label'].tolist() + + self.transform = transform + if weights is None: + pos = np.sum(self.labels) + neg = len(self.labels) - pos + self.weights = torch.FloatTensor([1, neg / pos]) + else: + self.weights = torch.FloatTensor(weights) + + def __len__(self): + return len(self.paths) + + def __getitem__(self, index): + array = np.load(self.paths[index]) + label = self.labels[index] + if label == 1: + label = torch.FloatTensor([[0, 1]]) + elif label == 0: + label = torch.FloatTensor([[1, 0]]) + + if self.transform: + array = self.transform(array) + else: + array = np.stack((array,)*3, axis=1) + array = torch.FloatTensor(array) + + # if label.item() == 1: + # weight = np.array([self.weights[1]]) + # weight = torch.FloatTensor(weight) + # else: + # weight = np.array([self.weights[0]]) + # weight = torch.FloatTensor(weight) + + return array, label, self.weights +