Diff of /src/run/trainer.py [000000] .. [42b7b1]

Switch to unified view

a b/src/run/trainer.py
1
""" trainer python script
2
3
This script allows training the proposed lfbnet model.
4
This script requires to specify the directory path to the preprocessed PET MIP images. It could read the patient ids
5
from
6
the given directory path, or it could accept patient ids as .xls and .csv files. please provide the directory path to
7
the
8
csv or xls file. It assumes the csv/xls file have two columns with level 'train' and 'valid' indicating the training and
9
validation patient ids respectively.
10
11
Please see the _name__ == '__main__': as example which is equivalent to:
12
13
e.g.train_valid_data_dir = r"E:\LFBNet\data\remarc_default_MIP_dir/"
14
    train_valid_ids_path_csv = r'E:\LFBNet\data\csv\training_validation_indexs\remarc/'
15
    train_ids, valid_ids = get_training_and_validation_ids_from_csv(train_valid_ids_path_csv)
16
17
    trainer = NetworkTrainer(
18
        folder_preprocessed_train=train_valid_data_dir, folder_preprocessed_valid=train_valid_data_dir,
19
        ids_to_read_train=train_ids,
20
        ids_to_read_valid=valid_ids
21
        )
22
    trainer.train()
23
"""
24
# Import libraries
25
import os
26
import glob
27
import sys
28
import time
29
from datetime import datetime
30
31
import numpy as np
32
from numpy.random import seed
33
from random import randint
34
from tqdm import tqdm
35
from typing import Tuple, List
36
from numpy import ndarray
37
from copy import deepcopy
38
from medpy.metric import binary
39
import matplotlib.pyplot as plt
40
from keras import backend as K
41
import re
42
43
# make LFBNet as parent directory, for absolute import libraries. local application import.
44
p = os.path.abspath('../..')
45
if p not in sys.path:
46
    sys.path.append(p)
47
48
# import LFBNet modules
49
from src.LFBNet.data_loader import DataLoader
50
from src.LFBNet.network_architecture import lfbnet
51
from src.LFBNet.losses import losses
52
from src.LFBNet.preprocessing import save_nii_images
53
from src.LFBNet.utilities import train_valid_paths
54
from src.LFBNet.postprocessing import remove_outliers_in_sagittal
55
# choose cuda gpu
56
CUDA_VISIBLE_DEVICES = 1
57
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
58
59
# set randomness repetable across experiments.
60
seed(1)
61
62
# Define the parameters of the data to process
63
K.set_image_data_format('channels_last')
64
65
66
def default_training_parameters(
67
        num_epochs: int = 5000, batch_size: int = 16, early_stop: int = None, fold_number: int = None,
68
        model_name_save: List[str] = None, loss: str = None, metric: str = None
69
        ) -> dict:
70
    """ Configure default parameters for training.
71
    Training parameters are setted here. For other options, the user should modifier these values.
72
    Parameters
73
    ----------
74
    num_epochs: int, maximum number of epochs to train the model.
75
    batch_size: int, number of images per batch
76
    early_stop: int, the number of training epochs the model should train while it is not improving the accuracy.
77
    fold_number: int, optional, fold number while applying cross-validation-based training.
78
    model_name_save: str, model name to save
79
    loss: str, loss funciton
80
    metric: str, specify the metric, such as dice
81
82
    Returns
83
    -------
84
    Returns configured dictionary for the training.
85
86
    """
87
    if early_stop is None:
88
        # early stop 50 % of the maximum number of epochs
89
        early_stop = int(num_epochs * 0.5)
90
91
    if fold_number is None:
92
        fold_number = 'fold_run_at_' + str(time.time())
93
94
    if model_name_save is None:
95
        model_name_save = ["forward_" + str(time.time()), "feedback_" + str(time.time())]
96
97
    if loss is None:
98
        loss = losses.LossMetric.dice_plus_binary_cross_entropy_loss
99
100
    if metric is None:
101
        metric = losses.LossMetric.dice_metric
102
103
    config_trainer = {'num_epochs': num_epochs, 'batch_size': batch_size, 'num_early_stop': early_stop,
104
                      'fold_number': fold_number, 'model_name_save_forward': model_name_save[0],
105
                      'model_name_save_feedback': model_name_save[1], "custom_loss": loss, "custom_dice": metric}
106
107
    return config_trainer
108
109
110
def get_training_and_validation_ids_from_csv(path):
111
    """ Get training and validation ids from a given csv or xls file. Assuming the training ids are given with column
112
    name 'train' and validation ids in 'valid'
113
114
    Parameters
115
    ----------
116
    path: directory path to the csv or xls file.
117
118
    Returns
119
    -------
120
    Returns training and validation patient ids.
121
122
123
    """
124
    ids = train_valid_paths.read_csv_train_valid_index(path)
