Switch to unified view

a b/Segmentation/train/train.py
1
import sys
2
import os
3
from glob import glob
4
import datetime
5
import tensorflow as tf
6
import numpy as np
7
from time import time
8
9
from Segmentation.train.utils import setup_gpu, LearningRateUpdate, Metric
10
from Segmentation.train.reshape import get_mid_slice, get_mid_vol
11
from Segmentation.train.validation import validate_best_model
12
from Segmentation.utils.data_loader import read_tfrecord_3d
13
from Segmentation.utils.visualise_utils import visualise_sample
14
from Segmentation.utils.losses import dice_loss, tversky_loss, iou_loss
15
from Segmentation.utils.losses import iou_loss_eval_3d, dice_coef_eval_3d
16
from Segmentation.utils.losses import dice_loss_weighted_3d, focal_tversky
17
from Segmentation.model.vnet import VNet
18
19
class Train:
20
    def __init__(self,
21
                 epochs,
22
                 batch_size,
23
                 enable_function,
24
                 model,
25
                 optimizer,
26
                 loss_func,
27
                 lr_manager,
28
                 predict_slice,
29
                 metrics,
30
                 tfrec_dir='./Data/tfrecords/',
31
                 log_dir="logs"):
32
33
34
        self.epochs = epochs
35
        self.batch_size = batch_size
36
        self.enable_function = enable_function
37
        self.model = model
38
        self.optimizer = optimizer
39
        self.loss_func = loss_func
40
        self.lr_manager = lr_manager
41
        self.predict_slice = predict_slice
42
        self.metrics = Metric(metrics)
43
        self.tfrec_dir = tfrec_dir
44
        self.log_dir = log_dir
45
46
    def train_step(self,
47
                   x_train,
48
                   y_train,
49
                   visualise):
50
        with tf.GradientTape() as tape:
51
            predictions = self.model(x_train, training=True)
52
            loss = self.loss_func(y_train, predictions)
53
        grads = tape.gradient(loss, self.model.trainable_variables)
54
        self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables))
55
        self.metrics.store_metric(y_train, predictions, training=True)
56
        if visualise:
57
            return loss, predictions
58
        return loss, None
59
60
    def test_step(self,
61
                  x_test,
62
                  y_test,
63
                  visualise):
64
        predictions = self.model(x_test, training=False)
65
        loss = self.loss_func(y_test, predictions)
66
        self.metrics.store_metric(y_test, predictions, training=False)
67
        if visualise:
68
            return loss, predictions
69
        return loss, None
70
71
    def train_model_loop(self,
72
                         train_ds,
73
                         valid_ds,
74
                         strategy,
75
                         multi_class,
76
                         visual_save_freq=5,
77
                         debug=False,
78
                         num_to_visualise=0):
79
        """ Trains 3D model with custom tf loop and MirrorStrategy
80
        """
81
82
        def run_train_strategy(x, y, visualise):
83
            total_step_loss, pred = strategy.run(self.train_step, args=(x, y, visualise, ))
84
            return strategy.reduce(
85
                tf.distribute.ReduceOp.SUM, total_step_loss, axis=None), pred
86
87
        def run_test_strategy(x, y, visualise):
88
            total_step_loss, pred = strategy.run(self.test_step, args=(x, y, visualise, ))
89
            return strategy.reduce(
90
                tf.distribute.ReduceOp.SUM, total_step_loss, axis=None), pred
91
92
        # TODO(Joe): This needs to be rewritten so that it works with 2D as well
93
        def distributed_train_epoch(train_ds,
94
                                    epoch,
95
                                    strategy,
96
                                    num_to_visualise,
97
                                    multi_class,
98
                                    slice_writer,
99
                                    vol_writer,
100
                                    visual_save_freq,
101
                                    predict_slice):
102
103
            total_loss, num_train_batch = 0.0, 0.0
104
            is_training = True
105
            use_2d = False
106
            for x_train, y_train in train_ds:
