|
a |
|
b/dataset.py |
|
|
1 |
import numpy as np |
|
|
2 |
import pandas as pd |
|
|
3 |
import os |
|
|
4 |
import glob |
|
|
5 |
import torch |
|
|
6 |
import torch.nn as nn |
|
|
7 |
import random |
|
|
8 |
import itertools |
|
|
9 |
from sklearn.metrics import roc_auc_score |
|
|
10 |
import torch.nn.functional as F |
|
|
11 |
from scipy import interpolate |
|
|
12 |
|
|
|
13 |
data_path_img = '<path-to-binary-image-files>' |
|
|
14 |
data_file = '<path-and_name-of-csv-datafile>' |
|
|
15 |
image_size = 310 # this is the size (310x310) of the MIP images used in the paper. Adjust to fit your images. |
|
|
16 |
|
|
|
17 |
def get_datasets_singleview(transform=None, norm=None, balance=False, split_index=0): |
|
|
18 |
split = 'split'+str(split_index) |
|
|
19 |
df = pd.read_csv(data_file) |
|
|
20 |
# Balance weight |
|
|
21 |
weight_neg_pos = [1-(df.target==0).sum()/len(df), 1-(df.target==1).sum()/len(df)] |
|
|
22 |
# Read split |
|
|
23 |
df_train = df[df[split]=='train'].drop(df.filter(regex='split').columns,axis=1) |
|
|
24 |
train_dset = dataset_singleview(df_train, transform=transform, norm=norm) |
|
|
25 |
trainval_dset = dataset_singleview_center(df_train, transform=None, norm=norm) |
|
|
26 |
# Val split |
|
|
27 |
df_val = df[df[split]=='val'].drop(df.filter(regex='split').columns,axis=1) |
|
|
28 |
val_dset = dataset_singleview_center(df_val, transform=None, norm=norm) |
|
|
29 |
# Test split |
|
|
30 |
df_test = df[df[split]=='test'].drop(df.filter(regex='split').columns,axis=1) |
|
|
31 |
test_dset = dataset_singleview_center(df_test, transform=None, norm=norm) |
|
|
32 |
return train_dset,trainval_dset,val_dset,test_dset,weight_neg_pos |
|
|
33 |
|
|
|
34 |
def get_bbox(img): |
|
|
35 |
rows = np.any(img, axis=1) |
|
|
36 |
cols = np.any(img, axis=0) |
|
|
37 |
rmin, rmax = np.where(rows)[0][[0, -1]] |
|
|
38 |
cmin, cmax = np.where(cols)[0][[0, -1]] |
|
|
39 |
return img[rmin:rmax, cmin:cmax] |
|
|
40 |
|
|
|
41 |
def pad2square_random(image, size): |
|
|
42 |
out = np.zeros((size,size)) |
|
|
43 |
# Sample offset |
|
|
44 |
maxr = size - image.shape[0] |
|
|
45 |
maxc = size - image.shape[1] |
|
|
46 |
offsetc = np.random.randint(0, maxc) |
|
|
47 |
offsetr = np.random.randint(0, maxr) |
|
|
48 |
# Place image |
|
|
49 |
out[offsetr:offsetr+image.shape[0], offsetc:offsetc+image.shape[1]] = image |
|
|
50 |
return out |
|
|
51 |
|
|
|
52 |
def pad2square_center(image, size): |
|
|
53 |
# Place image |
|
|
54 |
out = np.zeros((size,size)) |
|
|
55 |
out[int((size-image.shape[0])/2):int((size-image.shape[0])/2)+image.shape[0],int((size-image.shape[1])/2):int((size-image.shape[1])/2)+image.shape[1]] = image |
|
|
56 |
return out |
|
|
57 |
|
|
|
58 |
def clip_and_normalize_SUVimage(img): |
|
|
59 |
mu = 2.13 |
|
|
60 |
std = 3.39 |
|
|
61 |
q = 30.00 |
|
|
62 |
img = np.clip(img,0.,q) |
|
|
63 |
return (img-mu)/std |
|
|
64 |
|
|
|
65 |
def get_image(df, transform, norm): |
|
|
66 |
name = glob.glob(os.path.join(data_path_img,df.filename)) |
|
|
67 |
if not name: |
|
|
68 |
print('File not found:',name) |
|
|
69 |
img = np.fromfile(os.path.join(data_path_img, name[0]), dtype='float32') |
|
|
70 |
img = np.reshape( img,[df.matrix_size_1, df.matrix_size_2]) |
|
|
71 |
# Find bbox |
|
|
72 |
img = get_bbox(img) |
|
|
73 |
# Pad randomly |
|
|
74 |
img = pad2square_random(img, image_size) |
|
|
75 |
# Norm |
|
|
76 |
if norm: |
|
|
77 |
img = clip_and_normalize_SUVimage(img) |
|
|
78 |
# Make Tensors |
|
|
79 |
img = torch.FloatTensor(img).unsqueeze(0) |
|
|
80 |
if transform is not None: |
|
|
81 |
img = [transform(x) for x in img] |
|
|
82 |
img = torch.stack(img) |
|
|
83 |
return img |
|
|
84 |
|
|
|
85 |
def get_image_center(df, transform, norm): |
|
|
86 |
name = glob.glob(os.path.join(data_path_img,df.filename)) |
|
|
87 |
if not name: |
|
|
88 |
print('File not found:',name) |
|
|
89 |
img = np.fromfile(os.path.join(data_path_img, name[0]), dtype='float32') |
|
|
90 |
img = np.reshape( img,[df.matrix_size_1, df.matrix_size_2]) |
|
|
91 |
# Find bbox |
|
|
92 |
img = get_bbox(img) |
|
|
93 |
# Pad randomly |
|
|
94 |
img = pad2square_center(img, image_size) |
|
|
95 |
# Norm |
|
|
96 |
if norm: |
|
|
97 |
img = clip_and_normalize_SUVimage(img) |
|
|
98 |
# Make Tensors |
|
|
99 |
img = torch.FloatTensor(img).unsqueeze(0) |
|
|
100 |
if transform is not None: |
|
|
101 |
img = [transform(x) for x in img] |
|
|
102 |
img = torch.stack(img) |
|
|
103 |
return img |
|
|
104 |
|
|
|
105 |
class dataset_singleview(torch.utils.data.Dataset): |
|
|
106 |
def __init__(self, df, transform=None, norm=False): |
|
|
107 |
self.df = df.copy() |
|
|
108 |
self.transform = transform |
|
|
109 |
self.norm = norm |
|
|
110 |
|
|
|
111 |
def errors(self, probs): |
|
|
112 |
df = self.df.copy() |
|
|
113 |
df['p'] = probs |
|
|
114 |
df['pred'] = (df.p >= 0.5).astype(int) |
|
|
115 |
fpr = ((df.pred!=df.target) & (df.target==0)).sum() / (df.target==0).sum() |
|
|
116 |
fnr = ((df.pred!=df.target) & (df.target==1)).sum() / (df.target==1).sum() |
|
|
117 |
ber = (fpr + fnr) / 2. |
|
|
118 |
## Calculate auc |
|
|
119 |
auc = roc_auc_score(df.target, df.p) |
|
|
120 |
return auc, ber, fpr, fnr |
|
|
121 |
|
|
|
122 |
def __getitem__(self, index): |
|
|
123 |
df = self.df.iloc[index] |
|
|
124 |
# Read image |
|
|
125 |
img = get_image(df, self.transform, self.norm) |
|
|
126 |
return img, df.target |
|
|
127 |
|
|
|
128 |
def __len__(self): |
|
|
129 |
return len(self.df) |
|
|
130 |
|
|
|
131 |
class dataset_singleview_center(torch.utils.data.Dataset): |
|
|
132 |
def __init__(self, df, transform=None, norm=False): |
|
|
133 |
self.df = df.copy() |
|
|
134 |
self.transform = transform |
|
|
135 |
self.norm = norm |
|
|
136 |
|
|
|
137 |
def errors(self, probs): |
|
|
138 |
df = self.df.copy() |
|
|
139 |
df['p'] = probs |
|
|
140 |
df['pred'] = (df.p >= 0.5).astype(int) |
|
|
141 |
fpr = ((df.pred!=df.target) & (df.target==0)).sum() / (df.target==0).sum() |
|
|
142 |
fnr = ((df.pred!=df.target) & (df.target==1)).sum() / (df.target==1).sum() |
|
|
143 |
ber = (fpr + fnr) / 2. |
|
|
144 |
## Calculate auc |
|
|
145 |
auc = roc_auc_score(df.target, df.p) |
|
|
146 |
return auc, ber, fpr, fnr |
|
|
147 |
|
|
|
148 |
def __getitem__(self, index): |
|
|
149 |
df = self.df.iloc[index] |
|
|
150 |
# Read image |
|
|
151 |
img = get_image_center(df, self.transform, self.norm) |
|
|
152 |
return img, df.target |
|
|
153 |
|
|
|
154 |
def __len__(self): |
|
|
155 |
return len(self.df) |
|
|
156 |
|
|
|
157 |
class RandomFlip(object): |
|
|
158 |
"""Randomly flip the 2D image. |
|
|
159 |
""" |
|
|
160 |
def __call__(self, image): |
|
|
161 |
# Random flip: none, 0=vertical, 1=horizontal |
|
|
162 |
flip = random.choice((None,0,1)) |
|
|
163 |
if flip is not None: |
|
|
164 |
if flip==0: |
|
|
165 |
image = image[range(image.shape[flip]-1,-1,-1),:] |
|
|
166 |
elif flip==1: |
|
|
167 |
image = image[:,range(image.shape[flip]-1,-1,-1)] |
|
|
168 |
return image |
|
|
169 |
|
|
|
170 |
class RandomFlipLeftRight(object): |
|
|
171 |
"""Randomly flip all channels of the 2D image. |
|
|
172 |
""" |
|
|
173 |
def __call__(self, image): |
|
|
174 |
# Random flip: none, 0=vertical, 1=horizontal |
|
|
175 |
flip = random.choice((None,1)) |
|
|
176 |
if flip is not None: |
|
|
177 |
image = image[:,range(image.shape[1]-1,-1,-1)] |
|
|
178 |
return image |
|
|
179 |
|
|
|
180 |
class RandomRot90(object): |
|
|
181 |
"""Randomly rotate the 2D image by n*90 degrees. |
|
|
182 |
""" |
|
|
183 |
def __call__(self, image): |
|
|
184 |
# Random 90 rotation |
|
|
185 |
rot = random.randint(0,3) |
|
|
186 |
if rot != 0: |
|
|
187 |
image = torch.rot90(image, rot, (0,1)) |
|
|
188 |
return image |
|
|
189 |
|
|
|
190 |
class RandomScale(object): |
|
|
191 |
"""Randomly scale the 2D image. |
|
|
192 |
""" |
|
|
193 |
def __call__(self, image): |
|
|
194 |
scale = np.random.uniform(low=0.85, high=1.15, size=1) |
|
|
195 |
image = image*scale[0] |
|
|
196 |
return image |
|
|
197 |
|
|
|
198 |
class RandomNoise(object): |
|
|
199 |
"""Randomly gauss noise the 2D image. |
|
|
200 |
""" |
|
|
201 |
def __call__(self, image): |
|
|
202 |
noise = random.choice((None,1)) |
|
|
203 |
if noise is not None: |
|
|
204 |
image[image<0] = 0 |
|
|
205 |
level = np.random.uniform(low=0.001, high=0.02, size=1) |
|
|
206 |
sigma = np.random.uniform(low=0.01, high=0.1, size=1) |
|
|
207 |
sigma = sigma[0]*image+level[0] |
|
|
208 |
gauss = torch.normal(0,sigma) |
|
|
209 |
image = image + gauss |
|
|
210 |
image[image<0] = 0 |
|
|
211 |
return image |