125
    train, valid = ids[0], ids[1]
126
    return train, valid
127
128
129
def get_train_valid_ids_from_folder(path_train_valid, ratio_valid_data=0.25):
130
    """ Returns the randomly split training and validation patient ids. The percentage of validation is given by the
131
    ratio_valid_data.
132
133
    Parameters
134
    ----------
135
    path_train_valid
136
    ratio_valid_data
137
138
     Returns
139
    -------
140
    Returns training patient id and validation patient ids respectively as in two array.
141
142
    """
143
    # given training and validation data on one folder, random splitting with .25% : train, valid
144
    if len(path_train_valid) == 1:
145
        all_cases_id = os.listdir(str(path_train_valid))  # all patients id
146
147
        # make permutation in the given list
148
        case_ids = np.array(all_cases_id)
149
        indices = np.random.permutation(len(case_ids))
150
        num_valid_data = int(ratio_valid_data * len(all_cases_id))
151
152
        train, valid = indices[num_valid_data:], indices[:num_valid_data]
153
        return [train, valid]
154
155
156
class NetworkTrainer:
157
    """
158
    class to train the lfb net
159
    """
160
    # keep the best loss and dice while training : Value shared across all instances, methods
161
    BEST_METRIC_VALIDATION = 0  # KEEP THE BEST VALIDATION METRIC SUCH AS THE DICE METRIC (BEST_DICE)
162
    BEST_LOSS_VALIDATION = 100  # KEEP THE BEST VALIDATION LOSS SUCH AS THE LOSS VALUES (BEST_LOSS)
163
    EARLY_STOP_COUNT = 0  # COUNTS THE NUMBER OF TRAINING ITERATIONS THE MODEL DID NOT INCREASE, TO COMPARE WITH THE
164
    now = datetime.now()  # current time, date, month,
165
    TRAINED_MODEL_IDENTIFIER = re.sub('[ :]', "_", now.ctime())
166
167
    # EARLY STOP CRITERIA
168
169
    def __init__(
170
            self, config_trainer: dict = None, folder_preprocessed_train: str = '../data/train/',
171
            folder_preprocessed_valid: str = '../data/valid/', ids_to_read_train: ndarray = None,
172
            ids_to_read_valid: ndarray = None, task: str = 'valid', predicted_directory: str = '../data/predicted/',
173
            save_predicted: bool = False
174
            ):
175
        """
176
177
        :param config_trainer:
178
        :param folder_preprocessed_train:
179
        :param folder_preprocessed_valid:
180
        :param ids_to_read_train:
181
        :param ids_to_read_valid:
182
        :param task:
183
        :predicted_directory:
184
        :save_predicted
185
        """
186
187
        if config_trainer is None:
188
            self.config_trainer = deepcopy(default_training_parameters())
189
190
        # training data
191
        self.folder_preprocessed_train = folder_preprocessed_train
192
        if ids_to_read_train is None:
193
            ids_to_read_train = os.listdir(folder_preprocessed_train)
194
195
        self.ids_to_read_train = ids_to_read_train
196
197
        # validation data
198
        self.folder_preprocessed_valid = folder_preprocessed_valid
199
        if ids_to_read_valid is None:
200
            ids_to_read_valid = os.listdir(folder_preprocessed_valid)
201
        self.ids_to_read_valid = ids_to_read_valid
202
203
        # save predicted directory:
204
        self.save_all = save_predicted
205
        self.predicted_directory = predicted_directory
206
        # load the lfb_network architecture
207
        self.model = lfbnet.LfbNet()
208
        self.task = task
209
210
        # forward network decoder
211
212
        # latent feedback at zero time: means no feedback from feedback network
213
        self.latent_dim = self.model.latent_dim
214
        self.h_at_zero_time = np.zeros(
215
            (int(self.config_trainer['batch_size']), int(self.latent_dim[0]), int(self.latent_dim[1]),
216
             int(self.latent_dim[2])), np.float32
217
            )
218
219
    @staticmethod
220
    def load_dataset(directory_: str = None, ids_to_read: List[str] = None):
221
        """
222
223
        :param ids_to_read:
224
        :param directory_:
225
        """
226
        # load batch of data
227
        data_loader = DataLoader(data_dir=directory_, ids_to_read=ids_to_read)
228
        image_batch_ground_truth_batch = data_loader.get_batch_of_data()
229
230
        batch_input_data, batch_output_data = image_batch_ground_truth_batch[0], image_batch_ground_truth_batch[1]
231
        # expand dimension for the channel
232
        batch_output_data = np.expand_dims(batch_output_data, axis=-1)
233
        batch_input_data = np.expand_dims(batch_input_data, axis=-1)
234
235
        return batch_input_data, batch_output_data
236
237
    def load_latest_weight(self):
238
        """ loads the weights of the model with the latest saved weight in the folder ./weight
239
        """