107
                visualise = (num_train_batch < num_to_visualise)
108
                loss, pred = run_train_strategy(x_train, y_train, visualise)
109
                loss /= strategy.num_replicas_in_sync
110
                total_loss += loss
111
                if visualise:
112
                    num_to_visualise = visualise_sample(x_train, y_train, pred, 
113
                                                        num_to_visualise,
114
                                                        slice_writer, vol_writer, 
115
                                                        use_2d, epoch, multi_class, predict_slice, is_training)
116
                num_train_batch += 1
117
            return total_loss / num_train_batch
118
119
        def distributed_test_epoch(valid_ds,
120
                                   epoch,
121
                                   strategy,
122
                                   num_to_visualise,
123
                                   multi_class,
124
                                   slice_writer,
125
                                   vol_writer,
126
                                   visual_save_freq,
127
                                   predict_slice):
128
            total_loss, num_test_batch = 0.0, 0.0
129
            is_training = False
130
            use_2d = False
131
            for x_valid, y_valid in valid_ds:
132
                visualise = (num_test_batch < num_to_visualise)
133
                loss, pred = run_test_strategy(x_valid, y_valid, visualise)
134
                loss /= strategy.num_replicas_in_sync
135
                total_loss += loss
136
                if visualise:
137
                    num_to_visualise = visualise_sample(x_train, y_train, pred, 
138
                                                        num_to_visualise,
139
                                                        slice_writer, vol_writer, 
140
                                                        use_2d, epoch, multi_class, predict_slice, is_training)
141
                num_test_batch += 1
142
            return total_loss / num_test_batch
143
144
        if self.enable_function:
145
            run_train_strategy = tf.function(run_train_strategy)
146
            run_test_strategy = tf.function(run_test_strategy)
147
148
        # TODO: This whole chunk of code needs to be refactored. Perhaps write it as a function
149
        name = "/" + self.model.name
150
        db = "/debug" if debug else "/test"
151
        mc = "/multi" if multi_class else "/binary"
152
        log_dir_now = self.log_dir + name + db + mc + datetime.datetime.now().strftime("/%Y%m%d/%H%M%S")
153
        train_summary_writer = tf.summary.create_file_writer(log_dir_now + '/train')
154
        test_summary_writer = tf.summary.create_file_writer(log_dir_now + '/val')
155
        test_min_summary_writer = tf.summary.create_file_writer(log_dir_now + '/val_min')
156
        train_img_slice_writer = tf.summary.create_file_writer(log_dir_now + '/train/img/slice')
157
        test_img_slice_writer = tf.summary.create_file_writer(log_dir_now + '/val/img/slice')
158
        train_img_vol_writer = tf.summary.create_file_writer(log_dir_now + '/train/img/vol')
159
        test_img_vol_writer = tf.summary.create_file_writer(log_dir_now + '/val/img/vol')
160
        lr_summary_writer = tf.summary.create_file_writer(log_dir_now + '/lr')
161
162
        self.metrics.add_metric_summary_writer(log_dir_now)
163
164
        best_loss = None
165
        for e in range(self.epochs):
166
            self.optimizer.learning_rate = self.lr_manager.update_lr(e)
167
168
            et0 = time()
169
170
            train_loss = distributed_train_epoch(train_ds,
171
                                                 e,
172
                                                 strategy,
173
                                                 num_to_visualise,
174
                                                 multi_class,
175
                                                 train_img_slice_writer,
176
                                                 train_img_vol_writer,
177
                                                 visual_save_freq,
178
                                                 self.predict_slice)
179
180
            with train_summary_writer.as_default():
181
                tf.summary.scalar('epoch_loss', train_loss, step=e)
182
183
            # distributed_test_epoch(valid_ds,
184
            #                        e,
185
            #                        strategy,
186
            #                        num_to_visualise,
187
            #                        multi_class,
188
            #                        test_img_slice_writer,
189
            #                        test_img_vol_writer,
190
            #                        visual_save_freq,
