--- a
+++ b/src/run/trainer.py
@@ -0,0 +1,852 @@
+""" trainer python script
+
+This script allows training the proposed lfbnet model.
+This script requires to specify the directory path to the preprocessed PET MIP images. It could read the patient ids
+from
+the given directory path, or it could accept patient ids as .xls and .csv files. please provide the directory path to
+the
+csv or xls file. It assumes the csv/xls file have two columns with level 'train' and 'valid' indicating the training and
+validation patient ids respectively.
+
+Please see the _name__ == '__main__': as example which is equivalent to:
+
+e.g.train_valid_data_dir = r"E:\LFBNet\data\remarc_default_MIP_dir/"
+    train_valid_ids_path_csv = r'E:\LFBNet\data\csv\training_validation_indexs\remarc/'
+    train_ids, valid_ids = get_training_and_validation_ids_from_csv(train_valid_ids_path_csv)
+
+    trainer = NetworkTrainer(
+        folder_preprocessed_train=train_valid_data_dir, folder_preprocessed_valid=train_valid_data_dir,
+        ids_to_read_train=train_ids,
+        ids_to_read_valid=valid_ids
+        )
+    trainer.train()
+"""
+# Import libraries
+import os
+import glob
+import sys
+import time
+from datetime import datetime
+
+import numpy as np
+from numpy.random import seed
+from random import randint
+from tqdm import tqdm
+from typing import Tuple, List
+from numpy import ndarray
+from copy import deepcopy
+from medpy.metric import binary
+import matplotlib.pyplot as plt
+from keras import backend as K
+import re
+
+# make LFBNet as parent directory, for absolute import libraries. local application import.
+p = os.path.abspath('../..')
+if p not in sys.path:
+    sys.path.append(p)
+
+# import LFBNet modules
+from src.LFBNet.data_loader import DataLoader
+from src.LFBNet.network_architecture import lfbnet
+from src.LFBNet.losses import losses
+from src.LFBNet.preprocessing import save_nii_images
+from src.LFBNet.utilities import train_valid_paths
+from src.LFBNet.postprocessing import remove_outliers_in_sagittal
+# choose cuda gpu
+CUDA_VISIBLE_DEVICES = 1
+os.environ["CUDA_VISIBLE_DEVICES"] = "1"
+
+# set randomness repetable across experiments.
+seed(1)
+
+# Define the parameters of the data to process
+K.set_image_data_format('channels_last')
+
+
+def default_training_parameters(
+        num_epochs: int = 5000, batch_size: int = 16, early_stop: int = None, fold_number: int = None,
+        model_name_save: List[str] = None, loss: str = None, metric: str = None
+        ) -> dict:
+    """ Configure default parameters for training.
+    Training parameters are setted here. For other options, the user should modifier these values.
+    Parameters
+    ----------
+    num_epochs: int, maximum number of epochs to train the model.
+    batch_size: int, number of images per batch
+    early_stop: int, the number of training epochs the model should train while it is not improving the accuracy.
+    fold_number: int, optional, fold number while applying cross-validation-based training.
+    model_name_save: str, model name to save
+    loss: str, loss funciton
+    metric: str, specify the metric, such as dice
+
+    Returns
+    -------
+    Returns configured dictionary for the training.
+
+    """
+    if early_stop is None:
+        # early stop 50 % of the maximum number of epochs
+        early_stop = int(num_epochs * 0.5)
+
+    if fold_number is None:
+        fold_number = 'fold_run_at_' + str(time.time())
+
+    if model_name_save is None:
+        model_name_save = ["forward_" + str(time.time()), "feedback_" + str(time.time())]
+
+    if loss is None:
+        loss = losses.LossMetric.dice_plus_binary_cross_entropy_loss
+
+    if metric is None:
+        metric = losses.LossMetric.dice_metric
+
+    config_trainer = {'num_epochs': num_epochs, 'batch_size': batch_size, 'num_early_stop': early_stop,
+                      'fold_number': fold_number, 'model_name_save_forward': model_name_save[0],
+                      'model_name_save_feedback': model_name_save[1], "custom_loss": loss, "custom_dice": metric}
+
+    return config_trainer
+
+
+def get_training_and_validation_ids_from_csv(path):
+    """ Get training and validation ids from a given csv or xls file. Assuming the training ids are given with column
+    name 'train' and validation ids in 'valid'
+
+    Parameters
+    ----------
+    path: directory path to the csv or xls file.
+
+    Returns
+    -------
+    Returns training and validation patient ids.
+
+
+    """
+    ids = train_valid_paths.read_csv_train_valid_index(path)
+    train, valid = ids[0], ids[1]
+    return train, valid
+
+
+def get_train_valid_ids_from_folder(path_train_valid, ratio_valid_data=0.25):
+    """ Returns the randomly split training and validation patient ids. The percentage of validation is given by the
+    ratio_valid_data.
+
+    Parameters
+    ----------
+    path_train_valid
+    ratio_valid_data
+
+     Returns
+    -------
+    Returns training patient id and validation patient ids respectively as in two array.
+
+    """
+    # given training and validation data on one folder, random splitting with .25% : train, valid
+    if len(path_train_valid) == 1:
+        all_cases_id = os.listdir(str(path_train_valid))  # all patients id
+
+        # make permutation in the given list
+        case_ids = np.array(all_cases_id)
+        indices = np.random.permutation(len(case_ids))
+        num_valid_data = int(ratio_valid_data * len(all_cases_id))
+
+        train, valid = indices[num_valid_data:], indices[:num_valid_data]
+        return [train, valid]
+
+
+class NetworkTrainer:
+    """
+    class to train the lfb net
+    """
+    # keep the best loss and dice while training : Value shared across all instances, methods
+    BEST_METRIC_VALIDATION = 0  # KEEP THE BEST VALIDATION METRIC SUCH AS THE DICE METRIC (BEST_DICE)
+    BEST_LOSS_VALIDATION = 100  # KEEP THE BEST VALIDATION LOSS SUCH AS THE LOSS VALUES (BEST_LOSS)
+    EARLY_STOP_COUNT = 0  # COUNTS THE NUMBER OF TRAINING ITERATIONS THE MODEL DID NOT INCREASE, TO COMPARE WITH THE
+    now = datetime.now()  # current time, date, month,
+    TRAINED_MODEL_IDENTIFIER = re.sub('[ :]', "_", now.ctime())
+
+    # EARLY STOP CRITERIA
+
+    def __init__(
+            self, config_trainer: dict = None, folder_preprocessed_train: str = '../data/train/',
+            folder_preprocessed_valid: str = '../data/valid/', ids_to_read_train: ndarray = None,
+            ids_to_read_valid: ndarray = None, task: str = 'valid', predicted_directory: str = '../data/predicted/',
+            save_predicted: bool = False
+            ):
+        """
+
+        :param config_trainer:
+        :param folder_preprocessed_train:
+        :param folder_preprocessed_valid:
+        :param ids_to_read_train:
+        :param ids_to_read_valid:
+        :param task:
+        :predicted_directory:
+        :save_predicted
+        """
+
+        if config_trainer is None:
+            self.config_trainer = deepcopy(default_training_parameters())
+
+        # training data
+        self.folder_preprocessed_train = folder_preprocessed_train
+        if ids_to_read_train is None:
+            ids_to_read_train = os.listdir(folder_preprocessed_train)
+
+        self.ids_to_read_train = ids_to_read_train
+
+        # validation data
+        self.folder_preprocessed_valid = folder_preprocessed_valid
+        if ids_to_read_valid is None:
+            ids_to_read_valid = os.listdir(folder_preprocessed_valid)
+        self.ids_to_read_valid = ids_to_read_valid
+
+        # save predicted directory:
+        self.save_all = save_predicted
+        self.predicted_directory = predicted_directory
+        # load the lfb_network architecture
+        self.model = lfbnet.LfbNet()
+        self.task = task
+
+        # forward network decoder
+
+        # latent feedback at zero time: means no feedback from feedback network
+        self.latent_dim = self.model.latent_dim
+        self.h_at_zero_time = np.zeros(
+            (int(self.config_trainer['batch_size']), int(self.latent_dim[0]), int(self.latent_dim[1]),
+             int(self.latent_dim[2])), np.float32
+            )
+
+    @staticmethod
+    def load_dataset(directory_: str = None, ids_to_read: List[str] = None):
+        """
+
+        :param ids_to_read:
+        :param directory_:
+        """
+        # load batch of data
+        data_loader = DataLoader(data_dir=directory_, ids_to_read=ids_to_read)
+        image_batch_ground_truth_batch = data_loader.get_batch_of_data()
+
+        batch_input_data, batch_output_data = image_batch_ground_truth_batch[0], image_batch_ground_truth_batch[1]
+        # expand dimension for the channel
+        batch_output_data = np.expand_dims(batch_output_data, axis=-1)
+        batch_input_data = np.expand_dims(batch_input_data, axis=-1)
+
+        return batch_input_data, batch_output_data
+
+    def load_latest_weight(self):
+        """ loads the weights of the model with the latest saved weight in the folder ./weight
+        """
+        # load the last trained weight in the folder weight
+        folder_path = r'./weight/'
+        file_type = r'\*.h5'
+        files = glob.glob(folder_path + file_type)
+        try:
+            max_file = max(files, key=os.path.getctime)
+        except:
+            raise Exception("weight could not found !")
+
+        base_name = str(os.path.basename(max_file))
+        print(base_name)
+        self.model.combine_and_train.load_weights('./weight/forward_system' + str(base_name.split('system')[1]))
+        # f
+        self.model.fcn_feedback.load_weights('./weight/feedback_system' + str(base_name.split('system')[1]))
+
+    def train(self):
+        """Train the model
+        """
+
+        batch_size = self.config_trainer['batch_size']
+        # self.load_latest_weight()
+        # training
+        if self.task == 'train':
+            # training
+            for current_epoch in range(self.config_trainer['num_epochs']):
+                feedback_loss_dice = []
+                forward_loss_dice = []
+                forward_decoder_loss_dice = []
+
+                # shuffle the index of the training data
+                index_read = np.random.permutation(int(len(self.ids_to_read_train)))
+                # read data
+                for selected_patient in range(len(index_read)):
+                    # get index of batch of data
+                    start = selected_patient * batch_size
+                    idx_list_batch = index_read[start:start + batch_size]
+                    # if there are still elements in the given batch
+                    if idx_list_batch.size > 0:
+                        # get index of Why not ? kk = indx_list_batch
+                        kk = [str(k) for i, k in enumerate(self.ids_to_read_train) if i in idx_list_batch]
+
+                        batch_input_data, batch_output_data = self.load_dataset(
+                            directory_=self.folder_preprocessed_train, ids_to_read=kk
+                            )
+
+                        assert len(batch_input_data) > 0, "batch of data not loaded correctly"
+
+                        # shuffle within the batch
+                        index_batch = np.random.permutation(int(batch_input_data.shape[0]))
+                        batch_input_data = batch_input_data[index_batch]
+                        batch_output_data = batch_output_data[index_batch]
+
+                        # batches per epoch: Selected batch might as in id could have more images than the batch size
+                        batch_per_epoch = int(batch_input_data.shape[0] / batch_size)
+                        for batch_per_epoch_ in range(batch_per_epoch):
+                            batch_input = batch_input_data[
+                                          batch_per_epoch_ * batch_size:(batch_per_epoch_ + 1) * batch_size]
+                            batch_output = batch_output_data[
+                                           batch_per_epoch_ * batch_size:(batch_per_epoch_ + 1) * batch_size]
+
+                            # Train forward models
+                            if current_epoch % 2 == 0:
+                                # step 1: train the forward network encoder and decoder
+                                loss, dice = self.model.combine_and_train.train_on_batch(
+                                    [batch_input, self.h_at_zero_time], [batch_output]
+                                    )  # self.h_at_zero_time
+                                forward_loss_dice.append([loss, dice])
+
+                            else:
+                                predicted_decoder = self.model.combine_and_train.predict(
+                                    [batch_input, self.h_at_zero_time]
+                                    )  # , self.h_at_zero_time
+
+                                # step 2: train the feedback network, considering the output of the forward network
+                                loss, dice = self.model.fcn_feedback.train_on_batch(predicted_decoder, batch_output)
+                                feedback_loss_dice.append([loss, dice])
+
+                                # Step 3: train the forward decoder, considering the trained
+                                feedback_latent_result = self.model.feedback_latent.predict([predicted_decoder])
+                                forward_encoder_output = self.model.forward_encoder.predict([batch_input])
+
+                                # forward_encoder_output.insert(1, feedback_latent_result)
+                                forward_encoder_output = forward_encoder_output[::-1]  # bottleneck should be first
+                                forward_encoder_output.insert(1, feedback_latent_result)
+                                loss, dice = self.model.forward_decoder.train_on_batch(
+                                    [output for output in forward_encoder_output], [batch_output]
+                                    )
+                                forward_decoder_loss_dice.append([loss, dice])
+
+                forward_loss_dice = np.array(forward_loss_dice)
+                feedback_loss_dice = np.array(feedback_loss_dice)
+                forward_decoder_loss_dice = np.array(forward_decoder_loss_dice)
+
+                if current_epoch % 2 == 0:
+                    loss, dice = np.mean(forward_loss_dice, axis=0)
+                    print(
+                        'Training_forward_system: >%d, '
+                        ' fwd_loss = %.3f, fwd_dice=%0.3f, ' % (current_epoch, loss, dice)
+                        )
+
+                else:
+                    loss_forward, dice_forward = np.mean(forward_decoder_loss_dice, axis=0)
+                    loss_feedback, dice_feedback = np.mean(feedback_loss_dice, axis=0)
+
+                    print(
+                        'Training_forward_decoder_and_feedback_system: >%d, '
+                        'fwd_decoder_loss=%03f, '
+                        'fwd_decoder_dice=%0.3f '
+
+                        'fdb_loss=%03f, '
+                        'fdb_dice=%.3f ' % (current_epoch, loss_forward, dice_forward, loss_feedback, dice_feedback)
+                        )
+                # validation test:
+                self.validation(current_epoch=current_epoch)
+
+                # CHECK TRAINING STOPPING CRITERIA:  maximum number of epochs (epoch - 1), meet early stop
+                if NetworkTrainer.EARLY_STOP_COUNT == self.config_trainer['num_early_stop']:
+                    # save model with early stop identification
+                    self.model.combine_and_train.save(
+                        'weight/forward_system_early_stopped_' + str(NetworkTrainer.TRAINED_MODEL_IDENTIFIER) + '_.h5'
+                        )
+                    self.model.fcn_feedback.save(
+                        'weight/feedback_system_early_stopped_' + str(NetworkTrainer.TRAINED_MODEL_IDENTIFIER) + '_.h5'
+                        )
+                    break  # STOP TRAINING WITH BREAK, OR EXIT TRAINING
+
+        # if not training load the last saved weights, and check validation
+        elif self.task == 'valid':
+            # load the last trained weight in the folder weight
+            self.load_latest_weight()
+            self.validation(current_epoch=self.config_trainer['num_epochs'])
+
+    def validation(self, verbose: int = 0, current_epoch: int = None):
+        """
+        Compute the validation dice, loss of the training from the validation data
+        """
+        # path to the validation data, if not specified, the default path ../data/valid/ would be considered
+
+        folder_preprocessed = self.folder_preprocessed_valid
+
+        # image folder names, or identifier: if not specified the default values would be the name of the folder inside
+        # the directory "folder processed" or the self.folder_processed_valid :
+
+        valid_identifier = self.ids_to_read_valid
+
+        '''
+        WE CAN IMPLEMENT THE EVALUATION METHOD AS BATCH BASED, PATIENT BASED, OR THE WHOLE-VALIDATION DATA BASED. FOR
+        THE LAST OPTION WE NEED TO IMPLEMENT THE EVALUATION() FUNCTION HERE. 
+        '''
+
+        ''''
+        declare variables to return: 
+                        forward loss and dice with h0 (no feedback), 
+                        feedback network loss and dice
+                        forward decoder loss and dice,
+                        forward loss and dice with ht (with feedback latent space)
+        '''
+        loss_dice = {'loss_fwd_h0': [], 'dice__fwd_h0': [], 'loss_fdb_h0': [], 'dice_fdb_h0': [],
+                     'loss_fwd_decoder': [], 'dice_fwd_decoder': [], 'loss_fwd_ht': [], 'dice_fwd_ht': []}
+
+        all_dice_sen_sep = {'dice': [], 'specificity': [], 'sensitivity': []}
+
+        # load the dataset,
+        # get the validation ids
+        for id_to_validate in valid_identifier:
+            try:
+                id_to_validate = str(id_to_validate).split('.')[0]
+            except:
+                pass
+
+            valid_input, valid_output = self.load_dataset(directory_=folder_preprocessed, ids_to_read=[id_to_validate])
+
+            if len(valid_input) == 0:
+                print("data %s not read" % id_to_validate)
+                continue
+
+            results, dice_sen_sep = self.evaluation(
+                input_image=valid_input.copy(), ground_truth=valid_output.copy(), case_name=str(id_to_validate)
+                )
+
+            # append all loss to loss and dice to dice from all cases in valid identifiers
+            for keys in results.keys():
+                loss_dice[str(keys)].append(results[str(keys)][0])
+
+            for keys in dice_sen_sep.keys():
+                all_dice_sen_sep[str(keys)].append(dice_sen_sep[str(keys)][0])
+
+        print("\n Dice, sensitivity, specificity \t")
+        for k, v in all_dice_sen_sep.items():
+            print('%s :  %0.3f ' % (k, np.mean(list(v), axis=0)), end=" ")
+        print("\n")
+
+        """
+        print the mean of the validation loss and validation dice
+        """
+
+        # FOR STOPPING CRITERIA WE ARE USING THE MODEL AT THE 3RD STEP
+        dice_mean = np.mean(loss_dice['dice_fwd_ht'])
+        loss_mean = np.mean(loss_dice['loss_fwd_ht'])
+
+        # at the first epoch
+        if current_epoch == 0:
+            NetworkTrainer.BEST_METRIC_VALIDATION = dice_mean
+            NetworkTrainer.BEST_LOSS_VALIDATION = loss_mean
+
+        # compare the current dice and loss with the previous epoch's loss and dice:
+        # NOW CONSIDER DICE AS OPTIMIZATION METRIC
+        print("Current validation loss and metrics at epoch %d: >> " % current_epoch, end=" ")
+        for k, v in loss_dice.items():
+            print('%s :  %0.3f ' % (k, np.mean(v)), end=" ")
+        print("\n")
+
+        if NetworkTrainer.BEST_METRIC_VALIDATION <= dice_mean:
+            # reset early stop count, best dice, and best loss values
+            NetworkTrainer.BEST_LOSS_VALIDATION = loss_mean
+            NetworkTrainer.BEST_METRIC_VALIDATION = dice_mean
+            NetworkTrainer.EARLY_STOP_COUNT = 0
+
+            # save the best model weights
+            if not os.path.exists('./weight'):
+                os.mkdir('./weight')
+
+            self.model.combine_and_train.save(
+                'weight/forward_system_' + str(NetworkTrainer.TRAINED_MODEL_IDENTIFIER) + '_.h5'
+                )
+            self.model.fcn_feedback.save(
+                'weight/feedback_system_' + str(NetworkTrainer.TRAINED_MODEL_IDENTIFIER) + '_.h5'
+                )
+        else:  # just print the current validation metric (dice) and loss, and count early stop
+            # Increase the early stop count per epoch
+            NetworkTrainer.EARLY_STOP_COUNT += 1
+
+        print(
+            '\n Best model on validation data : %0.3f :  Dice: %0.3f \n' % (
+                NetworkTrainer.BEST_LOSS_VALIDATION, NetworkTrainer.BEST_METRIC_VALIDATION)
+            )
+
+    def evaluation(
+            self, verbose: int = 0, input_image: ndarray = None, ground_truth: ndarray = None,
+            validation_or_test: str = 'test', case_name: str = None
+            ):
+        """
+
+        :param case_name:
+        :param validation_or_test:
+        :param verbose:
+        :param input_image:
+        :param ground_truth:
+
+        Parameters
+        ----------
+        save_all
+        """
+        ''''
+        declare variables to return: 
+                        forward loss and dice with h0 (no feedback), 
+                        feedback network loss and dice
+                        forward decoder loss and dice,
+                        forward loss and dice with ht (with feedback latent space)
+        '''
+        all_loss_dice = {'loss_fwd_h0': [], 'dice__fwd_h0': [], 'loss_fdb_h0': [], 'dice_fdb_h0': [],
+                         'loss_fwd_decoder': [], 'dice_fwd_decoder': [], 'loss_fwd_ht': [], 'dice_fwd_ht': []}
+
+        dice_sen_sp = {'dice': [], 'specificity': [], 'sensitivity': []}
+
+        # latent  feedback variable h0
+        # replace the first number of batches with the number of input images from the first channel
+        h0_input = np.zeros(
+            (len(input_image), int(self.latent_dim[0]), int(self.latent_dim[1]), int(self.latent_dim[2])), np.float32
+            )
+
+        # step 0:
+        # Loss and dice on the validation of the forward system
+        loss, dice = self.model.combine_and_train.evaluate([input_image, h0_input], [ground_truth], verbose=verbose)
+        all_loss_dice['loss_fwd_h0'].append(loss), all_loss_dice['dice__fwd_h0'].append(dice)
+
+        # predict from the forward system
+        predicted = self.model.combine_and_train.predict([input_image, h0_input])
+
+        # step 2:
+        # Loss and dice on the validation of the feedback system
+        loss, dice = self.model.fcn_feedback.evaluate([predicted], [ground_truth], verbose=verbose)
+        all_loss_dice['loss_fdb_h0'].append(loss), all_loss_dice['dice_fdb_h0'].append(dice)
+
+        # step 3:
+        feedback_latent = self.model.feedback_latent.predict(predicted)  # feedback: hf
+        forward_encoder_output = self.model.forward_encoder.predict([input_image])  # forward system's encoder output
+
+        forward_encoder_output = forward_encoder_output[::-1]  # bottleneck should be first
+        forward_encoder_output.insert(1, feedback_latent)
+        loss, dice = self.model.forward_decoder.evaluate(
+            [output for output in forward_encoder_output], [ground_truth], verbose=verbose
+            )
+        all_loss_dice['loss_fwd_decoder'].append(loss), all_loss_dice['dice_fwd_decoder'].append(dice)
+
+        # loss and dice from the combined and feed back latent space : input  [input_image, fdb_latent_space]
+        loss, dice = self.model.combine_and_train.evaluate(
+            [input_image, feedback_latent], [ground_truth], verbose=verbose
+            )
+        all_loss_dice['loss_fwd_ht'].append(loss), all_loss_dice['dice_fwd_ht'].append(dice)
+        """
+        For the testing time, we use defined metrics on the predicted images instead of using model.evaluate during 
+        the validation cases 
+        """
+        predicted = self.model.combine_and_train.predict([input_image, feedback_latent])
+
+        # binary.dc, sen, and specificty works only on binary images
+        dice_sen_sp['dice'].append(
+            binary.dc(NetworkTrainer.threshold_image(predicted), NetworkTrainer.threshold_image(ground_truth))
+            )
+        dice_sen_sp['sensitivity'].append(
+            binary.sensitivity(NetworkTrainer.threshold_image(predicted), NetworkTrainer.threshold_image(ground_truth))
+            )
+        dice_sen_sp['specificity'].append(
+            binary.specificity(NetworkTrainer.threshold_image(predicted), NetworkTrainer.threshold_image(ground_truth))
+            )
+        # all = np.concatenate((ground_truth, predicted, input_image), axis=0)
+        # display_image(all)
+
+        # Sometimes save predictions
+        if self.save_all:
+            predicted = self.model.combine_and_train.predict([input_image, feedback_latent])
+            save_nii_images(
+                [predicted, ground_truth, input_image], identifier=str(case_name),
+                name=[case_name + "_predicted", case_name + "_ground_truth", case_name + "_image"], path_save=self.predicted_directory
+                )
+        else:
+
+            n = randint(0, 10)
+            if n % 3 == 0:
+                predicted = self.model.combine_and_train.predict([input_image, feedback_latent])
+                save_nii_images(
+                    [predicted, ground_truth, input_image], identifier=str(case_name),
+                    name=[case_name + "_predicted", case_name + "_ground_truth", case_name + "_image"], path_save=self.predicted_directory
+                    )
+
+        return all_loss_dice, dice_sen_sp
+
+    @staticmethod
+    def display_image(im_display: ndarray):
+        """ display given images
+
+        :param all: 2D image arrays to display
+        :returns: display images
+        """
+        plt.figure(figsize=(10, 8))
+        plt.subplots_adjust(hspace=0.5)
+        plt.suptitle("Daily closing prices", fontsize=18, y=0.95)
+        # loop through the length of tickers and keep track of index
+        for n, im in enumerate(im_display):
+            # add a new subplot iteratively
+            plt.subplot(3, 2, n + 1)
+            plt.imshow(im)  # chart formatting
+        plt.show()
+
+    @staticmethod
+    # binary.dc, sen, and specificty works only on binary images
+    def threshold_image(im_: ndarray, thr_value: float = 0.5) -> ndarray:
+        """ threshold given input array with the given thresholding value
+
+        :param im_: ndarray of images
+        :param thr_value: thresholding value
+        :return: threshold array image
+        """
+        im_[im_ > thr_value] = 1
+        im_[im_ < thr_value] = 0
+        return im_
+
+
+class ModelTesting:
+    """ performs prediction on a given data set. It predicts the segmentation results, and save the results, calculate
+    the clinical metrics such as TMTV, Dmax, sTMTV, sDmax.
+
+    """
+    now = datetime.now()  # current time, date, month,
+    TRAINED_MODEL_IDENTIFIER = re.sub('[ :]', "_", now.ctime())
+    print("current directory", os.getcwd())
+
+    def __init__(
+            self, config_test: dict = None, preprocessed_dir: str = '../data/test/', data_list: List[str] = None,
+            predicted_dir: str = "../data/predicted"
+            ):
+        """
+
+        :param config_trainer:
+        :param folder_preprocessed_train:
+        :param folder_preprocessed_valid:
+        :param ids_to_read_train:
+        :param ids_to_read_valid:
+        :param task:
+        :param predicted_dir:
+        """
+
+        if config_test is None:
+            self.config_test = deepcopy(default_training_parameters())
+
+        # training data
+        self.preprocessed_dir = preprocessed_dir
+        self.predicted_dir = predicted_dir
+
+        # if the list of testing cases are not given, get from the directory
+        if data_list is None:
+            data_list = os.listdir(preprocessed_dir)
+
+        self.data_list = data_list
+
+        # load the lfb_network architecture
+        self.model = lfbnet.LfbNet()
+
+        # latent feedback at zero time: means no feedback from feedback network
+        self.latent_dim = self.model.latent_dim
+
+        # load the last trained weight in the folder weight
+        print(os.getcwd())
+        folder_path = os.path.join(os.getcwd(), 'src/weight')
+        print(folder_path)
+
+        full_path = [path_i for path_i in glob.glob(str(folder_path) + '/*.h5')]
+
+        print("files \n", full_path)
+        try:
+            max_file = max(full_path, key=os.path.getctime)
+        except:
+            raise Exception("weight could not found !")
+
+        base_name = str(os.path.basename(max_file))
+        print(base_name)
+        self.model.combine_and_train.load_weights(
+            str(folder_path) + '/forward_system' + str(base_name.split('system')[1])
+            )
+        # f
+        self.model.fcn_feedback.load_weights(str(folder_path) + '/feedback_system' + str(base_name.split('system')[1]))
+
+        self.test()
+
+    def test(self):
+        """
+                   Compute the validation dice, loss of the training from the validation data
+           """
+        # path to the validation data, if not specified, the default path ../data/valid/ would be considered
+        #
+        folder_preprocessed = self.preprocessed_dir
+        # image folder names, or identifier: if not specified the default values would be the name of the folder inside
+        # the directory "folder processed" or the self.folder_processed_valid :
+        test_identifier = self.data_list
+
+        ''''
+        declare variables to return if there is a reference segmentation or ground truth : 
+                        forward loss and dice with h0 (no feedback), 
+                        feedback network loss and dice
+                        forward decoder loss and dice,
+                        forward loss and dice with ht (with feedback latent space)
+        '''
+        loss_dice = {'loss_fwd_h0': [], 'dice__fwd_h0': [], 'loss_fdb_h0': [], 'dice_fdb_h0': [],
+                     'loss_fwd_decoder': [], 'dice_fwd_decoder': [], 'loss_fwd_ht': [], 'dice_fwd_ht': []}
+
+        # get the validation ids
+        test_output = []
+        for id_to_test in tqdm(list(test_identifier)):
+            test_input, test_output = NetworkTrainer.load_dataset(
+                directory_=folder_preprocessed, ids_to_read=[id_to_test]
+                )
+
+            if len(test_input) == 0:
+                print("data %s not read" % id_to_test)
+                continue
+
+            '''
+            if there is a ground truth segmentation (gt), and you would like to compare with the predicted segmentation
+            by the deep learning model
+            '''
+
+            if len(test_output):
+                results = self.evaluation_test(
+                    input_image=test_input.copy(), ground_truth=test_output.copy(), case_name=str(id_to_test)
+                    )
+
+                # append all loss to loss and dice to dice from all cases in valid identifiers
+                for keys in results.keys():
+                    loss_dice[str(keys)].append(results[str(keys)][0])
+
+                print("Results (sagittal and coronal) for case id: %s  : >> " % id_to_test, end=" ")
+                for k, v in loss_dice.items():
+                    print('%s :  %0.3f ' % (k, np.mean(v)), end=" ")
+                print("\n")
+
+            # Predict the segmentation and save in the folder predicted, dataset identifier
+            else:
+                self.prediction(input_image=test_input.copy(), case_name=str(id_to_test))
+
+        """
+        print the mean of the testing loss and dice if there is a ground truth, for all cases 
+        """
+        if len(test_output):
+            print("Total dataset metrics:  : >> ", end=" ")
+            for k, v in loss_dice.items():
+                print('%s :  %0.3f ' % (k, np.mean(v)), end=" ")
+            print("\n")
+
+    def evaluation_test(
+            self, verbose: int = 0, input_image: ndarray = None, ground_truth: ndarray = None,
+            validation_or_test: str = 'validate', case_name: str = None
+            ):
+        """
+
+        :param case_name:
+        :param validation_or_test:
+        :param verbose:
+        :param input_image:
+        :param ground_truth:
+        """
+        ''''
+        declare variables to return: 
+                        forward loss and dice with h0 (no feedback), 
+                        feedback network loss and dice
+                        forward decoder loss and dice,
+                        forward loss and dice with ht (with feedback latent space)
+        '''
+        all_loss_dice = {'loss_fwd_h0': [], 'dice__fwd_h0': [], 'loss_fdb_h0': [], 'dice_fdb_h0': [],
+                         'loss_fwd_decoder': [], 'dice_fwd_decoder': [], 'loss_fwd_ht': [], 'dice_fwd_ht': []}
+        # latent  feedback variable h0
+        # replace the first number of batches with the number of input images from the first channel
+        h0_input = np.zeros(
+            (len(input_image), int(self.latent_dim[0]), int(self.latent_dim[1]), int(self.latent_dim[2])), np.float32
+            )
+
+        # step 0:
+        # Loss and dice on the validation of the forward system
+        loss, dice = self.model.combine_and_train.evaluate([input_image, h0_input], [ground_truth], verbose=verbose)
+        all_loss_dice['loss_fwd_h0'].append(loss), all_loss_dice['dice__fwd_h0'].append(dice)
+
+        # predict from the forward system
+        predicted = self.model.combine_and_train.predict([input_image, h0_input])
+
+        # step 2:
+        # Loss and dice on the validation of the feedback system
+        loss, dice = self.model.fcn_feedback.evaluate([predicted], [ground_truth], verbose=verbose)
+        all_loss_dice['loss_fdb_h0'].append(loss), all_loss_dice['dice_fdb_h0'].append(dice)
+
+        # step 3:
+        feedback_latent = self.model.feedback_latent.predict(predicted)  # feedback: hf
+        forward_encoder_output = self.model.forward_encoder.predict([input_image])  # forward system's encoder output
+
+        forward_encoder_output = forward_encoder_output[::-1]  # bottleneck should be first
+        forward_encoder_output.insert(1, feedback_latent)
+        loss, dice = self.model.forward_decoder.evaluate(
+            [output for output in forward_encoder_output], [ground_truth], verbose=verbose
+            )
+        all_loss_dice['loss_fwd_decoder'].append(loss), all_loss_dice['dice_fwd_decoder'].append(dice)
+
+        # loss and dice from the combined and feed back latent space : input  [input_image, fdb_latent_space]
+        loss, dice = self.model.combine_and_train.evaluate(
+            [input_image, feedback_latent], [ground_truth], verbose=verbose
+            )
+        all_loss_dice['loss_fwd_ht'].append(loss), all_loss_dice['dice_fwd_ht'].append(dice)
+
+        """
+        For the testing time, we use defined metrics on the predicted images instead of using model.evaluate during 
+        the validation cases 
+        """
+        if validation_or_test == "test":
+            # return [dice, specificity, and sensitivity
+            return {'dice': binary.dc(predicted, ground_truth),
+                    'specificity': binary.specificity(predicted, ground_truth),
+                    'sensitivity': binary.sensitivity(predicted, ground_truth)}
+
+        predicted = self.model.combine_and_train.predict([input_image, feedback_latent])
+        predicted = remove_outliers_in_sagittal(predicted)
+        save_nii_images(
+            [predicted, ground_truth, input_image], identifier=str(case_name),
+            name=[case_name + "_predicted", case_name + "_ground_truth", case_name + "_pet"],
+            path_save= os.path.join(str(self.predicted_dir), 'predicted_data')
+            )
+
+        return all_loss_dice
+
+    def prediction(self, input_image: ndarray = None, case_name: str = None):
+        """
+        :param case_name:
+        :param input_image:
+        """
+        # latent  feedback variable h0
+        # replace the first number of batches with the number of input images from the first channel
+        h0_input = np.zeros(
+            (len(input_image), int(self.latent_dim[0]), int(self.latent_dim[1]), int(self.latent_dim[2])), np.float32
+            )
+
+        # STEP 1: forward system prediction
+        # predict from the forward system
+        predicted = self.model.combine_and_train.predict([input_image, h0_input])
+
+        # step 2: Feedback system prediction
+        feedback_latent = self.model.feedback_latent.predict(predicted)  # feedback: hf
+
+        predicted = self.model.combine_and_train.predict([input_image, feedback_latent])
+        predicted = remove_outliers_in_sagittal(predicted)
+        save_nii_images(
+            image=[predicted, input_image], identifier=str(case_name), name=[case_name + "_predicted",
+                                                                             case_name + "_pet"],
+            path_save= os.path.join(str(self.predicted_dir), 'predicted_data')
+            )
+
+
+if __name__ == '__main__':
+    train_valid_data_dir = r"E:\LFBNet\data\remarc_default_MIP_dir/"
+    train_valid_ids_path_csv = r'E:\LFBNet\data\csv\training_validation_indexs\remarc/'
+    train_ids, valid_ids = get_training_and_validation_ids_from_csv(train_valid_ids_path_csv)
+
+    trainer = NetworkTrainer(
+        folder_preprocessed_train=train_valid_data_dir, folder_preprocessed_valid=train_valid_data_dir,
+        ids_to_read_train=train_ids, ids_to_read_valid=valid_ids
+        )
+    trainer.train()