240
        # load the last trained weight in the folder weight
241
        folder_path = r'./weight/'
242
        file_type = r'\*.h5'
243
        files = glob.glob(folder_path + file_type)
244
        try:
245
            max_file = max(files, key=os.path.getctime)
246
        except:
247
            raise Exception("weight could not found !")
248
249
        base_name = str(os.path.basename(max_file))
250
        print(base_name)
251
        self.model.combine_and_train.load_weights('./weight/forward_system' + str(base_name.split('system')[1]))
252
        # f
253
        self.model.fcn_feedback.load_weights('./weight/feedback_system' + str(base_name.split('system')[1]))
254
255
    def train(self):
256
        """Train the model
257
        """
258
259
        batch_size = self.config_trainer['batch_size']
260
        # self.load_latest_weight()
261
        # training
262
        if self.task == 'train':
263
            # training
264
            for current_epoch in range(self.config_trainer['num_epochs']):
265
                feedback_loss_dice = []
266
                forward_loss_dice = []
267
                forward_decoder_loss_dice = []
268
269
                # shuffle the index of the training data
270
                index_read = np.random.permutation(int(len(self.ids_to_read_train)))
271
                # read data
272
                for selected_patient in range(len(index_read)):
273
                    # get index of batch of data
274
                    start = selected_patient * batch_size
275
                    idx_list_batch = index_read[start:start + batch_size]
276
                    # if there are still elements in the given batch
277
                    if idx_list_batch.size > 0:
278
                        # get index of Why not ? kk = indx_list_batch
279
                        kk = [str(k) for i, k in enumerate(self.ids_to_read_train) if i in idx_list_batch]
280
281
                        batch_input_data, batch_output_data = self.load_dataset(
282
                            directory_=self.folder_preprocessed_train, ids_to_read=kk
283
                            )
284
285
                        assert len(batch_input_data) > 0, "batch of data not loaded correctly"
286
287
                        # shuffle within the batch
288
                        index_batch = np.random.permutation(int(batch_input_data.shape[0]))
289
                        batch_input_data = batch_input_data[index_batch]
290
                        batch_output_data = batch_output_data[index_batch]
291
292
                        # batches per epoch: Selected batch might as in id could have more images than the batch size
293
                        batch_per_epoch = int(batch_input_data.shape[0] / batch_size)
294
                        for batch_per_epoch_ in range(batch_per_epoch):
295
                            batch_input = batch_input_data[
296
                                          batch_per_epoch_ * batch_size:(batch_per_epoch_ + 1) * batch_size]
297
                            batch_output = batch_output_data[
298
                                           batch_per_epoch_ * batch_size:(batch_per_epoch_ + 1) * batch_size]
299
300
                            # Train forward models
301
                            if current_epoch % 2 == 0:
302
                                # step 1: train the forward network encoder and decoder
303
                                loss, dice = self.model.combine_and_train.train_on_batch(
304
                                    [batch_input, self.h_at_zero_time], [batch_output]
305
                                    )  # self.h_at_zero_time
306
                                forward_loss_dice.append([loss, dice])
307
308
                            else:
309
                                predicted_decoder = self.model.combine_and_train.predict(
310
                                    [batch_input, self.h_at_zero_time]
311
                                    )  # , self.h_at_zero_time
312
313
                                # step 2: train the feedback network, considering the output of the forward network
314
                                loss, dice = self.model.fcn_feedback.train_on_batch(predicted_decoder, batch_output)
315
                                feedback_loss_dice.append([loss, dice])
316
317
                                # Step 3: train the forward decoder, considering the trained
318
                                feedback_latent_result = self.model.feedback_latent.predict([predicted_decoder])
319
                                forward_encoder_output = self.model.forward_encoder.predict([batch_input])
320
321
                                # forward_encoder_output.insert(1, feedback_latent_result)
322
                                forward_encoder_output = forward_encoder_output[::-1]  # bottleneck should be first
323
                                forward_encoder_output.insert(1, feedback_latent_result)
324
                                loss, dice = self.model.forward_decoder.train_on_batch(
325
                                    [output for output in forward_encoder_output], [batch_output]
326
                                    )
327
                                forward_decoder_loss_dice.append([loss, dice])
328
329
                forward_loss_dice = np.array(forward_loss_dice)
330
                feedback_loss_dice = np.array(feedback_loss_dice)
331
                forward_decoder_loss_dice = np.array(forward_decoder_loss_dice)
332
333
                if current_epoch % 2 == 0:
334
                    loss, dice = np.mean(forward_loss_dice, axis=0)
335
                    print(
336
                        'Training_forward_system: >%d, '
337
                        ' fwd_loss = %.3f, fwd_dice=%0.3f, ' % (current_epoch, loss, dice)
338
                        )
339
340
                else:
