a b/lungs/main.py
1
VERBOSE_TF = False
2
import os
3
if not VERBOSE_TF:
4
    import warnings
5
    warnings.filterwarnings('ignore', category=FutureWarning)
6
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
7
    import tensorflow as tf
8
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
9
else:
10
    import tensorflow as tf
11
from random import shuffle
12
import numpy as np
13
import argparse
14
from time import time, strftime
15
from tqdm import tqdm
16
from os.path import join, dirname, realpath
17
from collections import defaultdict
18
from sklearn.metrics import roc_auc_score
19
from datetime import date
20
from pathlib import Path
21
22
from lungs.preprocess import preprocess, walk_dicom_dirs, walk_np_files
23
from lungs import utils
24
from lungs.i3d import InceptionI3d
25
26
class I3dForCTVolumes:
27
    def __init__(self, args):
28
        self.args = args
29
30
        # This is the shape of both dimensions of each slice of the volume.
31
        # The final volume shape fed to the model is [self.args['num_slices, 224, 224]
32
        self.slice_size = 224
33
34
        # pylint: disable=not-context-manager
35
        with tf.Graph().as_default():
36
            global_step = tf.get_variable(
37
                    'global_step',
38
                    [],
39
                    initializer=tf.constant_initializer(0),
40
                    trainable=False
41
                    )
42
43
            # Placeholders
44
            self.volumes_placeholder, self.labels_placeholder, self.is_training_placeholder = utils.placeholder_inputs(
45
                    num_slices=self.args['num_slices'],
46
                    crop_size=self.slice_size,
47
                    rgb_channels=3
48
                    )
49
            
50
            # Learning rate and optimizer
51
            lr = tf.train.exponential_decay(self.args['lr'], global_step, decay_steps=5000, decay_rate=0.1, staircase=True)
52
            optimizer = tf.train.AdamOptimizer(lr)
53
54
            # Init I3D model
55
            with tf.device('/device:' + self.args['device'] + ':0'):
56
                with tf.compat.v1.variable_scope('RGB'):
57
                    _, end_points = InceptionI3d(num_classes=2, final_endpoint='Predictions')\
58
                        (self.volumes_placeholder, self.is_training_placeholder, dropout_keep_prob=args['keep_prob'])
59
                self.logits = end_points['Logits']
60
                self.preds = end_points['Predictions']
61
62
                # Loss function
63
                # self.loss = utils.focal_loss(self.logits[:, 1], self.labels_placeholder)
64
                self.loss = utils.cross_entropy_loss(self.logits, self.labels_placeholder)
65
66
                # Evaluation metrics
67
                self.get_preds = utils.get_preds(self.preds)
68
                self.get_logits = utils.get_logits(self.logits)
69
                self.accuracy = utils.accuracy(self.logits, self.labels_placeholder)
70
71
                update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
72
                with tf.control_dependencies(update_ops):
73
                    grads = optimizer.compute_gradients(self.loss)
74
                    apply_gradient = optimizer.apply_gradients(grads, global_step=global_step)
75
                    self.train_op = tf.group(apply_gradient)
76
77
            # Create a saver for loading pretrained checkpoints.
78
            pretrained_variable_map = {}
79
            for variable in tf.global_variables():
80
                if variable.name.split('/')[0] == 'RGB' and 'Adam' not in variable.name.split('/')[-1] \
81
                    and variable.name.split('/')[2] != 'Logits':
82
                    pretrained_variable_map[variable.name.replace(':0', '')] = variable
83
            self.pretrained_saver = tf.train.Saver(var_list=pretrained_variable_map, reshape=True)
84
85
            # Create a saver for writing training checkpoints.
86
            self.saver = tf.train.Saver()
87
88
            # Init local and global vars
89
            init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
90
91
            # Create a session for running Ops on the Graph.
92
            run_config = tf.ConfigProto(allow_soft_placement=True)
93
            self.sess = tf.Session(config=run_config)
