Diff of /model/data_loader.py [000000] .. [39fb2b]

Switch to unified view

a b/model/data_loader.py
1
import random
2
import os
3
4
from PIL import Image
5
import torch
6
from torch.utils.data import Dataset, DataLoader
7
import torchvision.transforms as transforms
8
9
import pandas as pd
10
import re
11
import numpy as np
12
import utils
13
14
# borrowed from http://pytorch.org/tutorials/advanced/neural_style_tutorial.html
15
# and http://pytorch.org/tutorials/beginner/data_loading_tutorial.html
16
# define a training image loader that specifies transforms on images. See documentation for more details.
17
18
def get_tfms_3d(split, params):
19
    if split == "train":
20
        def tfms(x):
21
            # x = random_3d_crop(x, params.n_crop_vox)
22
            
23
            x = normalize(x, params)
24
            x = random_crop(x, params.n_crop_vox)
25
            # batchgenerators transforms expect bath dim and channel dim
26
            # add these and squeeze off later
27
            x = np.expand_dims(np.expand_dims(x, 0), 0)
28
            # x = transforms3d.spatial_transforms.augment_mirroring(x)[0]
29
            x = np.squeeze(x)
30
31
            return x
32
    else:
33
        def tfms(x):
34
            x = normalize(x, params)
35
            x = unpad(x, int(params.n_crop_vox/2))
36
            return x
37
    
38
    return tfms
39
40
def get_tfms(split = "train", size = 51):
41
    if split == "train":
42
        tfms = transforms.Compose([
43
            # transforms.CenterCrop(70),
44
            transforms.RandomCrop(size),
45
            # transforms.Resize((size, size)),
46
            transforms.RandomHorizontalFlip(),
47
            transforms.RandomVerticalFlip(),
48
            # transforms.RandomRotation(90),
49
            # transforms.Resize((224, 224)),
50
            # transforms.RandomResizedCrop(size, scale = (.9, 1)),
51
            # transforms.RandomRotation(12),
52
            # transforms.Resize((224, 224)),  # resize the image to 64x64 (remove if images are already 64x64),
53
            # transforms.RandomHorizontalFlip(),  # randomly flip image horizontally
54
            # transforms.RandomAffine(10, translate=(.1, .1), scale=(.1, .1), shear=.1, resample=False, fillcolor=0),
55
            transforms.ToTensor()
56
            # normalize_2d
57
            ])  # transform it into a torch tensor
58
    
59
    else:
60
        tfms = transforms.Compose([
61
            transforms.CenterCrop(size),
62
            # transforms.Resize((size, size)),
63
            transforms.ToTensor()
64
            # normalize_2d
65
            ])
66
    
67
    return tfms
68
69
def normalize_2d(x):
70
    return x / 255
71
72
def normalize(x, params=None):
73
    if params is None:
74
        MIN_BOUND = -1000.; MAX_BOUND = 600.0; PIXEL_MEAN = .25
75
    else:
76
        MIN_BOUND = params.hu_min; MAX_BOUND = params.hu_max; PIXEL_MEAN = params.pix_mean
77
    x = (x - MIN_BOUND) / (MAX_BOUND - MIN_BOUND)
78
    x[x > (1 - PIXEL_MEAN)] = 1.
79
    x[x < (0 - PIXEL_MEAN)] - 0.
80
    return x
81
82
def random_crop(x, num_vox=3):
83
    starts = np.random.choice(range(num_vox), replace=True, size=(x.ndim,))
84
    ends = x.shape - (num_vox - starts)
85
    for i in range(x.ndim):
86
        x = x.take(indices=range(starts[i],ends[i]), axis=i)
87
    return x
88
89
def unpad(x, n=2):
90
    """
91
    Skim off n-entries in 3 dimensions
92
    """
93
    assert type(x) is np.ndarray
94
    if n>0:
95
        x = x[n:-n,n:-n,n:-n]
96
    return x
97
98
class LIDCDataset(Dataset):
99
    """
100
    A standard PyTorch definition of Dataset which defines the functions __len__ and __getitem__.
101
    """
102
    def __init__(self, data_dir, transform, df, setting, params):
103
        """
104
        Store the filenames of the jpgs to use. Specifies transforms to apply on images.
105
106
        Args:
107
            data_dir: (string) directory containing the dataset
108
            transform: (torchvision.transforms) transformation to apply on image
109
        """
110
        self.setting = setting
111
        self.params = params
112
        self.data_dir = data_dir
113
        self.transform = transform
114
        self.df = df
115
        self.mode3d = setting.mode3d
116
        self.covar_mode = setting.covar_mode
117
        self.fase = setting.fase
118
119
        # print(df.head())
120
        # print(df.dtypes)
121
122
        assert ("name" in df.columns)
123
124
        self.name_col  = df.columns.get_loc("name")
125
        self.label_col = df.columns.get_loc(setting.outcome[0])
126
        self.data_cols = list(set(range(len(self.df.columns))) - 
127
                                set([self.name_col, self.label_col]))
128
129
        # split of data, which contains covariate data that is not name or label
130
        if self.covar_mode:
131
            self.data = self.df.loc[:,"t"].values
132
        # if len(self.data_cols) > 0:
133
        #     self.data = self.df.iloc[:,self.data_cols]
134
        df['x_true'] = df.x
135
136
        # calculate a transformation of x to assess robustness of method to different measurements
137
        df['x'] = df.x + params.size_offset