341
                    loss_forward, dice_forward = np.mean(forward_decoder_loss_dice, axis=0)
342
                    loss_feedback, dice_feedback = np.mean(feedback_loss_dice, axis=0)
343
344
                    print(
345
                        'Training_forward_decoder_and_feedback_system: >%d, '
346
                        'fwd_decoder_loss=%03f, '
347
                        'fwd_decoder_dice=%0.3f '
348
349
                        'fdb_loss=%03f, '
350
                        'fdb_dice=%.3f ' % (current_epoch, loss_forward, dice_forward, loss_feedback, dice_feedback)
351
                        )
352
                # validation test:
353
                self.validation(current_epoch=current_epoch)
354
355
                # CHECK TRAINING STOPPING CRITERIA:  maximum number of epochs (epoch - 1), meet early stop
356
                if NetworkTrainer.EARLY_STOP_COUNT == self.config_trainer['num_early_stop']:
357
                    # save model with early stop identification
358
                    self.model.combine_and_train.save(
359
                        'weight/forward_system_early_stopped_' + str(NetworkTrainer.TRAINED_MODEL_IDENTIFIER) + '_.h5'
360
                        )
361
                    self.model.fcn_feedback.save(
362
                        'weight/feedback_system_early_stopped_' + str(NetworkTrainer.TRAINED_MODEL_IDENTIFIER) + '_.h5'
363
                        )
364
                    break  # STOP TRAINING WITH BREAK, OR EXIT TRAINING
365
366
        # if not training load the last saved weights, and check validation
367
        elif self.task == 'valid':
368
            # load the last trained weight in the folder weight
369
            self.load_latest_weight()
370
            self.validation(current_epoch=self.config_trainer['num_epochs'])
371
372
    def validation(self, verbose: int = 0, current_epoch: int = None):
373
        """
374
        Compute the validation dice, loss of the training from the validation data
375
        """
376
        # path to the validation data, if not specified, the default path ../data/valid/ would be considered
377
378
        folder_preprocessed = self.folder_preprocessed_valid
379
380
        # image folder names, or identifier: if not specified the default values would be the name of the folder inside
381
        # the directory "folder processed" or the self.folder_processed_valid :
382
383
        valid_identifier = self.ids_to_read_valid
384
385
        '''
386
        WE CAN IMPLEMENT THE EVALUATION METHOD AS BATCH BASED, PATIENT BASED, OR THE WHOLE-VALIDATION DATA BASED. FOR
387
        THE LAST OPTION WE NEED TO IMPLEMENT THE EVALUATION() FUNCTION HERE. 
388
        '''
389
390
        ''''
391
        declare variables to return: 
392
                        forward loss and dice with h0 (no feedback), 
393
                        feedback network loss and dice
394
                        forward decoder loss and dice,
395
                        forward loss and dice with ht (with feedback latent space)
396
        '''
397
        loss_dice = {'loss_fwd_h0': [], 'dice__fwd_h0': [], 'loss_fdb_h0': [], 'dice_fdb_h0': [],
398
                     'loss_fwd_decoder': [], 'dice_fwd_decoder': [], 'loss_fwd_ht': [], 'dice_fwd_ht': []}
399
400
        all_dice_sen_sep = {'dice': [], 'specificity': [], 'sensitivity': []}
401
402
        # load the dataset,
403
        # get the validation ids
404
        for id_to_validate in valid_identifier:
405
            try:
406
                id_to_validate = str(id_to_validate).split('.')[0]
407
            except:
408
                pass
409
410
            valid_input, valid_output = self.load_dataset(directory_=folder_preprocessed, ids_to_read=[id_to_validate])
411
412
            if len(valid_input) == 0:
413
                print("data %s not read" % id_to_validate)
414
                continue
415
416
            results, dice_sen_sep = self.evaluation(
417
                input_image=valid_input.copy(), ground_truth=valid_output.copy(), case_name=str(id_to_validate)
418
                )
419
420
            # append all loss to loss and dice to dice from all cases in valid identifiers
421
            for keys in results.keys():
422
                loss_dice[str(keys)].append(results[str(keys)][0])
423
424
            for keys in dice_sen_sep.keys():
425
                all_dice_sen_sep[str(keys)].append(dice_sen_sep[str(keys)][0])
426
427
        print("\n Dice, sensitivity, specificity \t")
428
        for k, v in all_dice_sen_sep.items():
429
            print('%s :  %0.3f ' % (k, np.mean(list(v), axis=0)), end=" ")
430
        print("\n")
431
432
        """
433
        print the mean of the validation loss and validation dice
434
        """
435
436
        # FOR STOPPING CRITERIA WE ARE USING THE MODEL AT THE 3RD STEP
437
        dice_mean = np.mean(loss_dice['dice_fwd_ht'])
