|
a |
|
b/lungs/main.py |
|
|
1 |
VERBOSE_TF = False |
|
|
2 |
import os |
|
|
3 |
if not VERBOSE_TF: |
|
|
4 |
import warnings |
|
|
5 |
warnings.filterwarnings('ignore', category=FutureWarning) |
|
|
6 |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' |
|
|
7 |
import tensorflow as tf |
|
|
8 |
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) |
|
|
9 |
else: |
|
|
10 |
import tensorflow as tf |
|
|
11 |
from random import shuffle |
|
|
12 |
import numpy as np |
|
|
13 |
import argparse |
|
|
14 |
from time import time, strftime |
|
|
15 |
from tqdm import tqdm |
|
|
16 |
from os.path import join, dirname, realpath |
|
|
17 |
from collections import defaultdict |
|
|
18 |
from sklearn.metrics import roc_auc_score |
|
|
19 |
from datetime import date |
|
|
20 |
from pathlib import Path |
|
|
21 |
|
|
|
22 |
from lungs.preprocess import preprocess, walk_dicom_dirs, walk_np_files |
|
|
23 |
from lungs import utils |
|
|
24 |
from lungs.i3d import InceptionI3d |
|
|
25 |
|
|
|
26 |
class I3dForCTVolumes: |
|
|
27 |
def __init__(self, args): |
|
|
28 |
self.args = args |
|
|
29 |
|
|
|
30 |
# This is the shape of both dimensions of each slice of the volume. |
|
|
31 |
# The final volume shape fed to the model is [self.args['num_slices, 224, 224] |
|
|
32 |
self.slice_size = 224 |
|
|
33 |
|
|
|
34 |
# pylint: disable=not-context-manager |
|
|
35 |
with tf.Graph().as_default(): |
|
|
36 |
global_step = tf.get_variable( |
|
|
37 |
'global_step', |
|
|
38 |
[], |
|
|
39 |
initializer=tf.constant_initializer(0), |
|
|
40 |
trainable=False |
|
|
41 |
) |
|
|
42 |
|
|
|
43 |
# Placeholders |
|
|
44 |
self.volumes_placeholder, self.labels_placeholder, self.is_training_placeholder = utils.placeholder_inputs( |
|
|
45 |
num_slices=self.args['num_slices'], |
|
|
46 |
crop_size=self.slice_size, |
|
|
47 |
rgb_channels=3 |
|
|
48 |
) |
|
|
49 |
|
|
|
50 |
# Learning rate and optimizer |
|
|
51 |
lr = tf.train.exponential_decay(self.args['lr'], global_step, decay_steps=5000, decay_rate=0.1, staircase=True) |
|
|
52 |
optimizer = tf.train.AdamOptimizer(lr) |
|
|
53 |
|
|
|
54 |
# Init I3D model |
|
|
55 |
with tf.device('/device:' + self.args['device'] + ':0'): |
|
|
56 |
with tf.compat.v1.variable_scope('RGB'): |
|
|
57 |
_, end_points = InceptionI3d(num_classes=2, final_endpoint='Predictions')\ |
|
|
58 |
(self.volumes_placeholder, self.is_training_placeholder, dropout_keep_prob=args['keep_prob']) |
|
|
59 |
self.logits = end_points['Logits'] |
|
|
60 |
self.preds = end_points['Predictions'] |
|
|
61 |
|
|
|
62 |
# Loss function |
|
|
63 |
# self.loss = utils.focal_loss(self.logits[:, 1], self.labels_placeholder) |
|
|
64 |
self.loss = utils.cross_entropy_loss(self.logits, self.labels_placeholder) |
|
|
65 |
|
|
|
66 |
# Evaluation metrics |
|
|
67 |
self.get_preds = utils.get_preds(self.preds) |
|
|
68 |
self.get_logits = utils.get_logits(self.logits) |
|
|
69 |
self.accuracy = utils.accuracy(self.logits, self.labels_placeholder) |
|
|
70 |
|
|
|
71 |
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) |
|
|
72 |
with tf.control_dependencies(update_ops): |
|
|
73 |
grads = optimizer.compute_gradients(self.loss) |
|
|
74 |
apply_gradient = optimizer.apply_gradients(grads, global_step=global_step) |
|
|
75 |
self.train_op = tf.group(apply_gradient) |
|
|
76 |
|
|
|
77 |
# Create a saver for loading pretrained checkpoints. |
|
|
78 |
pretrained_variable_map = {} |
|
|
79 |
for variable in tf.global_variables(): |
|
|
80 |
if variable.name.split('/')[0] == 'RGB' and 'Adam' not in variable.name.split('/')[-1] \ |
|
|
81 |
and variable.name.split('/')[2] != 'Logits': |
|
|
82 |
pretrained_variable_map[variable.name.replace(':0', '')] = variable |
|
|
83 |
self.pretrained_saver = tf.train.Saver(var_list=pretrained_variable_map, reshape=True) |
|
|
84 |
|
|
|
85 |
# Create a saver for writing training checkpoints. |
|
|
86 |
self.saver = tf.train.Saver() |
|
|
87 |
|
|
|
88 |
# Init local and global vars |
|
|
89 |
init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) |
|
|
90 |
|
|
|
91 |
# Create a session for running Ops on the Graph. |
|
|
92 |
run_config = tf.ConfigProto(allow_soft_placement=True) |
|
|
93 |
self.sess = tf.Session(config=run_config) |
|
|
94 |
self.sess.run(init) |
|
|
95 |
|
|
|
96 |
def train_loop(self, train_list, metrics_dir): |
|
|
97 |
train_batches = utils.batcher(train_list, self.args['batch_size']) |
|
|
98 |
for coupled_batch in tqdm(train_batches): |
|
|
99 |
feed_dict, _ = self.process_data_into_to_dict(coupled_batch, is_training=True) |
|
|
100 |
self.sess.run(self.train_op, feed_dict=feed_dict) |
|
|
101 |
|
|
|
102 |
metrics = self.evaluate(train_list, ds='Train') |
|
|
103 |
utils.write_number_list(metrics[-1], join(metrics_dir, 'tr_true'), verbose=self.args['verbose']) |
|
|
104 |
return metrics |
|
|
105 |
|
|
|
106 |
def evaluate(self, coupled_list, ds='Val.'): |
|
|
107 |
coupled_batches = utils.batcher(coupled_list, self.args['batch_size']) |
|
|
108 |
|
|
|
109 |
loss_list, acc_list, preds_list, labels_list = [], [], [], [] |
|
|
110 |
|
|
|
111 |
print('\nINFO: ++++++++++++++++++++ {} Evaluation ++++++++++++++++++++'.format(ds)) |
|
|
112 |
for coupled_batch in tqdm(coupled_batches): |
|
|
113 |
feed_dict, labels = self.process_data_into_to_dict(coupled_batch) |
|
|
114 |
acc, loss, preds = self.sess.run([self.accuracy, self.loss, self.get_preds], feed_dict=feed_dict) |
|
|
115 |
loss_list.append(loss) |
|
|
116 |
acc_list.append(acc) |
|
|
117 |
preds_list.extend(preds) |
|
|
118 |
labels_list.extend(labels) |
|
|
119 |
|
|
|
120 |
if self.args['verbose']: |
|
|
121 |
print('\nDEBUG: {}. Preds/Labels: {}'.format(ds, list(zip(preds_list, labels_list)))) |
|
|
122 |
print('\nDEBUG: {} Batch accuracy/loss: {}'.format(ds, list(zip(acc_list, loss_list)))) |
|
|
123 |
|
|
|
124 |
mean_acc = np.mean(acc_list) |
|
|
125 |
mean_loss = np.mean(loss_list) |
|
|
126 |
auc_score = roc_auc_score(labels_list, preds_list) |
|
|
127 |
print('\n' + '=' * 34) |
|
|
128 |
print("|| INFO: {} Accuracy: {:.4f} ||".format(ds, mean_acc)) |
|
|
129 |
print("|| INFO: {} Loss: {:.4f} ||".format(ds, mean_loss)) |
|
|
130 |
print("|| INFO: {} AUC: {:.4f} ||".format(ds, auc_score)) |
|
|
131 |
print('=' * 34) |
|
|
132 |
return mean_loss, mean_acc, auc_score, preds_list, labels_list |
|
|
133 |
|
|
|
134 |
def predict(self, inference_data): |
|
|
135 |
errors_map = defaultdict(int) |
|
|
136 |
volume_iterator = walk_np_files(inference_data) if self.args['preprocessed'] else walk_dicom_dirs(inference_data) |
|
|
137 |
|
|
|
138 |
for i, volume_path in enumerate(volume_iterator): |
|
|
139 |
try: |
|
|
140 |
if not self.args['preprocessed']: |
|
|
141 |
print('\nINFO: Preprocessing volume...') |
|
|
142 |
preprocessed, _ = preprocess(volume_path, errors_map, self.args['num_slices'], self.slice_size, \ |
|
|
143 |
sample_volume=False, verbose=self.args['verbose']) |
|
|
144 |
else: |
|
|
145 |
preprocessed = self.load_np_volume(volume_path) |
|
|
146 |
# preprocessed = np.expand_dims(preprocessed, axis=0) |
|
|
147 |
except ValueError as e: |
|
|
148 |
raise e |
|
|
149 |
|
|
|
150 |
print('\nINFO: Predicting cancer for volume no. {}...'.format(i + 1)) |
|
|
151 |
singleton_batch = [[preprocessed, None]] |
|
|
152 |
feed_dict, _ = self.process_data_into_to_dict(singleton_batch, from_paths=False) |
|
|
153 |
preds = self.sess.run([self.get_preds], feed_dict=feed_dict) |
|
|
154 |
print('\nINFO: Probability of cancer within 1 year: {:.5f}\n\n'.format(preds[0][0])) |
|
|
155 |
|
|
|
156 |
def process_data_into_to_dict(self, coupled_batch, from_paths=True, is_training=False): |
|
|
157 |
volumes = [] |
|
|
158 |
labels = [] |
|
|
159 |
for volume, label in coupled_batch: |
|
|
160 |
try: |
|
|
161 |
if from_paths: |
|
|
162 |
volume = self.load_np_volume(volume) |
|
|
163 |
|
|
|
164 |
# Crop volume to shape (self.args['num_slices'], 224, 224) |
|
|
165 |
crop_start = volume.shape[0] // 2 - self.args['num_slices'] // 2 |
|
|
166 |
volume = volume[crop_start: crop_start + self.args['num_slices']] |
|
|
167 |
volumes.append(volume) |
|
|
168 |
|
|
|
169 |
if label is not None: |
|
|
170 |
labels.append(label) |
|
|
171 |
except: |
|
|
172 |
print('\nERROR! Could not load:', volume) |
|
|
173 |
|
|
|
174 |
# Perform windowing online volume, to save storage space of preprocessed volumes |
|
|
175 |
volumes = np.array(volumes) |
|
|
176 |
volume_batch = utils.apply_window(volumes) |
|
|
177 |
|
|
|
178 |
if labels: |
|
|
179 |
labels_np = np.array(labels).astype(np.int64) |
|
|
180 |
else: |
|
|
181 |
labels_np = np.zeros(volume_batch.shape[0], dtype=np.int64) |
|
|
182 |
|
|
|
183 |
feed_dict = {self.volumes_placeholder: volume_batch, self.labels_placeholder: labels_np, self.is_training_placeholder: is_training} |
|
|
184 |
return feed_dict, labels |
|
|
185 |
|
|
|
186 |
def load_np_volume(self, volume_file): |
|
|
187 |
if volume_file.endswith('.npz'): |
|
|
188 |
scan_arr = np.load(join(self.args['data_dir'], volume_file))['data'] |
|
|
189 |
else: |
|
|
190 |
scan_arr = np.load(join(self.args['data_dir'], volume_file)).astype(np.float32) |
|
|
191 |
return scan_arr |
|
|
192 |
|
|
|
193 |
def create_output_dirs(args): |
|
|
194 |
# Create model dir and log dir if they doesn't exist |
|
|
195 |
timestamp = date.today().strftime("%A_") + strftime("%H:%M:%S") |
|
|
196 |
out_dir_time = Path(str(args['out_dir']) + '_' + timestamp) |
|
|
197 |
save_dir = out_dir_time / 'models' |
|
|
198 |
metrics_dir = out_dir_time / 'metrics' |
|
|
199 |
val_preds_dir = metrics_dir / 'val_preds' |
|
|
200 |
tr_preds_dir = metrics_dir / 'tr_preds' |
|
|
201 |
plots_dir = out_dir_time / 'plots' |
|
|
202 |
|
|
|
203 |
for new_dir in out_dir_time, save_dir, val_preds_dir, tr_preds_dir, plots_dir: |
|
|
204 |
os.makedirs(new_dir, exist_ok=True) |
|
|
205 |
|
|
|
206 |
return save_dir, metrics_dir, plots_dir |
|
|
207 |
|
|
|
208 |
def main(args): |
|
|
209 |
print('\nINFO: Initializing...') |
|
|
210 |
|
|
|
211 |
# Set GPU |
|
|
212 |
if args['device'] == 'GPU': |
|
|
213 |
os.environ["CUDA_VISIBLE_DEVICES"] = str(args['gpu_id']) |
|
|
214 |
|
|
|
215 |
# Init model wrapper |
|
|
216 |
model = I3dForCTVolumes(args) |
|
|
217 |
|
|
|
218 |
# Load pre-trained weights |
|
|
219 |
pre_trained_ckpt = utils.load_pretrained_ckpt(args['ckpt'], args['data_dir']) |
|
|
220 |
model.pretrained_saver.restore(model.sess, pre_trained_ckpt) |
|
|
221 |
|
|
|
222 |
if args['input']: |
|
|
223 |
print('\nINFO: Begin Inference \n') |
|
|
224 |
model.predict(args['input']) |
|
|
225 |
else: |
|
|
226 |
print('\nINFO: Begin Training') |
|
|
227 |
|
|
|
228 |
print('\nINFO: Hyperparams:') |
|
|
229 |
print('\n'.join([str(item) for item in args.items()])) |
|
|
230 |
|
|
|
231 |
save_dir, metrics_dir, plots_dir = create_output_dirs(args) |
|
|
232 |
|
|
|
233 |
train_list = utils.load_data_list(args['train']) |
|
|
234 |
val_list = utils.load_data_list(args['val']) |
|
|
235 |
val_labels = utils.get_list_labels(val_list) |
|
|
236 |
utils.write_number_list(val_labels, join(metrics_dir, 'val_true'), verbose=args['verbose']) |
|
|
237 |
|
|
|
238 |
metrics = defaultdict(list) |
|
|
239 |
for epoch in range(1, args['epochs'] + 1): |
|
|
240 |
print('\nINFO: +++++++++++++++++++++ EPOCH {} +++++++++++++++++++++'.format(epoch)) |
|
|
241 |
start_time = time() |
|
|
242 |
shuffle(train_list) |
|
|
243 |
|
|
|
244 |
# Run training for 1 epoch and save weights to file |
|
|
245 |
tr_epoch_metrics = model.train_loop(train_list, metrics_dir) |
|
|
246 |
print("\nINFO: Saving Weights...") |
|
|
247 |
model.saver.save(model.sess, "{}/epoch_{}/model.ckpt".format(save_dir, epoch)) |
|
|
248 |
|
|
|
249 |
train_end_time = time() |
|
|
250 |
print('\nINFO: Train epoch duration: {:.2f} secs'.format(train_end_time - start_time)) |
|
|
251 |
|
|
|
252 |
# Run validation at end of each epoch |
|
|
253 |
print("\nINFO: Begin Validation") |
|
|
254 |
val_metrics = model.evaluate(val_list) |
|
|
255 |
|
|
|
256 |
print('\nINFO: Val duration: {:.2f} secs'.format(time() - train_end_time)) |
|
|
257 |
|
|
|
258 |
print('\nINFO: Writing metrics plotting them...') |
|
|
259 |
utils.write_metrics(metrics, tr_epoch_metrics, val_metrics, metrics_dir, epoch, verbose=args['verbose']) |
|
|
260 |
utils.plot_metrics(epoch, metrics_dir, plots_dir) |
|
|
261 |
|
|
|
262 |
def train(**kwargs): |
|
|
263 |
''' |
|
|
264 |
Run prediction. |
|
|
265 |
For arguments description, see General and Training sections in params() function below. |
|
|
266 |
''' |
|
|
267 |
final_kwargs = params() |
|
|
268 |
# Override default parameters with given arguments |
|
|
269 |
for key, value in kwargs.items(): |
|
|
270 |
final_kwargs[key] = value |
|
|
271 |
main(final_kwargs) |
|
|
272 |
|
|
|
273 |
def predict(**kwargs): |
|
|
274 |
''' |
|
|
275 |
Run prediction. |
|
|
276 |
For arguments description, see General and Inference sections in params() function below. |
|
|
277 |
''' |
|
|
278 |
final_kwargs = params() |
|
|
279 |
# Override default parameters with given arguments |
|
|
280 |
for key, value in kwargs.items(): |
|
|
281 |
final_kwargs[key] = value |
|
|
282 |
main(final_kwargs) |
|
|
283 |
|
|
|
284 |
def params(): |
|
|
285 |
parser = argparse.ArgumentParser() |
|
|
286 |
|
|
|
287 |
default_out_dir = Path.home() / 'Lung-Cancer-Risk-Prediction' / 'out' |
|
|
288 |
default_data_dir = Path.home() / 'Lung-Cancer-Risk-Prediction' / 'data' |
|
|
289 |
lists_dir = default_data_dir / 'lists' |
|
|
290 |
|
|
|
291 |
######################################## General parameters ######################################### |
|
|
292 |
parser.add_argument('--ckpt', default='cancer_fine_tuned', type=str, help="pre-trained weights to load. \ |
|
|
293 |
Either 'i3d_imagenet', 'cancer_fine_tuned' or a path to a directory containing model.ckpt file") |
|
|
294 |
|
|
|
295 |
parser.add_argument('--num_slices', default=220, type=int, \ |
|
|
296 |
help='number of slices (z dimension) from the volume to be used by the model') |
|
|
297 |
|
|
|
298 |
parser.add_argument('--verbose', default=False, type=bool, help='whether to print detailed logs') |
|
|
299 |
|
|
|
300 |
######################################## Training parameters ######################################## |
|
|
301 |
parser.add_argument('--epochs', default=40, type=int, help='the number of epochs') |
|
|
302 |
|
|
|
303 |
parser.add_argument('--lr', default=0.0001, type=int, help='initial learning rate') |
|
|
304 |
|
|
|
305 |
parser.add_argument('--keep_prob', default=0.8, type=int, help='dropout keep prob') |
|
|
306 |
|
|
|
307 |
parser.add_argument('--batch_size', default=2, type=int, help='the batch size for training/validation') |
|
|
308 |
|
|
|
309 |
parser.add_argument('--gpu_id', default=1, type=int, help='gpu id') |
|
|
310 |
|
|
|
311 |
parser.add_argument('--device', default='GPU', type=str, help='the device to execute on') |
|
|
312 |
|
|
|
313 |
parser.add_argument('--data_dir', default=default_data_dir, \ |
|
|
314 |
help='path to data directory (for raw/processed volumes, train/val lists, checkpoints etc.)') |
|
|
315 |
|
|
|
316 |
parser.add_argument('--train', default=lists_dir / 'train.list', help='path to train data .list file') |
|
|
317 |
|
|
|
318 |
parser.add_argument('--val', default=lists_dir / 'val.list', help='path to validation data .list file') |
|
|
319 |
|
|
|
320 |
parser.add_argument('--out_dir', default=default_out_dir, help='path to output dir for models, metrics and plots') |
|
|
321 |
|
|
|
322 |
######################################## Inference parameters ######################################## |
|
|
323 |
parser.add_argument('--input', default=None, type=str, help="path to directory of volumes for cancer prediction.") |
|
|
324 |
|
|
|
325 |
parser.add_argument('--preprocessed', default=False, type=bool, help='whether data for inference is \ |
|
|
326 |
preprocessed (.npz files) or raw volumes (dirs of .dcm files)') |
|
|
327 |
|
|
|
328 |
parser.set_defaults() |
|
|
329 |
args, _ = parser.parse_known_args() |
|
|
330 |
kwargs = vars(args) |
|
|
331 |
return kwargs |
|
|
332 |
|
|
|
333 |
if __name__ == "__main__": |
|
|
334 |
kwargs = params() |
|
|
335 |
main(kwargs) |