94
            self.sess.run(init)
95
96
    def train_loop(self, train_list, metrics_dir):
97
        train_batches = utils.batcher(train_list, self.args['batch_size'])
98
        for coupled_batch in tqdm(train_batches):
99
            feed_dict, _ = self.process_data_into_to_dict(coupled_batch, is_training=True)
100
            self.sess.run(self.train_op, feed_dict=feed_dict)
101
102
        metrics = self.evaluate(train_list, ds='Train')
103
        utils.write_number_list(metrics[-1], join(metrics_dir, 'tr_true'), verbose=self.args['verbose'])
104
        return metrics
105
106
    def evaluate(self, coupled_list, ds='Val.'):
107
        coupled_batches = utils.batcher(coupled_list, self.args['batch_size'])
108
109
        loss_list, acc_list, preds_list, labels_list = [], [], [], []
110
        
111
        print('\nINFO: ++++++++++++++++++++ {} Evaluation ++++++++++++++++++++'.format(ds))
112
        for coupled_batch in tqdm(coupled_batches):
113
            feed_dict, labels = self.process_data_into_to_dict(coupled_batch)
114
            acc, loss, preds = self.sess.run([self.accuracy, self.loss, self.get_preds], feed_dict=feed_dict)
115
            loss_list.append(loss)
116
            acc_list.append(acc)
117
            preds_list.extend(preds)
118
            labels_list.extend(labels)
119
120
        if self.args['verbose']:
121
            print('\nDEBUG: {}. Preds/Labels: {}'.format(ds, list(zip(preds_list, labels_list))))
122
            print('\nDEBUG: {} Batch accuracy/loss: {}'.format(ds, list(zip(acc_list, loss_list))))
123
124
        mean_acc = np.mean(acc_list)
125
        mean_loss = np.mean(loss_list)
126
        auc_score = roc_auc_score(labels_list, preds_list)
127
        print('\n' + '=' * 34)
128
        print("||  INFO: {} Accuracy: {:.4f} ||".format(ds, mean_acc))
129
        print("||  INFO: {} Loss:     {:.4f} ||".format(ds, mean_loss))
130
        print("||  INFO: {} AUC:      {:.4f} ||".format(ds, auc_score))
131
        print('=' * 34)
132
        return mean_loss, mean_acc, auc_score, preds_list, labels_list
133
134
    def predict(self, inference_data):
135
        errors_map = defaultdict(int)
136
        volume_iterator = walk_np_files(inference_data) if self.args['preprocessed'] else walk_dicom_dirs(inference_data)
137
        
138
        for i, volume_path in enumerate(volume_iterator):
139
            try:
140
                if not self.args['preprocessed']:
141
                    print('\nINFO: Preprocessing volume...')
142
                    preprocessed, _ = preprocess(volume_path, errors_map, self.args['num_slices'], self.slice_size, \
143
                        sample_volume=False, verbose=self.args['verbose'])
144
                else:
145
                    preprocessed = self.load_np_volume(volume_path)
146
                    # preprocessed = np.expand_dims(preprocessed, axis=0)
147
            except ValueError as e:
148
                raise e
149
150
            print('\nINFO: Predicting cancer for volume no. {}...'.format(i + 1))
151
            singleton_batch = [[preprocessed, None]]
152
            feed_dict, _ = self.process_data_into_to_dict(singleton_batch, from_paths=False)
153
            preds = self.sess.run([self.get_preds], feed_dict=feed_dict)
154
            print('\nINFO: Probability of cancer within 1 year: {:.5f}\n\n'.format(preds[0][0]))
155
156
    def process_data_into_to_dict(self, coupled_batch, from_paths=True, is_training=False):
157
        volumes = []
158
        labels = []
159
        for volume, label in coupled_batch:
160
            try:
161
                if from_paths:
162
                    volume = self.load_np_volume(volume)