191
            #                        self.predict_slice)
192
            test_loss = distributed_test_epoch(valid_ds,
193
                                               e,
194
                                               strategy,
195
                                               num_to_visualise,
196
                                               multi_class,
197
                                               test_img_slice_writer,
198
                                               test_img_vol_writer,
199
                                               visual_save_freq,
200
                                               self.predict_slice)
201
            with test_summary_writer.as_default():
202
                tf.summary.scalar('epoch_loss', test_loss, step=e)
203
204
            current_lr = self.optimizer.get_config()['learning_rate']
205
            with lr_summary_writer.as_default():
206
                tf.summary.scalar('epoch_lr', current_lr, step=e)
207
208
            self.metrics.record_metric_to_summary(e)
209
            metric_str = self.metrics.reset_metrics_get_str()
210
            print(f"Epoch {e+1}/{self.epochs} - {time() - et0:.0f}s - loss: {train_loss:.05f} - val_loss: {test_loss:.05f} - lr: {self.optimizer.get_config()['learning_rate']: .06f}" + metric_str)
211
212
            if best_loss is None:
213
                self.model.save_weights(os.path.join(log_dir_now + f'/best_weights.tf'))
214
                best_loss = test_loss
215
            else:
216
                if test_loss < best_loss:
217
                    self.model.save_weights(os.path.join(log_dir_now + f'/best_weights.tf'))
218
                    best_loss = test_loss
219
            with test_min_summary_writer.as_default():
220
                tf.summary.scalar('epoch_loss', best_loss, step=e)
221
        return log_dir_now
222
223
224
def load_datasets(batch_size, buffer_size,
225
                  tfrec_dir='./Data/tfrecords/',
226
                  multi_class=False,
227
                  crop_size=144,
228
                  depth_crop_size=80,
229
                  aug=[],
230
                  predict_slice=False,
231
                  ):
232
    """
233
    Loads tf records datasets for 3D models.
234
    """
235
    args = {
236
        'batch_size': batch_size,
237
        'buffer_size': buffer_size,
238
        'multi_class': multi_class,
239
        'use_keras_fit': False,
240
        'crop_size': crop_size, 
241
        'depth_crop_size': depth_crop_size,
242
        'aug': aug,
243
    }
244
    train_ds = read_tfrecord_3d(tfrecords_dir=os.path.join(tfrec_dir, 'train_3d/'),
245
                                is_training=True, predict_slice=predict_slice, **args)
246
    valid_ds = read_tfrecord_3d(tfrecords_dir=os.path.join(tfrec_dir, 'valid_3d/'),
247
                                is_training=False, predict_slice=predict_slice, **args)
248
    return train_ds, valid_ds
249
250
251
def build_model(num_channels, num_classes, name, **kwargs):
252
    """
253
    Builds standard vnet for 3D
254
    """
255
    model = VNet(num_channels, num_classes, name=name, **kwargs)
256
    return model
257
258
259
def main(epochs,
260
         name,
261
         log_dir_now=None,
262
         batch_size=2,
263
         val_batch_size=2,
264
         lr=1e-4,
265
         lr_drop=0.9,
266
         lr_drop_freq=5,
267
         lr_warmup=3,
268
         num_to_visualise=2,
269
         num_channels=4,
270
         buffer_size=4,
271
         enable_function=True,
272
         tfrec_dir='./Data/tfrecords/',
273
         multi_class=False,
274
         crop_size=144,
275
         depth_crop_size=80,
276
         aug=[],
277
         debug=False,
278
         predict_slice=False,
279
         tpu=False,
280
         min_lr=1e-7,
281
         custom_loss=None,
282
         **model_kwargs,
283
         ):
284
    t0 = time()
285
286
    if tpu:
287
        tfrec_dir = 'gs://oai-challenge-dataset/tfrecords'
288
289
    num_classes = 7 if multi_class else 1