438
        loss_mean = np.mean(loss_dice['loss_fwd_ht'])
439
440
        # at the first epoch
441
        if current_epoch == 0:
442
            NetworkTrainer.BEST_METRIC_VALIDATION = dice_mean
443
            NetworkTrainer.BEST_LOSS_VALIDATION = loss_mean
444
445
        # compare the current dice and loss with the previous epoch's loss and dice:
446
        # NOW CONSIDER DICE AS OPTIMIZATION METRIC
447
        print("Current validation loss and metrics at epoch %d: >> " % current_epoch, end=" ")
448
        for k, v in loss_dice.items():
449
            print('%s :  %0.3f ' % (k, np.mean(v)), end=" ")
450
        print("\n")
451
452
        if NetworkTrainer.BEST_METRIC_VALIDATION <= dice_mean:
453
            # reset early stop count, best dice, and best loss values
454
            NetworkTrainer.BEST_LOSS_VALIDATION = loss_mean
455
            NetworkTrainer.BEST_METRIC_VALIDATION = dice_mean
456
            NetworkTrainer.EARLY_STOP_COUNT = 0
457
458
            # save the best model weights
459
            if not os.path.exists('./weight'):
460
                os.mkdir('./weight')
461
462
            self.model.combine_and_train.save(
463
                'weight/forward_system_' + str(NetworkTrainer.TRAINED_MODEL_IDENTIFIER) + '_.h5'
464
                )
465
            self.model.fcn_feedback.save(
466
                'weight/feedback_system_' + str(NetworkTrainer.TRAINED_MODEL_IDENTIFIER) + '_.h5'
467
                )
468
        else:  # just print the current validation metric (dice) and loss, and count early stop
469
            # Increase the early stop count per epoch
470
            NetworkTrainer.EARLY_STOP_COUNT += 1
471
472
        print(
473
            '\n Best model on validation data : %0.3f :  Dice: %0.3f \n' % (
474
                NetworkTrainer.BEST_LOSS_VALIDATION, NetworkTrainer.BEST_METRIC_VALIDATION)
475
            )
476
477
    def evaluation(
478
            self, verbose: int = 0, input_image: ndarray = None, ground_truth: ndarray = None,
479
            validation_or_test: str = 'test', case_name: str = None
480
            ):
481
        """
482
483
        :param case_name:
484
        :param validation_or_test:
485
        :param verbose:
486
        :param input_image:
487
        :param ground_truth:
488
489
        Parameters
490
        ----------
491
        save_all
492
        """
493
        ''''
494
        declare variables to return: 
495
                        forward loss and dice with h0 (no feedback), 
496
                        feedback network loss and dice
497
                        forward decoder loss and dice,
498
                        forward loss and dice with ht (with feedback latent space)
499
        '''
500
        all_loss_dice = {'loss_fwd_h0': [], 'dice__fwd_h0': [], 'loss_fdb_h0': [], 'dice_fdb_h0': [],
501
                         'loss_fwd_decoder': [], 'dice_fwd_decoder': [], 'loss_fwd_ht': [], 'dice_fwd_ht': []}
502
503
        dice_sen_sp = {'dice': [], 'specificity': [], 'sensitivity': []}
504
505
        # latent  feedback variable h0
506
        # replace the first number of batches with the number of input images from the first channel
507
        h0_input = np.zeros(
508
            (len(input_image), int(self.latent_dim[0]), int(self.latent_dim[1]), int(self.latent_dim[2])), np.float32
509
            )
510
511
        # step 0:
512
        # Loss and dice on the validation of the forward system
513
        loss, dice = self.model.combine_and_train.evaluate([input_image, h0_input], [ground_truth], verbose=verbose)
514
        all_loss_dice['loss_fwd_h0'].append(loss), all_loss_dice['dice__fwd_h0'].append(dice)
515
516
        # predict from the forward system
517
        predicted = self.model.combine_and_train.predict([input_image, h0_input])
518
519
        # step 2:
520
        # Loss and dice on the validation of the feedback system
521
        loss, dice = self.model.fcn_feedback.evaluate([predicted], [ground_truth], verbose=verbose)
522
        all_loss_dice['loss_fdb_h0'].append(loss), all_loss_dice['dice_fdb_h0'].append(dice)
523
524
        # step 3:
525
        feedback_latent = self.model.feedback_latent.predict(predicted)  # feedback: hf
526
        forward_encoder_output = self.model.forward_encoder.predict([input_image])  # forward system's encoder output
527
528
        forward_encoder_output = forward_encoder_output[::-1]  # bottleneck should be first
529
        forward_encoder_output.insert(1, feedback_latent)
530
        loss, dice = self.model.forward_decoder.evaluate(
531
            [output for output in forward_encoder_output], [ground_truth], verbose=verbose
532
            )