163
164
                # Crop volume to shape (self.args['num_slices'], 224, 224)
165
                crop_start = volume.shape[0] // 2 - self.args['num_slices'] // 2
166
                volume = volume[crop_start: crop_start + self.args['num_slices']]
167
                volumes.append(volume)
168
169
                if label is not None:
170
                    labels.append(label)
171
            except:
172
                print('\nERROR! Could not load:', volume)
173
174
        # Perform windowing online volume, to save storage space of preprocessed volumes
175
        volumes = np.array(volumes)
176
        volume_batch = utils.apply_window(volumes)
177
178
        if labels:
179
            labels_np = np.array(labels).astype(np.int64)
180
        else:
181
            labels_np = np.zeros(volume_batch.shape[0], dtype=np.int64)
182
183
        feed_dict = {self.volumes_placeholder: volume_batch, self.labels_placeholder: labels_np, self.is_training_placeholder: is_training}
184
        return feed_dict, labels
185
186
    def load_np_volume(self, volume_file):
187
        if volume_file.endswith('.npz'):
188
            scan_arr = np.load(join(self.args['data_dir'], volume_file))['data']
189
        else:
190
            scan_arr = np.load(join(self.args['data_dir'], volume_file)).astype(np.float32)
191
        return scan_arr
192
193
def create_output_dirs(args):
194
    # Create model dir and log dir if they doesn't exist
195
    timestamp = date.today().strftime("%A_") + strftime("%H:%M:%S")
196
    out_dir_time = Path(str(args['out_dir']) + '_' + timestamp)
197
    save_dir = out_dir_time / 'models'
198
    metrics_dir = out_dir_time / 'metrics'
199
    val_preds_dir = metrics_dir / 'val_preds'
200
    tr_preds_dir = metrics_dir / 'tr_preds'
201
    plots_dir = out_dir_time / 'plots'
202
203
    for new_dir in out_dir_time, save_dir, val_preds_dir, tr_preds_dir, plots_dir:
204
        os.makedirs(new_dir, exist_ok=True)
205
206
    return save_dir, metrics_dir, plots_dir
207
208
def main(args):
209
    print('\nINFO: Initializing...')
210
211
    # Set GPU
212
    if args['device'] == 'GPU':
213
        os.environ["CUDA_VISIBLE_DEVICES"] = str(args['gpu_id'])
214
215
    # Init model wrapper
216
    model = I3dForCTVolumes(args)
217
218
    # Load pre-trained weights
219
    pre_trained_ckpt = utils.load_pretrained_ckpt(args['ckpt'], args['data_dir'])
220
    model.pretrained_saver.restore(model.sess, pre_trained_ckpt)
221
222
    if args['input']:
223
        print('\nINFO: Begin Inference \n')
224
        model.predict(args['input'])
225
    else:
226
        print('\nINFO: Begin Training')
227
228
        print('\nINFO: Hyperparams:')
229
        print('\n'.join([str(item) for item in args.items()]))
230
231
        save_dir, metrics_dir, plots_dir = create_output_dirs(args)
232
233
        train_list = utils.load_data_list(args['train'])
234
        val_list = utils.load_data_list(args['val'])
235
        val_labels = utils.get_list_labels(val_list)
236
        utils.write_number_list(val_labels, join(metrics_dir, 'val_true'), verbose=args['verbose'])
237
238
        metrics = defaultdict(list)
239
        for epoch in range(1, args['epochs'] + 1):
240
            print('\nINFO: +++++++++++++++++++++ EPOCH {} +++++++++++++++++++++'.format(epoch))
241
            start_time = time()
242
            shuffle(train_list)
243
244
            # Run training for 1 epoch and save weights to file
245
            tr_epoch_metrics = model.train_loop(train_list, metrics_dir)
246
            print("\nINFO: Saving Weights...")
247
            model.saver.save(model.sess, "{}/epoch_{}/model.ckpt".format(save_dir, epoch))
248
            