290
291
    metrics = {
292
        'losses': {
293
            'mIoU': [iou_loss, tf.keras.metrics.Mean(), tf.keras.metrics.Mean(), None, None],
294
            'dice': [dice_loss, tf.keras.metrics.Mean(), tf.keras.metrics.Mean(), None, None]
295
        },
296
    }
297
298
    if multi_class:
299
        metrics['losses']['mIoU-6ch'] = [iou_loss_eval_3d, tf.keras.metrics.Mean(), tf.keras.metrics.Mean(), None, None]
300
        metrics['losses']['dice-6ch'] = [dice_coef_eval_3d, tf.keras.metrics.Mean(), tf.keras.metrics.Mean(), None, None]
301
302
    train_ds, valid_ds = load_datasets(batch_size, buffer_size, tfrec_dir, multi_class,
303
                                       crop_size=crop_size, depth_crop_size=depth_crop_size, aug=aug,
304
                                       predict_slice=predict_slice)
305
306
    num_gpu = len(tf.config.experimental.list_physical_devices('GPU'))
307
    steps_per_epoch = len(glob(os.path.join(tfrec_dir, 'train_3d/*'))) / (batch_size)
308
309
    if tpu:
310
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='pit-tpu')
311
        tf.config.experimental_connect_to_cluster(resolver)
312
        tf.tpu.experimental.initialize_tpu_system(resolver)
313
        strategy = tf.distribute.experimental.TPUStrategy(resolver)
314
    else:
315
        strategy = tf.distribute.MirroredStrategy()
316
    # strategy = tf.distribute.OneDeviceStrategy(device="/gpu:0")
317
    with strategy.scope():
318
        if custom_loss is None:
319
            loss_func = tversky_loss if multi_class else dice_loss
320
        elif multi_class and custom_loss == "weighted":
321
            loss_func = dice_loss_weighted_3d
322
        elif multi_class and custom_loss == "focal":
323
            loss_func = focal_tversky
324
        else:
325
            raise NotImplementedError(f"Custom loss: {custom_loss} not implemented.")
326
327
        lr_manager = LearningRateUpdate(lr, lr_drop, lr_drop_freq, warmup=lr_warmup, min_lr=min_lr)
328
329
        optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
330
        model = build_model(num_channels, num_classes, name, predict_slice=predict_slice, **model_kwargs)
331
332
        trainer = Train(epochs, batch_size, enable_function,
333
                        model, optimizer, loss_func, lr_manager, predict_slice, metrics,
334
                        tfrec_dir=tfrec_dir)
335
336
        train_ds = strategy.experimental_distribute_dataset(train_ds)
337
        valid_ds = strategy.experimental_distribute_dataset(valid_ds)
338
339
        if log_dir_now is None:
340
            log_dir_now = trainer.train_model_loop(train_ds, valid_ds, strategy, multi_class, debug, num_to_visualise)
341
342
    train_time = time() - t0
343
    print(f"Train Time: {train_time:.02f}")
344
    t1 = time()
345
    with strategy.scope():
346
        model = build_model(num_channels, num_classes, name, predict_slice=predict_slice, **model_kwargs)
347
        model.load_weights(os.path.join(log_dir_now + f'/best_weights.tf')).expect_partial()
348
    print("Validation for:", log_dir_now)
349
350
    if not predict_slice:
351
        total_loss, metric_str = validate_best_model(model,
352
                                                     log_dir_now,
353
                                                     val_batch_size,
354
                                                     buffer_size,
355
                                                     tfrec_dir,
356
                                                     multi_class,
357
                                                     crop_size,
358
                                                     depth_crop_size,
359
                                                     predict_slice,
360
                                                     Metric(metrics))
361
        print(f"Train Time: {train_time:.02f}")
362
        print(f"Validation Time: {time() - t1:.02f}")              
363
        print(f"Total Time: {time() - t0:.02f}")
364
        with open("results/3d_result.txt", "a") as f:
365
            f.write(f'{log_dir_now}: total_loss {total_loss} {metric_str} \n')