533
        all_loss_dice['loss_fwd_decoder'].append(loss), all_loss_dice['dice_fwd_decoder'].append(dice)
534
535
        # loss and dice from the combined and feed back latent space : input  [input_image, fdb_latent_space]
536
        loss, dice = self.model.combine_and_train.evaluate(
537
            [input_image, feedback_latent], [ground_truth], verbose=verbose
538
            )
539
        all_loss_dice['loss_fwd_ht'].append(loss), all_loss_dice['dice_fwd_ht'].append(dice)
540
        """
541
        For the testing time, we use defined metrics on the predicted images instead of using model.evaluate during 
542
        the validation cases 
543
        """
544
        predicted = self.model.combine_and_train.predict([input_image, feedback_latent])
545
546
        # binary.dc, sen, and specificty works only on binary images
547
        dice_sen_sp['dice'].append(
548
            binary.dc(NetworkTrainer.threshold_image(predicted), NetworkTrainer.threshold_image(ground_truth))
549
            )
550
        dice_sen_sp['sensitivity'].append(
551
            binary.sensitivity(NetworkTrainer.threshold_image(predicted), NetworkTrainer.threshold_image(ground_truth))
552
            )
553
        dice_sen_sp['specificity'].append(
554
            binary.specificity(NetworkTrainer.threshold_image(predicted), NetworkTrainer.threshold_image(ground_truth))
555
            )
556
        # all = np.concatenate((ground_truth, predicted, input_image), axis=0)
557
        # display_image(all)
558
559
        # Sometimes save predictions
560
        if self.save_all:
561
            predicted = self.model.combine_and_train.predict([input_image, feedback_latent])
562
            save_nii_images(
563
                [predicted, ground_truth, input_image], identifier=str(case_name),
564
                name=[case_name + "_predicted", case_name + "_ground_truth", case_name + "_image"], path_save=self.predicted_directory
565
                )
566
        else:
567
568
            n = randint(0, 10)
569
            if n % 3 == 0:
570
                predicted = self.model.combine_and_train.predict([input_image, feedback_latent])
571
                save_nii_images(
572
                    [predicted, ground_truth, input_image], identifier=str(case_name),
573
                    name=[case_name + "_predicted", case_name + "_ground_truth", case_name + "_image"], path_save=self.predicted_directory
574
                    )
575
576
        return all_loss_dice, dice_sen_sp
577
578
    @staticmethod
579
    def display_image(im_display: ndarray):
580
        """ display given images
581
582
        :param all: 2D image arrays to display
583
        :returns: display images
584
        """
585
        plt.figure(figsize=(10, 8))
586
        plt.subplots_adjust(hspace=0.5)
587
        plt.suptitle("Daily closing prices", fontsize=18, y=0.95)
588
        # loop through the length of tickers and keep track of index
589
        for n, im in enumerate(im_display):
590
            # add a new subplot iteratively
591
            plt.subplot(3, 2, n + 1)
592
            plt.imshow(im)  # chart formatting
593
        plt.show()
594
595
    @staticmethod
596
    # binary.dc, sen, and specificty works only on binary images
597
    def threshold_image(im_: ndarray, thr_value: float = 0.5) -> ndarray:
598
        """ threshold given input array with the given thresholding value
599
600
        :param im_: ndarray of images
601
        :param thr_value: thresholding value
602
        :return: threshold array image
603
        """
604
        im_[im_ > thr_value] = 1
605
        im_[im_ < thr_value] = 0
606
        return im_
607
608
609
class ModelTesting:
610
    """ performs prediction on a given data set. It predicts the segmentation results, and save the results, calculate
611
    the clinical metrics such as TMTV, Dmax, sTMTV, sDmax.
612
613
    """
614
    now = datetime.now()  # current time, date, month,
615
    TRAINED_MODEL_IDENTIFIER = re.sub('[ :]', "_", now.ctime())
616
    print("current directory", os.getcwd())
617
618
    def __init__(
619
            self, config_test: dict = None, preprocessed_dir: str = '../data/test/', data_list: List[str] = None,
620
            predicted_dir: str = "../data/predicted"
621
            ):
622
        """
623
624
        :param config_trainer:
625
        :param folder_preprocessed_train:
626
        :param folder_preprocessed_valid:
627
        :param ids_to_read_train:
628
        :param ids_to_read_valid:
629
        :param task:
630
        :param predicted_dir:
631
        """
632
633
        if config_test is None:
634
            self.config_test = deepcopy(default_training_parameters())
635
636
        # training data
637
        self.preprocessed_dir = preprocessed_dir
638
        self.predicted_dir = predicted_dir
639
640
        # if the list of testing cases are not given, get from the directory
641
        if data_list is None:
642
            data_list = os.listdir(preprocessed_dir)
643
644
        self.data_list = data_list
645
646
        # load the lfb_network architecture
