Diff of /dataloader.py [000000] .. [fa8046]

Switch to unified view

a b/dataloader.py
1
import os
2
import pandas as pd
3
import numpy as np
4
5
import torch
6
import torch.nn.functional as F
7
import torchvision.transforms.functional as TF
8
import torch.utils.data as data
9
from torchvision import transforms
10
from torchsample.transforms import RandomRotate, RandomTranslate, RandomFlip, ToTensor, Compose, RandomAffine
11
12
13
class MRDataset(data.Dataset):
14
    def __init__(self, root_dir, task, plane, train=True, transform=None, weights=None):
15
        super().__init__()
16
        self.task = task
17
        self.plane = plane
18
        self.root_dir = root_dir
19
        self.train = train
20
        if self.train:
21
            self.folder_path = self.root_dir + 'train/{0}/'.format(plane)
22
            self.records = pd.read_csv(
23
                self.root_dir + 'train-{0}.csv'.format(task), header=None, names=['id', 'label'])
24
        else:
25
            transform = None
26
            self.folder_path = self.root_dir + 'valid/{0}/'.format(plane)
27
            self.records = pd.read_csv(
28
                self.root_dir + 'valid-{0}.csv'.format(task), header=None, names=['id', 'label'])
29
30
        self.records['id'] = self.records['id'].map(
31
            lambda i: '0' * (4 - len(str(i))) + str(i))
32
        self.paths = [self.folder_path + filename +
33
                      '.npy' for filename in self.records['id'].tolist()]
34
        self.labels = self.records['label'].tolist()
35
36
        self.transform = transform
37
        if weights is None:
38
            pos = np.sum(self.labels)
39
            neg = len(self.labels) - pos
40
            self.weights = torch.FloatTensor([1, neg / pos])
41
        else:
42
            self.weights = torch.FloatTensor(weights)
43
44
    def __len__(self):
45
        return len(self.paths)
46
47
    def __getitem__(self, index):
48
        array = np.load(self.paths[index])
49
        label = self.labels[index]
50
        if label == 1:
51
            label = torch.FloatTensor([[0, 1]])
52
        elif label == 0:
53
            label = torch.FloatTensor([[1, 0]])
54
55
        if self.transform:
56
            array = self.transform(array)
57
        else:
58
            array = np.stack((array,)*3, axis=1)
59
            array = torch.FloatTensor(array)
60
61
        # if label.item() == 1:
62
        #     weight = np.array([self.weights[1]])
63
        #     weight = torch.FloatTensor(weight)
64
        # else:
65
        #     weight = np.array([self.weights[0]])
66
        #     weight = torch.FloatTensor(weight)
67
68
        return array, label, self.weights
69