249
            train_end_time = time()
250
            print('\nINFO: Train epoch duration: {:.2f} secs'.format(train_end_time - start_time))
251
252
            # Run validation at end of each epoch
253
            print("\nINFO: Begin Validation")
254
            val_metrics = model.evaluate(val_list)
255
256
            print('\nINFO: Val duration: {:.2f} secs'.format(time() - train_end_time))
257
258
            print('\nINFO: Writing metrics plotting them...')
259
            utils.write_metrics(metrics, tr_epoch_metrics, val_metrics, metrics_dir, epoch, verbose=args['verbose'])
260
            utils.plot_metrics(epoch, metrics_dir, plots_dir)
261
262
def train(**kwargs):
263
    '''
264
    Run prediction. 
265
    For arguments description, see General and Training sections in params() function below.
266
    '''
267
    final_kwargs = params()
268
    # Override default parameters with given arguments
269
    for key, value in kwargs.items():
270
        final_kwargs[key] = value
271
    main(final_kwargs)
272
273
def predict(**kwargs):
274
    '''
275
    Run prediction. 
276
    For arguments description, see General and Inference sections in params() function below.
277
    '''
278
    final_kwargs = params()
279
    # Override default parameters with given arguments
280
    for key, value in kwargs.items():
281
        final_kwargs[key] = value
282
    main(final_kwargs)
283
284
def params():
285
    parser = argparse.ArgumentParser()
286
287
    default_out_dir = Path.home() / 'Lung-Cancer-Risk-Prediction' / 'out'
288
    default_data_dir = Path.home() / 'Lung-Cancer-Risk-Prediction' / 'data'
289
    lists_dir = default_data_dir / 'lists'
290
291
    ########################################   General parameters #########################################
292
    parser.add_argument('--ckpt', default='cancer_fine_tuned', type=str, help="pre-trained weights to load. \
293
        Either 'i3d_imagenet', 'cancer_fine_tuned' or a path to a directory containing model.ckpt file")
294
295
    parser.add_argument('--num_slices', default=220, type=int, \
296
        help='number of slices (z dimension) from the volume to be used by the model')
297
298
    parser.add_argument('--verbose', default=False, type=bool, help='whether to print detailed logs')
299
300
    ########################################   Training parameters ########################################
301
    parser.add_argument('--epochs', default=40, type=int,  help='the number of epochs')
302
303
    parser.add_argument('--lr', default=0.0001, type=int, help='initial learning rate')
304
305
    parser.add_argument('--keep_prob', default=0.8, type=int, help='dropout keep prob')
306
307
    parser.add_argument('--batch_size', default=2, type=int, help='the batch size for training/validation')
308
309
    parser.add_argument('--gpu_id', default=1, type=int, help='gpu id')
310
311
    parser.add_argument('--device', default='GPU', type=str, help='the device to execute on')
312
313
    parser.add_argument('--data_dir', default=default_data_dir, \
314
        help='path to data directory (for raw/processed volumes, train/val lists, checkpoints etc.)')
315
316
    parser.add_argument('--train', default=lists_dir / 'train.list', help='path to train data .list file')
317
318
    parser.add_argument('--val', default=lists_dir / 'val.list', help='path to validation data .list file')
319
320
    parser.add_argument('--out_dir', default=default_out_dir, help='path to output dir for models, metrics and plots')
321
322
    ########################################   Inference parameters ########################################
323
    parser.add_argument('--input', default=None, type=str, help="path to directory of volumes for cancer prediction.")
324
325
    parser.add_argument('--preprocessed', default=False, type=bool, help='whether data for inference is \
326
        preprocessed (.npz files) or raw volumes (dirs of .dcm files)')
327
328
    parser.set_defaults()
329
    args, _ = parser.parse_known_args()
330
    kwargs = vars(args)
331
    return kwargs
332
333
if __name__ == "__main__":
334
    kwargs = params()
335
    main(kwargs)