647
        self.model = lfbnet.LfbNet()
648
649
        # latent feedback at zero time: means no feedback from feedback network
650
        self.latent_dim = self.model.latent_dim
651
652
        # load the last trained weight in the folder weight
653
        print(os.getcwd())
654
        folder_path = os.path.join(os.getcwd(), 'src/weight')
655
        print(folder_path)
656
657
        full_path = [path_i for path_i in glob.glob(str(folder_path) + '/*.h5')]
658
659
        print("files \n", full_path)
660
        try:
661
            max_file = max(full_path, key=os.path.getctime)
662
        except:
663
            raise Exception("weight could not found !")
664
665
        base_name = str(os.path.basename(max_file))
666
        print(base_name)
667
        self.model.combine_and_train.load_weights(
668
            str(folder_path) + '/forward_system' + str(base_name.split('system')[1])
669
            )
670
        # f
671
        self.model.fcn_feedback.load_weights(str(folder_path) + '/feedback_system' + str(base_name.split('system')[1]))
672
673
        self.test()
674
675
    def test(self):
676
        """
677
                   Compute the validation dice, loss of the training from the validation data
678
           """
679
        # path to the validation data, if not specified, the default path ../data/valid/ would be considered
680
        #
681
        folder_preprocessed = self.preprocessed_dir
682
        # image folder names, or identifier: if not specified the default values would be the name of the folder inside
683
        # the directory "folder processed" or the self.folder_processed_valid :
684
        test_identifier = self.data_list
685
686
        ''''
687
        declare variables to return if there is a reference segmentation or ground truth : 
688
                        forward loss and dice with h0 (no feedback), 
689
                        feedback network loss and dice
690
                        forward decoder loss and dice,
691
                        forward loss and dice with ht (with feedback latent space)
692
        '''
693
        loss_dice = {'loss_fwd_h0': [], 'dice__fwd_h0': [], 'loss_fdb_h0': [], 'dice_fdb_h0': [],
694
                     'loss_fwd_decoder': [], 'dice_fwd_decoder': [], 'loss_fwd_ht': [], 'dice_fwd_ht': []}
695
696
        # get the validation ids
697
        test_output = []
698
        for id_to_test in tqdm(list(test_identifier)):
699
            test_input, test_output = NetworkTrainer.load_dataset(
700
                directory_=folder_preprocessed, ids_to_read=[id_to_test]
701
                )
702
703
            if len(test_input) == 0:
704
                print("data %s not read" % id_to_test)
705
                continue
706
707
            '''
708
            if there is a ground truth segmentation (gt), and you would like to compare with the predicted segmentation
709
            by the deep learning model
710
            '''
711
712
            if len(test_output):
713
                results = self.evaluation_test(
714
                    input_image=test_input.copy(), ground_truth=test_output.copy(), case_name=str(id_to_test)
715
                    )
716
717
                # append all loss to loss and dice to dice from all cases in valid identifiers
718
                for keys in results.keys():
719
                    loss_dice[str(keys)].append(results[str(keys)][0])
720
721
                print("Results (sagittal and coronal) for case id: %s  : >> " % id_to_test, end=" ")
722
                for k, v in loss_dice.items():
723
                    print('%s :  %0.3f ' % (k, np.mean(v)), end=" ")
724
                print("\n")
725
726
            # Predict the segmentation and save in the folder predicted, dataset identifier
727
            else:
728
                self.prediction(input_image=test_input.copy(), case_name=str(id_to_test))
729
730
        """
731
        print the mean of the testing loss and dice if there is a ground truth, for all cases 
732
        """
733
        if len(test_output):
734
            print("Total dataset metrics:  : >> ", end=" ")
735
            for k, v in loss_dice.items():
736
                print('%s :  %0.3f ' % (k, np.mean(v)), end=" ")
737
            print("\n")
738
739
    def evaluation_test(
740
            self, verbose: int = 0, input_image: ndarray = None, ground_truth: ndarray = None,
741
            validation_or_test: str = 'validate', case_name: str = None
742
            ):
743
        """
744
745
        :param case_name:
746
        :param validation_or_test:
747
        :param verbose:
748
        :param input_image:
749
        :param ground_truth:
750
        """
751
        ''''
752
        declare variables to return: 
753
                        forward loss and dice with h0 (no feedback), 
754
                        feedback network loss and dice
755
                        forward decoder loss and dice,
756
                        forward loss and dice with ht (with feedback latent space)
757
        '''
758
        all_loss_dice = {'loss_fwd_h0': [], 'dice__fwd_h0': [], 'loss_fdb_h0': [], 'dice_fdb_h0': [],
759
                         'loss_fwd_decoder': [], 'dice_fwd_decoder': [], 'loss_fwd_ht': [], 'dice_fwd_ht': []}
760
        # latent  feedback variable h0