138
        if params.size_measurement == 'area':
139
            pass
140
        elif params.size_measurement == 'diameter':
141
            df['x'] = df.x.values ** (1/2)
142
        elif params.size_measurement == 'volume':
143
            df['x'] = df.x.values ** (3/2)
144
        else:
145
            raise ValueError(f'dont know how to measure size in {params.size_measurement}, pick area, diameter or volume')
146
        # renormalize x to make sure that whatever measurement is used, the MSE is comparable
147
        df['x'] = (df.x - df.x.mean()) / df.x.std()
148
149
150
    def __len__(self):
151
        # return size of dataset
152
        return self.df.shape[0]
153
154
    def __getitem__(self, idx):
155
        """
156
        Fetch index idx image and labels from dataset. Perform transforms on image.
157
158
        Args:
159
            idx: (int) index in [0, 1, ..., size_of_dataset-1]
160
161
        Returns:
162
            image: (Tensor) transformed image
163
            label: (int) corresponding label of image
164
        """
165
        # image = Image.open(self.fpath_dict[self.idx_to_id[idx]]).convert("RGB")  # PIL image
166
        if self.mode3d:
167
            image = np.load(os.path.join(self.data_dir, self.df.iloc[idx, self.name_col]))
168
            image = image.astype(np.float32)
169
            image = self.transform(image)
170
            image = torch.from_numpy(image).unsqueeze(0)
171
172
        else:
173
            img_name = os.path.join(self.data_dir, 
174
                                    self.df.iloc[idx, self.name_col])
175
            image = Image.open(img_name).convert("L") # use rgb for resnet compatibility; L for grayscale
176
            image = self.transform(image)
177
178
        label = torch.from_numpy(np.array(self.df.iloc[idx, self.label_col], dtype = np.float32))
179
180
        sample = {"image": image, 'label': label}
181
        
182
        for variable in ["x", "y", "z", "t", 'x_true']:
183
            if variable in self.df.columns:
184
                sample[variable] = self.df[variable].values[idx].astype(np.float32)
185
186
        if self.setting.fase == "feature":
187
            sample[self.setting.outcome[0]] = self.df[self.setting.outcome[0]].values[idx].astype(np.float32)
188
189
        return sample
190
191
def fetch_dataloader(args, params, setting, types = ["train"], df = None):
192
    """
193
    Fetches the DataLoader object for each type in types from data_dir.
194
195
    Args:
196
        types: (list) has one or more of 'train', 'val', 'test' depending on which data is required
197
        data_dir: (string) directory containing the dataset
198
        df: pandas dataframe containing at least name, label and split
199
        params: (Params) hyperparameters
200
201
    Returns:
202
        data: (dict) contains the DataLoader object for each type in types
203
    """
204
    if setting.gen_model == "":
205
        if setting.mode3d:
206
            data_dir = "data"
207
        else:
208
            data_dir = "slices"
209
    else:
210
        data_dir = os.path.join(setting.home, "data")
211
212
    if df is None:
213
        df = pd.read_csv(os.path.join(data_dir, "labels.csv"))
214
    dataloaders = {}
215
216
    if not setting.mode3d:
217
        pass
218
        # print(df.name.tolist()[:5])
219
        # df["name"] = df.apply(lambda x: os.path.join(x["split"], x["name"]), axis=1)
220
        # print(df.name.tolist()[:5])
221
222
    # make sure the dataframe has no index
223
    df_cols = df.columns
224
    df = df.reset_index()
225
    df = df[df_cols]
226
227
    try:
228
        assert setting.outcome[0] in df.columns
229
    except:
230
        print(f"outcome {setting.outcome[0]} not in df.columns:")
231
        print("\n".join(df.columns))
232
        raise
233
234
235
    if "split" in df.columns:
236
        splits = [x for x in types if x in df.split.unique().tolist()]
237
    else:
238
        df["split"] = types[0]
239
        splits = types
240
241
    df_grp = df.groupby("split")
242
243
    # for split in ['train', 'val', 'test']:
244
    for split, df_split in df_grp:
245
        df_split = df_split.drop("split", axis = 1)
246
        if split in types:
247
            # path = os.path.join(data_dir, split)
248
            path = data_dir
249
            if setting.mode3d:
250
                tfms = get_tfms_3d(split, params)
251
                # tfms = []
252
            else:
253
                tfms = get_tfms(split, params.size)
254
255
            # use the train_transformer if training data, else use eval_transformer without random flip
256
            if split == 'train':
257
                dl = DataLoader(LIDCDataset(path, tfms, df_split, setting, params), 
258
                                        shuffle=True,
259
                                        num_workers=params.num_workers,
260
                                        batch_size=params.batch_size, 
261
                                        pin_memory=params.cuda)
262
                                        # batch_size = batch_size,
263
                                        # num_workers=2,
264
                                        # pin_memory=True)
265
            else:
266
                # dl = DataLoader(SEGMENTATIONDataset(path, eval_transformer, df[df.split.isin([split])]), 
267
                dl = DataLoader(LIDCDataset(path, tfms, df_split, setting, params), 
268
                                batch_size=params.batch_size,
269
                                num_workers=params.num_workers,
270
                                shuffle=False,
271
                                pin_memory=params.cuda)
272
                                # batch_size = batch_size,
273
                                # num_workers=2,
274
                                # pin_memory=True)
275
276
            dataloaders[split] = dl
277
278
    return dataloaders