|
a |
|
b/model/data_loader.py |
|
|
1 |
import random |
|
|
2 |
import os |
|
|
3 |
|
|
|
4 |
from PIL import Image |
|
|
5 |
import torch |
|
|
6 |
from torch.utils.data import Dataset, DataLoader |
|
|
7 |
import torchvision.transforms as transforms |
|
|
8 |
|
|
|
9 |
import pandas as pd |
|
|
10 |
import re |
|
|
11 |
import numpy as np |
|
|
12 |
import utils |
|
|
13 |
|
|
|
14 |
# borrowed from http://pytorch.org/tutorials/advanced/neural_style_tutorial.html |
|
|
15 |
# and http://pytorch.org/tutorials/beginner/data_loading_tutorial.html |
|
|
16 |
# define a training image loader that specifies transforms on images. See documentation for more details. |
|
|
17 |
|
|
|
18 |
def get_tfms_3d(split, params): |
|
|
19 |
if split == "train": |
|
|
20 |
def tfms(x): |
|
|
21 |
# x = random_3d_crop(x, params.n_crop_vox) |
|
|
22 |
|
|
|
23 |
x = normalize(x, params) |
|
|
24 |
x = random_crop(x, params.n_crop_vox) |
|
|
25 |
# batchgenerators transforms expect bath dim and channel dim |
|
|
26 |
# add these and squeeze off later |
|
|
27 |
x = np.expand_dims(np.expand_dims(x, 0), 0) |
|
|
28 |
# x = transforms3d.spatial_transforms.augment_mirroring(x)[0] |
|
|
29 |
x = np.squeeze(x) |
|
|
30 |
|
|
|
31 |
return x |
|
|
32 |
else: |
|
|
33 |
def tfms(x): |
|
|
34 |
x = normalize(x, params) |
|
|
35 |
x = unpad(x, int(params.n_crop_vox/2)) |
|
|
36 |
return x |
|
|
37 |
|
|
|
38 |
return tfms |
|
|
39 |
|
|
|
40 |
def get_tfms(split = "train", size = 51): |
|
|
41 |
if split == "train": |
|
|
42 |
tfms = transforms.Compose([ |
|
|
43 |
# transforms.CenterCrop(70), |
|
|
44 |
transforms.RandomCrop(size), |
|
|
45 |
# transforms.Resize((size, size)), |
|
|
46 |
transforms.RandomHorizontalFlip(), |
|
|
47 |
transforms.RandomVerticalFlip(), |
|
|
48 |
# transforms.RandomRotation(90), |
|
|
49 |
# transforms.Resize((224, 224)), |
|
|
50 |
# transforms.RandomResizedCrop(size, scale = (.9, 1)), |
|
|
51 |
# transforms.RandomRotation(12), |
|
|
52 |
# transforms.Resize((224, 224)), # resize the image to 64x64 (remove if images are already 64x64), |
|
|
53 |
# transforms.RandomHorizontalFlip(), # randomly flip image horizontally |
|
|
54 |
# transforms.RandomAffine(10, translate=(.1, .1), scale=(.1, .1), shear=.1, resample=False, fillcolor=0), |
|
|
55 |
transforms.ToTensor() |
|
|
56 |
# normalize_2d |
|
|
57 |
]) # transform it into a torch tensor |
|
|
58 |
|
|
|
59 |
else: |
|
|
60 |
tfms = transforms.Compose([ |
|
|
61 |
transforms.CenterCrop(size), |
|
|
62 |
# transforms.Resize((size, size)), |
|
|
63 |
transforms.ToTensor() |
|
|
64 |
# normalize_2d |
|
|
65 |
]) |
|
|
66 |
|
|
|
67 |
return tfms |
|
|
68 |
|
|
|
69 |
def normalize_2d(x): |
|
|
70 |
return x / 255 |
|
|
71 |
|
|
|
72 |
def normalize(x, params=None): |
|
|
73 |
if params is None: |
|
|
74 |
MIN_BOUND = -1000.; MAX_BOUND = 600.0; PIXEL_MEAN = .25 |
|
|
75 |
else: |
|
|
76 |
MIN_BOUND = params.hu_min; MAX_BOUND = params.hu_max; PIXEL_MEAN = params.pix_mean |
|
|
77 |
x = (x - MIN_BOUND) / (MAX_BOUND - MIN_BOUND) |
|
|
78 |
x[x > (1 - PIXEL_MEAN)] = 1. |
|
|
79 |
x[x < (0 - PIXEL_MEAN)] - 0. |
|
|
80 |
return x |
|
|
81 |
|
|
|
82 |
def random_crop(x, num_vox=3): |
|
|
83 |
starts = np.random.choice(range(num_vox), replace=True, size=(x.ndim,)) |
|
|
84 |
ends = x.shape - (num_vox - starts) |
|
|
85 |
for i in range(x.ndim): |
|
|
86 |
x = x.take(indices=range(starts[i],ends[i]), axis=i) |
|
|
87 |
return x |
|
|
88 |
|
|
|
89 |
def unpad(x, n=2): |
|
|
90 |
""" |
|
|
91 |
Skim off n-entries in 3 dimensions |
|
|
92 |
""" |
|
|
93 |
assert type(x) is np.ndarray |
|
|
94 |
if n>0: |
|
|
95 |
x = x[n:-n,n:-n,n:-n] |
|
|
96 |
return x |
|
|
97 |
|
|
|
98 |
class LIDCDataset(Dataset): |
|
|
99 |
""" |
|
|
100 |
A standard PyTorch definition of Dataset which defines the functions __len__ and __getitem__. |
|
|
101 |
""" |
|
|
102 |
def __init__(self, data_dir, transform, df, setting, params): |
|
|
103 |
""" |
|
|
104 |
Store the filenames of the jpgs to use. Specifies transforms to apply on images. |
|
|
105 |
|
|
|
106 |
Args: |
|
|
107 |
data_dir: (string) directory containing the dataset |
|
|
108 |
transform: (torchvision.transforms) transformation to apply on image |
|
|
109 |
""" |
|
|
110 |
self.setting = setting |
|
|
111 |
self.params = params |
|
|
112 |
self.data_dir = data_dir |
|
|
113 |
self.transform = transform |
|
|
114 |
self.df = df |
|
|
115 |
self.mode3d = setting.mode3d |
|
|
116 |
self.covar_mode = setting.covar_mode |
|
|
117 |
self.fase = setting.fase |
|
|
118 |
|
|
|
119 |
# print(df.head()) |
|
|
120 |
# print(df.dtypes) |
|
|
121 |
|
|
|
122 |
assert ("name" in df.columns) |
|
|
123 |
|
|
|
124 |
self.name_col = df.columns.get_loc("name") |
|
|
125 |
self.label_col = df.columns.get_loc(setting.outcome[0]) |
|
|
126 |
self.data_cols = list(set(range(len(self.df.columns))) - |
|
|
127 |
set([self.name_col, self.label_col])) |
|
|
128 |
|
|
|
129 |
# split of data, which contains covariate data that is not name or label |
|
|
130 |
if self.covar_mode: |
|
|
131 |
self.data = self.df.loc[:,"t"].values |
|
|
132 |
# if len(self.data_cols) > 0: |
|
|
133 |
# self.data = self.df.iloc[:,self.data_cols] |
|
|
134 |
df['x_true'] = df.x |
|
|
135 |
|
|
|
136 |
# calculate a transformation of x to assess robustness of method to different measurements |
|
|
137 |
df['x'] = df.x + params.size_offset |
|
|
138 |
if params.size_measurement == 'area': |
|
|
139 |
pass |
|
|
140 |
elif params.size_measurement == 'diameter': |
|
|
141 |
df['x'] = df.x.values ** (1/2) |
|
|
142 |
elif params.size_measurement == 'volume': |
|
|
143 |
df['x'] = df.x.values ** (3/2) |
|
|
144 |
else: |
|
|
145 |
raise ValueError(f'dont know how to measure size in {params.size_measurement}, pick area, diameter or volume') |
|
|
146 |
# renormalize x to make sure that whatever measurement is used, the MSE is comparable |
|
|
147 |
df['x'] = (df.x - df.x.mean()) / df.x.std() |
|
|
148 |
|
|
|
149 |
|
|
|
150 |
def __len__(self): |
|
|
151 |
# return size of dataset |
|
|
152 |
return self.df.shape[0] |
|
|
153 |
|
|
|
154 |
def __getitem__(self, idx): |
|
|
155 |
""" |
|
|
156 |
Fetch index idx image and labels from dataset. Perform transforms on image. |
|
|
157 |
|
|
|
158 |
Args: |
|
|
159 |
idx: (int) index in [0, 1, ..., size_of_dataset-1] |
|
|
160 |
|
|
|
161 |
Returns: |
|
|
162 |
image: (Tensor) transformed image |
|
|
163 |
label: (int) corresponding label of image |
|
|
164 |
""" |
|
|
165 |
# image = Image.open(self.fpath_dict[self.idx_to_id[idx]]).convert("RGB") # PIL image |
|
|
166 |
if self.mode3d: |
|
|
167 |
image = np.load(os.path.join(self.data_dir, self.df.iloc[idx, self.name_col])) |
|
|
168 |
image = image.astype(np.float32) |
|
|
169 |
image = self.transform(image) |
|
|
170 |
image = torch.from_numpy(image).unsqueeze(0) |
|
|
171 |
|
|
|
172 |
else: |
|
|
173 |
img_name = os.path.join(self.data_dir, |
|
|
174 |
self.df.iloc[idx, self.name_col]) |
|
|
175 |
image = Image.open(img_name).convert("L") # use rgb for resnet compatibility; L for grayscale |
|
|
176 |
image = self.transform(image) |
|
|
177 |
|
|
|
178 |
label = torch.from_numpy(np.array(self.df.iloc[idx, self.label_col], dtype = np.float32)) |
|
|
179 |
|
|
|
180 |
sample = {"image": image, 'label': label} |
|
|
181 |
|
|
|
182 |
for variable in ["x", "y", "z", "t", 'x_true']: |
|
|
183 |
if variable in self.df.columns: |
|
|
184 |
sample[variable] = self.df[variable].values[idx].astype(np.float32) |
|
|
185 |
|
|
|
186 |
if self.setting.fase == "feature": |
|
|
187 |
sample[self.setting.outcome[0]] = self.df[self.setting.outcome[0]].values[idx].astype(np.float32) |
|
|
188 |
|
|
|
189 |
return sample |
|
|
190 |
|
|
|
191 |
def fetch_dataloader(args, params, setting, types = ["train"], df = None): |
|
|
192 |
""" |
|
|
193 |
Fetches the DataLoader object for each type in types from data_dir. |
|
|
194 |
|
|
|
195 |
Args: |
|
|
196 |
types: (list) has one or more of 'train', 'val', 'test' depending on which data is required |
|
|
197 |
data_dir: (string) directory containing the dataset |
|
|
198 |
df: pandas dataframe containing at least name, label and split |
|
|
199 |
params: (Params) hyperparameters |
|
|
200 |
|
|
|
201 |
Returns: |
|
|
202 |
data: (dict) contains the DataLoader object for each type in types |
|
|
203 |
""" |
|
|
204 |
if setting.gen_model == "": |
|
|
205 |
if setting.mode3d: |
|
|
206 |
data_dir = "data" |
|
|
207 |
else: |
|
|
208 |
data_dir = "slices" |
|
|
209 |
else: |
|
|
210 |
data_dir = os.path.join(setting.home, "data") |
|
|
211 |
|
|
|
212 |
if df is None: |
|
|
213 |
df = pd.read_csv(os.path.join(data_dir, "labels.csv")) |
|
|
214 |
dataloaders = {} |
|
|
215 |
|
|
|
216 |
if not setting.mode3d: |
|
|
217 |
pass |
|
|
218 |
# print(df.name.tolist()[:5]) |
|
|
219 |
# df["name"] = df.apply(lambda x: os.path.join(x["split"], x["name"]), axis=1) |
|
|
220 |
# print(df.name.tolist()[:5]) |
|
|
221 |
|
|
|
222 |
# make sure the dataframe has no index |
|
|
223 |
df_cols = df.columns |
|
|
224 |
df = df.reset_index() |
|
|
225 |
df = df[df_cols] |
|
|
226 |
|
|
|
227 |
try: |
|
|
228 |
assert setting.outcome[0] in df.columns |
|
|
229 |
except: |
|
|
230 |
print(f"outcome {setting.outcome[0]} not in df.columns:") |
|
|
231 |
print("\n".join(df.columns)) |
|
|
232 |
raise |
|
|
233 |
|
|
|
234 |
|
|
|
235 |
if "split" in df.columns: |
|
|
236 |
splits = [x for x in types if x in df.split.unique().tolist()] |
|
|
237 |
else: |
|
|
238 |
df["split"] = types[0] |
|
|
239 |
splits = types |
|
|
240 |
|
|
|
241 |
df_grp = df.groupby("split") |
|
|
242 |
|
|
|
243 |
# for split in ['train', 'val', 'test']: |
|
|
244 |
for split, df_split in df_grp: |
|
|
245 |
df_split = df_split.drop("split", axis = 1) |
|
|
246 |
if split in types: |
|
|
247 |
# path = os.path.join(data_dir, split) |
|
|
248 |
path = data_dir |
|
|
249 |
if setting.mode3d: |
|
|
250 |
tfms = get_tfms_3d(split, params) |
|
|
251 |
# tfms = [] |
|
|
252 |
else: |
|
|
253 |
tfms = get_tfms(split, params.size) |
|
|
254 |
|
|
|
255 |
# use the train_transformer if training data, else use eval_transformer without random flip |
|
|
256 |
if split == 'train': |
|
|
257 |
dl = DataLoader(LIDCDataset(path, tfms, df_split, setting, params), |
|
|
258 |
shuffle=True, |
|
|
259 |
num_workers=params.num_workers, |
|
|
260 |
batch_size=params.batch_size, |
|
|
261 |
pin_memory=params.cuda) |
|
|
262 |
# batch_size = batch_size, |
|
|
263 |
# num_workers=2, |
|
|
264 |
# pin_memory=True) |
|
|
265 |
else: |
|
|
266 |
# dl = DataLoader(SEGMENTATIONDataset(path, eval_transformer, df[df.split.isin([split])]), |
|
|
267 |
dl = DataLoader(LIDCDataset(path, tfms, df_split, setting, params), |
|
|
268 |
batch_size=params.batch_size, |
|
|
269 |
num_workers=params.num_workers, |
|
|
270 |
shuffle=False, |
|
|
271 |
pin_memory=params.cuda) |
|
|
272 |
# batch_size = batch_size, |
|
|
273 |
# num_workers=2, |
|
|
274 |
# pin_memory=True) |
|
|
275 |
|
|
|
276 |
dataloaders[split] = dl |
|
|
277 |
|
|
|
278 |
return dataloaders |