761
        # replace the first number of batches with the number of input images from the first channel
762
        h0_input = np.zeros(
763
            (len(input_image), int(self.latent_dim[0]), int(self.latent_dim[1]), int(self.latent_dim[2])), np.float32
764
            )
765
766
        # step 0:
767
        # Loss and dice on the validation of the forward system
768
        loss, dice = self.model.combine_and_train.evaluate([input_image, h0_input], [ground_truth], verbose=verbose)
769
        all_loss_dice['loss_fwd_h0'].append(loss), all_loss_dice['dice__fwd_h0'].append(dice)
770
771
        # predict from the forward system
772
        predicted = self.model.combine_and_train.predict([input_image, h0_input])
773
774
        # step 2:
775
        # Loss and dice on the validation of the feedback system
776
        loss, dice = self.model.fcn_feedback.evaluate([predicted], [ground_truth], verbose=verbose)
777
        all_loss_dice['loss_fdb_h0'].append(loss), all_loss_dice['dice_fdb_h0'].append(dice)
778
779
        # step 3:
780
        feedback_latent = self.model.feedback_latent.predict(predicted)  # feedback: hf
781
        forward_encoder_output = self.model.forward_encoder.predict([input_image])  # forward system's encoder output
782
783
        forward_encoder_output = forward_encoder_output[::-1]  # bottleneck should be first
784
        forward_encoder_output.insert(1, feedback_latent)
785
        loss, dice = self.model.forward_decoder.evaluate(
786
            [output for output in forward_encoder_output], [ground_truth], verbose=verbose
787
            )
788
        all_loss_dice['loss_fwd_decoder'].append(loss), all_loss_dice['dice_fwd_decoder'].append(dice)
789
790
        # loss and dice from the combined and feed back latent space : input  [input_image, fdb_latent_space]
791
        loss, dice = self.model.combine_and_train.evaluate(
792
            [input_image, feedback_latent], [ground_truth], verbose=verbose
793
            )
794
        all_loss_dice['loss_fwd_ht'].append(loss), all_loss_dice['dice_fwd_ht'].append(dice)
795
796
        """
797
        For the testing time, we use defined metrics on the predicted images instead of using model.evaluate during 
798
        the validation cases 
799
        """
800
        if validation_or_test == "test":
801
            # return [dice, specificity, and sensitivity
802
            return {'dice': binary.dc(predicted, ground_truth),
803
                    'specificity': binary.specificity(predicted, ground_truth),
804
                    'sensitivity': binary.sensitivity(predicted, ground_truth)}
805
806
        predicted = self.model.combine_and_train.predict([input_image, feedback_latent])
807
        predicted = remove_outliers_in_sagittal(predicted)
808
        save_nii_images(
809
            [predicted, ground_truth, input_image], identifier=str(case_name),
810
            name=[case_name + "_predicted", case_name + "_ground_truth", case_name + "_pet"],
811
            path_save= os.path.join(str(self.predicted_dir), 'predicted_data')
812
            )
813
814
        return all_loss_dice
815
816
    def prediction(self, input_image: ndarray = None, case_name: str = None):
817
        """
818
        :param case_name:
819
        :param input_image:
820
        """
821
        # latent  feedback variable h0
822
        # replace the first number of batches with the number of input images from the first channel
823
        h0_input = np.zeros(
824
            (len(input_image), int(self.latent_dim[0]), int(self.latent_dim[1]), int(self.latent_dim[2])), np.float32
825
            )
826
827
        # STEP 1: forward system prediction
828
        # predict from the forward system
829
        predicted = self.model.combine_and_train.predict([input_image, h0_input])
830
831
        # step 2: Feedback system prediction
832
        feedback_latent = self.model.feedback_latent.predict(predicted)  # feedback: hf
833
834
        predicted = self.model.combine_and_train.predict([input_image, feedback_latent])
835
        predicted = remove_outliers_in_sagittal(predicted)
836
        save_nii_images(
837
            image=[predicted, input_image], identifier=str(case_name), name=[case_name + "_predicted",
838
                                                                             case_name + "_pet"],
839
            path_save= os.path.join(str(self.predicted_dir), 'predicted_data')
840
            )
841
842
843
if __name__ == '__main__':
844
    train_valid_data_dir = r"E:\LFBNet\data\remarc_default_MIP_dir/"
845
    train_valid_ids_path_csv = r'E:\LFBNet\data\csv\training_validation_indexs\remarc/'
846
    train_ids, valid_ids = get_training_and_validation_ids_from_csv(train_valid_ids_path_csv)
847
848
    trainer = NetworkTrainer(
849
        folder_preprocessed_train=train_valid_data_dir, folder_preprocessed_valid=train_valid_data_dir,
850
        ids_to_read_train=train_ids, ids_to_read_valid=valid_ids
851
        )
852
    trainer.train()