|
a |
|
b/lungs/utils.py |
|
|
1 |
# pylint: disable=missing-docstring |
|
|
2 |
import os |
|
|
3 |
import time |
|
|
4 |
import numpy |
|
|
5 |
from six.moves import xrange # pylint: disable=redefined-builtin |
|
|
6 |
import tensorflow as tf |
|
|
7 |
import math |
|
|
8 |
import numpy as np |
|
|
9 |
import gdown |
|
|
10 |
from sklearn.metrics import roc_auc_score, roc_curve |
|
|
11 |
import scikitplot as skplt |
|
|
12 |
import matplotlib.pyplot as plt |
|
|
13 |
from matplotlib.pyplot import figure |
|
|
14 |
from os.path import join |
|
|
15 |
import seaborn as sns |
|
|
16 |
sns.set_style("darkgrid") |
|
|
17 |
|
|
|
18 |
REMOTE_CKPTS = { |
|
|
19 |
'cancer_fine_tuned': {'url': '1Zc8KdEz9JUfkT1ZsG9ELYReUPbVapbQC', 'md5': 'cd5271617e090859f73a727da81cc2e3'}, |
|
|
20 |
'i3d_imagenet': {'url': '1FMWHGFYPjuvpgzkGm-_gYKdXpmv5fOq2', 'md5': 'f1408b50e5871153516fe87932121745'} |
|
|
21 |
} |
|
|
22 |
|
|
|
23 |
def load_pretrained_ckpt(ckpt, data_dir): |
|
|
24 |
if ckpt in REMOTE_CKPTS: |
|
|
25 |
download_ckpt(data_dir, ckpt, REMOTE_CKPTS[ckpt]) |
|
|
26 |
|
|
|
27 |
# Load a pre-defined ckpt or a ckpt from path |
|
|
28 |
predefined = join(data_dir, 'checkpoints', ckpt) |
|
|
29 |
ckpt_dir = predefined if os.path.exists(predefined) else ckpt |
|
|
30 |
|
|
|
31 |
pre_trained_ckpt = join(ckpt_dir, 'model.ckpt') |
|
|
32 |
print('\nINFO: Loading pre-trained model:', pre_trained_ckpt) |
|
|
33 |
return pre_trained_ckpt |
|
|
34 |
|
|
|
35 |
def download_ckpt(data_dir, name, download_info): |
|
|
36 |
if os.path.exists(join(data_dir, 'checkpoints', name)): |
|
|
37 |
print('\nINFO: {} model already downloaded.'.format(name)) |
|
|
38 |
else: |
|
|
39 |
print('\nINFO: Downloading {} model...'.format(name)) |
|
|
40 |
url = 'https://drive.google.com/uc?id=' + download_info['url'] |
|
|
41 |
zip_output = join(data_dir, 'checkpoints', name + '.zip') |
|
|
42 |
md5 = download_info['md5'] |
|
|
43 |
gdown.cached_download(url, zip_output, md5=md5, postprocess=gdown.extractall, quiet=True) |
|
|
44 |
os.remove(zip_output) |
|
|
45 |
|
|
|
46 |
def pretty_print_floats(lst): |
|
|
47 |
return ', '.join(['{:.3f}'.format(_) for _ in lst]) |
|
|
48 |
|
|
|
49 |
def load_npz_as_list(base_dir, npz_file): |
|
|
50 |
return np.load(join(base_dir, npz_file))['arr_0'].tolist() |
|
|
51 |
|
|
|
52 |
def plot_loss(val_loss, tr_loss, plots_dir): |
|
|
53 |
figure(num=None, figsize=(16, 8), dpi=100) |
|
|
54 |
title = 'Training and Validation Loss' |
|
|
55 |
epochs = range(1, len(val_loss) + 1) |
|
|
56 |
plt.plot(epochs, val_loss, label='Val. Loss') |
|
|
57 |
plt.plot(epochs, tr_loss, label='Train Loss') |
|
|
58 |
plt.title(title) |
|
|
59 |
plt.xlabel('Epoch') |
|
|
60 |
plt.ylabel('Loss') |
|
|
61 |
plt.legend() |
|
|
62 |
plt.show() |
|
|
63 |
plt.savefig(join(plots_dir, title + '.png'), bbox_inches='tight') |
|
|
64 |
|
|
|
65 |
def plot_acc_auc(val_acc, tr_acc, val_auc, tr_auc, plots_dir): |
|
|
66 |
figure(num=None, figsize=(16, 8), dpi=100) |
|
|
67 |
title = 'Accuracy and AUC' |
|
|
68 |
epochs = range(1, len(val_acc) + 1) |
|
|
69 |
plt.plot(epochs, val_acc, label='Val. Accuracy') |
|
|
70 |
plt.plot(epochs, tr_acc, label='Train Accuracy') |
|
|
71 |
plt.plot(epochs, val_auc, label='Val. AUC') |
|
|
72 |
plt.plot(epochs, tr_auc, label='Train AUC') |
|
|
73 |
plt.title(title) |
|
|
74 |
plt.xlabel('Epoch') |
|
|
75 |
plt.ylabel('Score') |
|
|
76 |
plt.legend() |
|
|
77 |
plt.savefig(join(plots_dir, title + '.png'), bbox_inches='tight') |
|
|
78 |
|
|
|
79 |
def calc_plot_epoch_auc_roc(y, y_probs, title, plots_dir, verbose=False): |
|
|
80 |
y_prob_2_classes = [(1 - p, p) for p in y_probs] |
|
|
81 |
fpr, tpr, th = roc_curve(y, y_probs) |
|
|
82 |
if verbose: |
|
|
83 |
print('TPR:', pretty_print_floats(tpr)) |
|
|
84 |
print('FPR:', pretty_print_floats(fpr)) |
|
|
85 |
print('TH: ', pretty_print_floats(th), '\n') |
|
|
86 |
auc = roc_auc_score(y, y_probs) |
|
|
87 |
title = title + ', AUC={:.3f}'.format(auc) |
|
|
88 |
skplt.metrics.plot_roc(y, y_prob_2_classes, classes_to_plot=[], |
|
|
89 |
title= title, |
|
|
90 |
figsize=(7, 7), plot_micro=False, plot_macro=True, |
|
|
91 |
title_fontsize=15, text_fontsize=13) |
|
|
92 |
plt.show() |
|
|
93 |
plt.savefig(join(plots_dir, title) + '.png', bbox_inches='tight') |
|
|
94 |
|
|
|
95 |
def load_and_plot_epoch_auc(metrics_dir, epoch, val_true, tr_true, plots_dir): |
|
|
96 |
val_preds_epoch = load_npz_as_list(metrics_dir, 'val_preds/epoch_' + str(epoch) + '.npz') |
|
|
97 |
calc_plot_epoch_auc_roc(val_true, val_preds_epoch, |
|
|
98 |
'Val. ROC for Epoch {}'.format(epoch), plots_dir) |
|
|
99 |
|
|
|
100 |
tr_preds_epoch = load_npz_as_list(metrics_dir, 'tr_preds/epoch_' + str(epoch) + '.npz') |
|
|
101 |
calc_plot_epoch_auc_roc(tr_true, tr_preds_epoch, |
|
|
102 |
'Train ROC for Epoch {}'.format(epoch), plots_dir) |
|
|
103 |
|
|
|
104 |
def plot_metrics(epoch, metrics_dir, plots_dir): |
|
|
105 |
val_loss = load_npz_as_list(metrics_dir, 'val_loss.npz') |
|
|
106 |
val_acc = load_npz_as_list(metrics_dir, 'val_acc.npz') |
|
|
107 |
val_auc = load_npz_as_list(metrics_dir, 'val_auc.npz') |
|
|
108 |
val_true = load_npz_as_list(metrics_dir, 'val_true.npz') |
|
|
109 |
|
|
|
110 |
tr_loss = load_npz_as_list(metrics_dir, 'tr_loss.npz') |
|
|
111 |
tr_acc = load_npz_as_list(metrics_dir, 'tr_acc.npz') |
|
|
112 |
tr_auc = load_npz_as_list(metrics_dir, 'tr_auc.npz') |
|
|
113 |
tr_true = load_npz_as_list(metrics_dir, 'tr_true.npz') |
|
|
114 |
|
|
|
115 |
plot_loss(val_loss, tr_loss, plots_dir) |
|
|
116 |
plot_acc_auc(val_acc, tr_acc, val_auc, tr_auc, plots_dir) |
|
|
117 |
load_and_plot_epoch_auc(metrics_dir, epoch, val_true, tr_true, plots_dir) |
|
|
118 |
|
|
|
119 |
def write_metrics(metrics, tr_metrics, val_metrics, metrics_dir, epoch, verbose=False): |
|
|
120 |
for (loss, acc, auc, preds, _), ds in ((tr_metrics, 'tr'), (val_metrics, 'val')): |
|
|
121 |
for metric, key in [(loss, 'loss'), (acc, 'acc'), (auc, 'auc'), (preds, 'preds')]: |
|
|
122 |
name = ds + '_' + key |
|
|
123 |
metrics[name].append(metric) |
|
|
124 |
write_number_list(metrics[name], join(metrics_dir, name)) |
|
|
125 |
write_number_list(preds, join(metrics_dir, ds + '_preds', 'epoch_{}'.format(epoch)), verbose=verbose) |
|
|
126 |
|
|
|
127 |
def apply_window(volume, axis=4): |
|
|
128 |
# Windowing |
|
|
129 |
# Our values currently range from -1024 to around 2000. |
|
|
130 |
# Anything above 400 is not interesting to us, as these are simply bones with different radiodensity. |
|
|
131 |
# A commonly used set of thresholds in Lungs LDCT to normalize between are -1000 and 400. |
|
|
132 |
min_bound = -1000.0 |
|
|
133 |
max_bound = 400.0 |
|
|
134 |
volume = (volume - min_bound) / (max_bound - min_bound) |
|
|
135 |
volume[volume>1] = 1. |
|
|
136 |
volume[volume<0] = 0. |
|
|
137 |
|
|
|
138 |
# Normalize rgb values to [-1, 1] |
|
|
139 |
volume = (volume * 2) - 1 |
|
|
140 |
res = np.stack((volume, volume, volume), axis=axis) |
|
|
141 |
return res.astype(np.float32) |
|
|
142 |
|
|
|
143 |
def write_number_list(lst, f_name, verbose=False): |
|
|
144 |
if verbose: |
|
|
145 |
print('INFO: Saving ' + f_name + '.npz ...') |
|
|
146 |
print(lst) |
|
|
147 |
np.savez(f_name + '.npz', np.array(lst)) |
|
|
148 |
|
|
|
149 |
def batcher(iterable, batch_size=1): |
|
|
150 |
iter_len = len(iterable) |
|
|
151 |
for i in range(0, iter_len, batch_size): |
|
|
152 |
yield iterable[i: min(i + batch_size, iter_len)] |
|
|
153 |
|
|
|
154 |
def load_data_list(path): |
|
|
155 |
coupled_data = [] |
|
|
156 |
with open(path) as file_list_fp: |
|
|
157 |
for line in file_list_fp: |
|
|
158 |
volume_path, label = line.split() |
|
|
159 |
coupled_data.append((volume_path, int(label))) |
|
|
160 |
return coupled_data |
|
|
161 |
|
|
|
162 |
def get_list_labels(coupled_data): |
|
|
163 |
return np.array([l for _, l in coupled_data]).astype(np.int64) |
|
|
164 |
|
|
|
165 |
def placeholder_inputs(num_slices, crop_size, rgb_channels=3): |
|
|
166 |
"""Generate placeholder variables to represent the input tensors. |
|
|
167 |
|
|
|
168 |
These placeholders are used as inputs by the rest of the model building |
|
|
169 |
code and will be fed from the downloaded data in the .run() loop, below. |
|
|
170 |
|
|
|
171 |
Args: |
|
|
172 |
num_slices: The num of slices per volume. |
|
|
173 |
crop_size: The crop size of per volume. |
|
|
174 |
channels: The number of RGB input channels per volume. |
|
|
175 |
|
|
|
176 |
Returns: |
|
|
177 |
volumes_placeholder: volumes placeholder. |
|
|
178 |
labels_placeholder: Labels placeholder. |
|
|
179 |
""" |
|
|
180 |
# Note that the shapes of the placeholders match the shapes of the full |
|
|
181 |
# volume and label tensors, except the first dimension is now batch_size |
|
|
182 |
# rather than the full size of the train or test data sets. |
|
|
183 |
volumes_placeholder = tf.placeholder(tf.float32, shape=(None, |
|
|
184 |
num_slices, |
|
|
185 |
crop_size, |
|
|
186 |
crop_size, |
|
|
187 |
rgb_channels)) |
|
|
188 |
labels_placeholder = tf.placeholder(tf.int64, shape=(None)) |
|
|
189 |
is_training = tf.placeholder(tf.bool) |
|
|
190 |
return volumes_placeholder, labels_placeholder, is_training |
|
|
191 |
|
|
|
192 |
def focal_loss(logits, labels, alpha=0.75, gamma=2): |
|
|
193 |
"""Compute focal loss for binary classification. |
|
|
194 |
|
|
|
195 |
Args: |
|
|
196 |
labels: A int32 tensor of shape [batch_size]. |
|
|
197 |
logits: A float32 tensor of shape [batch_size]. |
|
|
198 |
alpha: A scalar for focal loss alpha hyper-parameter. If positive samples number |
|
|
199 |
> negtive samples number, alpha < 0.5 and vice versa. |
|
|
200 |
gamma: A scalar for focal loss gamma hyper-parameter. |
|
|
201 |
Returns: |
|
|
202 |
A tensor of the same shape as `labels`. |
|
|
203 |
""" |
|
|
204 |
y_pred = tf.nn.sigmoid(logits) |
|
|
205 |
labels = tf.to_float(labels) |
|
|
206 |
losses = -(labels * (1 - alpha) * ((1 - y_pred) * gamma) * tf.log(y_pred)) - \ |
|
|
207 |
(1 - labels) * alpha * (y_pred ** gamma) * tf.log(1 - y_pred) |
|
|
208 |
return tf.reduce_mean(losses) |
|
|
209 |
|
|
|
210 |
def cross_entropy_loss(logits, labels): |
|
|
211 |
# pylint: disable=no-value-for-parameter |
|
|
212 |
# pylint: disable=unexpected-keyword-arg |
|
|
213 |
cross_entropy_mean = tf.reduce_mean( |
|
|
214 |
tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits) |
|
|
215 |
) |
|
|
216 |
return cross_entropy_mean |
|
|
217 |
|
|
|
218 |
def accuracy(logit, labels): |
|
|
219 |
correct_pred = tf.equal(tf.argmax(logit, 1), labels) |
|
|
220 |
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) |
|
|
221 |
return accuracy |
|
|
222 |
|
|
|
223 |
def get_preds(preds): |
|
|
224 |
return preds[:, 1] |
|
|
225 |
|
|
|
226 |
def get_logits(logits): |
|
|
227 |
return logits |