--- a +++ b/src/run/trainer.py @@ -0,0 +1,852 @@ +""" trainer python script + +This script allows training the proposed lfbnet model. +This script requires to specify the directory path to the preprocessed PET MIP images. It could read the patient ids +from +the given directory path, or it could accept patient ids as .xls and .csv files. please provide the directory path to +the +csv or xls file. It assumes the csv/xls file have two columns with level 'train' and 'valid' indicating the training and +validation patient ids respectively. + +Please see the _name__ == '__main__': as example which is equivalent to: + +e.g.train_valid_data_dir = r"E:\LFBNet\data\remarc_default_MIP_dir/" + train_valid_ids_path_csv = r'E:\LFBNet\data\csv\training_validation_indexs\remarc/' + train_ids, valid_ids = get_training_and_validation_ids_from_csv(train_valid_ids_path_csv) + + trainer = NetworkTrainer( + folder_preprocessed_train=train_valid_data_dir, folder_preprocessed_valid=train_valid_data_dir, + ids_to_read_train=train_ids, + ids_to_read_valid=valid_ids + ) + trainer.train() +""" +# Import libraries +import os +import glob +import sys +import time +from datetime import datetime + +import numpy as np +from numpy.random import seed +from random import randint +from tqdm import tqdm +from typing import Tuple, List +from numpy import ndarray +from copy import deepcopy +from medpy.metric import binary +import matplotlib.pyplot as plt +from keras import backend as K +import re + +# make LFBNet as parent directory, for absolute import libraries. local application import. +p = os.path.abspath('../..') +if p not in sys.path: + sys.path.append(p) + +# import LFBNet modules +from src.LFBNet.data_loader import DataLoader +from src.LFBNet.network_architecture import lfbnet +from src.LFBNet.losses import losses +from src.LFBNet.preprocessing import save_nii_images +from src.LFBNet.utilities import train_valid_paths +from src.LFBNet.postprocessing import remove_outliers_in_sagittal +# choose cuda gpu +CUDA_VISIBLE_DEVICES = 1 +os.environ["CUDA_VISIBLE_DEVICES"] = "1" + +# set randomness repetable across experiments. +seed(1) + +# Define the parameters of the data to process +K.set_image_data_format('channels_last') + + +def default_training_parameters( + num_epochs: int = 5000, batch_size: int = 16, early_stop: int = None, fold_number: int = None, + model_name_save: List[str] = None, loss: str = None, metric: str = None + ) -> dict: + """ Configure default parameters for training. + Training parameters are setted here. For other options, the user should modifier these values. + Parameters + ---------- + num_epochs: int, maximum number of epochs to train the model. + batch_size: int, number of images per batch + early_stop: int, the number of training epochs the model should train while it is not improving the accuracy. + fold_number: int, optional, fold number while applying cross-validation-based training. + model_name_save: str, model name to save + loss: str, loss funciton + metric: str, specify the metric, such as dice + + Returns + ------- + Returns configured dictionary for the training. + + """ + if early_stop is None: + # early stop 50 % of the maximum number of epochs + early_stop = int(num_epochs * 0.5) + + if fold_number is None: + fold_number = 'fold_run_at_' + str(time.time()) + + if model_name_save is None: + model_name_save = ["forward_" + str(time.time()), "feedback_" + str(time.time())] + + if loss is None: + loss = losses.LossMetric.dice_plus_binary_cross_entropy_loss + + if metric is None: + metric = losses.LossMetric.dice_metric + + config_trainer = {'num_epochs': num_epochs, 'batch_size': batch_size, 'num_early_stop': early_stop, + 'fold_number': fold_number, 'model_name_save_forward': model_name_save[0], + 'model_name_save_feedback': model_name_save[1], "custom_loss": loss, "custom_dice": metric} + + return config_trainer + + +def get_training_and_validation_ids_from_csv(path): + """ Get training and validation ids from a given csv or xls file. Assuming the training ids are given with column + name 'train' and validation ids in 'valid' + + Parameters + ---------- + path: directory path to the csv or xls file. + + Returns + ------- + Returns training and validation patient ids. + + + """ + ids = train_valid_paths.read_csv_train_valid_index(path) + train, valid = ids[0], ids[1] + return train, valid + + +def get_train_valid_ids_from_folder(path_train_valid, ratio_valid_data=0.25): + """ Returns the randomly split training and validation patient ids. The percentage of validation is given by the + ratio_valid_data. + + Parameters + ---------- + path_train_valid + ratio_valid_data + + Returns + ------- + Returns training patient id and validation patient ids respectively as in two array. + + """ + # given training and validation data on one folder, random splitting with .25% : train, valid + if len(path_train_valid) == 1: + all_cases_id = os.listdir(str(path_train_valid)) # all patients id + + # make permutation in the given list + case_ids = np.array(all_cases_id) + indices = np.random.permutation(len(case_ids)) + num_valid_data = int(ratio_valid_data * len(all_cases_id)) + + train, valid = indices[num_valid_data:], indices[:num_valid_data] + return [train, valid] + + +class NetworkTrainer: + """ + class to train the lfb net + """ + # keep the best loss and dice while training : Value shared across all instances, methods + BEST_METRIC_VALIDATION = 0 # KEEP THE BEST VALIDATION METRIC SUCH AS THE DICE METRIC (BEST_DICE) + BEST_LOSS_VALIDATION = 100 # KEEP THE BEST VALIDATION LOSS SUCH AS THE LOSS VALUES (BEST_LOSS) + EARLY_STOP_COUNT = 0 # COUNTS THE NUMBER OF TRAINING ITERATIONS THE MODEL DID NOT INCREASE, TO COMPARE WITH THE + now = datetime.now() # current time, date, month, + TRAINED_MODEL_IDENTIFIER = re.sub('[ :]', "_", now.ctime()) + + # EARLY STOP CRITERIA + + def __init__( + self, config_trainer: dict = None, folder_preprocessed_train: str = '../data/train/', + folder_preprocessed_valid: str = '../data/valid/', ids_to_read_train: ndarray = None, + ids_to_read_valid: ndarray = None, task: str = 'valid', predicted_directory: str = '../data/predicted/', + save_predicted: bool = False + ): + """ + + :param config_trainer: + :param folder_preprocessed_train: + :param folder_preprocessed_valid: + :param ids_to_read_train: + :param ids_to_read_valid: + :param task: + :predicted_directory: + :save_predicted + """ + + if config_trainer is None: + self.config_trainer = deepcopy(default_training_parameters()) + + # training data + self.folder_preprocessed_train = folder_preprocessed_train + if ids_to_read_train is None: + ids_to_read_train = os.listdir(folder_preprocessed_train) + + self.ids_to_read_train = ids_to_read_train + + # validation data + self.folder_preprocessed_valid = folder_preprocessed_valid + if ids_to_read_valid is None: + ids_to_read_valid = os.listdir(folder_preprocessed_valid) + self.ids_to_read_valid = ids_to_read_valid + + # save predicted directory: + self.save_all = save_predicted + self.predicted_directory = predicted_directory + # load the lfb_network architecture + self.model = lfbnet.LfbNet() + self.task = task + + # forward network decoder + + # latent feedback at zero time: means no feedback from feedback network + self.latent_dim = self.model.latent_dim + self.h_at_zero_time = np.zeros( + (int(self.config_trainer['batch_size']), int(self.latent_dim[0]), int(self.latent_dim[1]), + int(self.latent_dim[2])), np.float32 + ) + + @staticmethod + def load_dataset(directory_: str = None, ids_to_read: List[str] = None): + """ + + :param ids_to_read: + :param directory_: + """ + # load batch of data + data_loader = DataLoader(data_dir=directory_, ids_to_read=ids_to_read) + image_batch_ground_truth_batch = data_loader.get_batch_of_data() + + batch_input_data, batch_output_data = image_batch_ground_truth_batch[0], image_batch_ground_truth_batch[1] + # expand dimension for the channel + batch_output_data = np.expand_dims(batch_output_data, axis=-1) + batch_input_data = np.expand_dims(batch_input_data, axis=-1) + + return batch_input_data, batch_output_data + + def load_latest_weight(self): + """ loads the weights of the model with the latest saved weight in the folder ./weight + """ + # load the last trained weight in the folder weight + folder_path = r'./weight/' + file_type = r'\*.h5' + files = glob.glob(folder_path + file_type) + try: + max_file = max(files, key=os.path.getctime) + except: + raise Exception("weight could not found !") + + base_name = str(os.path.basename(max_file)) + print(base_name) + self.model.combine_and_train.load_weights('./weight/forward_system' + str(base_name.split('system')[1])) + # f + self.model.fcn_feedback.load_weights('./weight/feedback_system' + str(base_name.split('system')[1])) + + def train(self): + """Train the model + """ + + batch_size = self.config_trainer['batch_size'] + # self.load_latest_weight() + # training + if self.task == 'train': + # training + for current_epoch in range(self.config_trainer['num_epochs']): + feedback_loss_dice = [] + forward_loss_dice = [] + forward_decoder_loss_dice = [] + + # shuffle the index of the training data + index_read = np.random.permutation(int(len(self.ids_to_read_train))) + # read data + for selected_patient in range(len(index_read)): + # get index of batch of data + start = selected_patient * batch_size + idx_list_batch = index_read[start:start + batch_size] + # if there are still elements in the given batch + if idx_list_batch.size > 0: + # get index of Why not ? kk = indx_list_batch + kk = [str(k) for i, k in enumerate(self.ids_to_read_train) if i in idx_list_batch] + + batch_input_data, batch_output_data = self.load_dataset( + directory_=self.folder_preprocessed_train, ids_to_read=kk + ) + + assert len(batch_input_data) > 0, "batch of data not loaded correctly" + + # shuffle within the batch + index_batch = np.random.permutation(int(batch_input_data.shape[0])) + batch_input_data = batch_input_data[index_batch] + batch_output_data = batch_output_data[index_batch] + + # batches per epoch: Selected batch might as in id could have more images than the batch size + batch_per_epoch = int(batch_input_data.shape[0] / batch_size) + for batch_per_epoch_ in range(batch_per_epoch): + batch_input = batch_input_data[ + batch_per_epoch_ * batch_size:(batch_per_epoch_ + 1) * batch_size] + batch_output = batch_output_data[ + batch_per_epoch_ * batch_size:(batch_per_epoch_ + 1) * batch_size] + + # Train forward models + if current_epoch % 2 == 0: + # step 1: train the forward network encoder and decoder + loss, dice = self.model.combine_and_train.train_on_batch( + [batch_input, self.h_at_zero_time], [batch_output] + ) # self.h_at_zero_time + forward_loss_dice.append([loss, dice]) + + else: + predicted_decoder = self.model.combine_and_train.predict( + [batch_input, self.h_at_zero_time] + ) # , self.h_at_zero_time + + # step 2: train the feedback network, considering the output of the forward network + loss, dice = self.model.fcn_feedback.train_on_batch(predicted_decoder, batch_output) + feedback_loss_dice.append([loss, dice]) + + # Step 3: train the forward decoder, considering the trained + feedback_latent_result = self.model.feedback_latent.predict([predicted_decoder]) + forward_encoder_output = self.model.forward_encoder.predict([batch_input]) + + # forward_encoder_output.insert(1, feedback_latent_result) + forward_encoder_output = forward_encoder_output[::-1] # bottleneck should be first + forward_encoder_output.insert(1, feedback_latent_result) + loss, dice = self.model.forward_decoder.train_on_batch( + [output for output in forward_encoder_output], [batch_output] + ) + forward_decoder_loss_dice.append([loss, dice]) + + forward_loss_dice = np.array(forward_loss_dice) + feedback_loss_dice = np.array(feedback_loss_dice) + forward_decoder_loss_dice = np.array(forward_decoder_loss_dice) + + if current_epoch % 2 == 0: + loss, dice = np.mean(forward_loss_dice, axis=0) + print( + 'Training_forward_system: >%d, ' + ' fwd_loss = %.3f, fwd_dice=%0.3f, ' % (current_epoch, loss, dice) + ) + + else: + loss_forward, dice_forward = np.mean(forward_decoder_loss_dice, axis=0) + loss_feedback, dice_feedback = np.mean(feedback_loss_dice, axis=0) + + print( + 'Training_forward_decoder_and_feedback_system: >%d, ' + 'fwd_decoder_loss=%03f, ' + 'fwd_decoder_dice=%0.3f ' + + 'fdb_loss=%03f, ' + 'fdb_dice=%.3f ' % (current_epoch, loss_forward, dice_forward, loss_feedback, dice_feedback) + ) + # validation test: + self.validation(current_epoch=current_epoch) + + # CHECK TRAINING STOPPING CRITERIA: maximum number of epochs (epoch - 1), meet early stop + if NetworkTrainer.EARLY_STOP_COUNT == self.config_trainer['num_early_stop']: + # save model with early stop identification + self.model.combine_and_train.save( + 'weight/forward_system_early_stopped_' + str(NetworkTrainer.TRAINED_MODEL_IDENTIFIER) + '_.h5' + ) + self.model.fcn_feedback.save( + 'weight/feedback_system_early_stopped_' + str(NetworkTrainer.TRAINED_MODEL_IDENTIFIER) + '_.h5' + ) + break # STOP TRAINING WITH BREAK, OR EXIT TRAINING + + # if not training load the last saved weights, and check validation + elif self.task == 'valid': + # load the last trained weight in the folder weight + self.load_latest_weight() + self.validation(current_epoch=self.config_trainer['num_epochs']) + + def validation(self, verbose: int = 0, current_epoch: int = None): + """ + Compute the validation dice, loss of the training from the validation data + """ + # path to the validation data, if not specified, the default path ../data/valid/ would be considered + + folder_preprocessed = self.folder_preprocessed_valid + + # image folder names, or identifier: if not specified the default values would be the name of the folder inside + # the directory "folder processed" or the self.folder_processed_valid : + + valid_identifier = self.ids_to_read_valid + + ''' + WE CAN IMPLEMENT THE EVALUATION METHOD AS BATCH BASED, PATIENT BASED, OR THE WHOLE-VALIDATION DATA BASED. FOR + THE LAST OPTION WE NEED TO IMPLEMENT THE EVALUATION() FUNCTION HERE. + ''' + + '''' + declare variables to return: + forward loss and dice with h0 (no feedback), + feedback network loss and dice + forward decoder loss and dice, + forward loss and dice with ht (with feedback latent space) + ''' + loss_dice = {'loss_fwd_h0': [], 'dice__fwd_h0': [], 'loss_fdb_h0': [], 'dice_fdb_h0': [], + 'loss_fwd_decoder': [], 'dice_fwd_decoder': [], 'loss_fwd_ht': [], 'dice_fwd_ht': []} + + all_dice_sen_sep = {'dice': [], 'specificity': [], 'sensitivity': []} + + # load the dataset, + # get the validation ids + for id_to_validate in valid_identifier: + try: + id_to_validate = str(id_to_validate).split('.')[0] + except: + pass + + valid_input, valid_output = self.load_dataset(directory_=folder_preprocessed, ids_to_read=[id_to_validate]) + + if len(valid_input) == 0: + print("data %s not read" % id_to_validate) + continue + + results, dice_sen_sep = self.evaluation( + input_image=valid_input.copy(), ground_truth=valid_output.copy(), case_name=str(id_to_validate) + ) + + # append all loss to loss and dice to dice from all cases in valid identifiers + for keys in results.keys(): + loss_dice[str(keys)].append(results[str(keys)][0]) + + for keys in dice_sen_sep.keys(): + all_dice_sen_sep[str(keys)].append(dice_sen_sep[str(keys)][0]) + + print("\n Dice, sensitivity, specificity \t") + for k, v in all_dice_sen_sep.items(): + print('%s : %0.3f ' % (k, np.mean(list(v), axis=0)), end=" ") + print("\n") + + """ + print the mean of the validation loss and validation dice + """ + + # FOR STOPPING CRITERIA WE ARE USING THE MODEL AT THE 3RD STEP + dice_mean = np.mean(loss_dice['dice_fwd_ht']) + loss_mean = np.mean(loss_dice['loss_fwd_ht']) + + # at the first epoch + if current_epoch == 0: + NetworkTrainer.BEST_METRIC_VALIDATION = dice_mean + NetworkTrainer.BEST_LOSS_VALIDATION = loss_mean + + # compare the current dice and loss with the previous epoch's loss and dice: + # NOW CONSIDER DICE AS OPTIMIZATION METRIC + print("Current validation loss and metrics at epoch %d: >> " % current_epoch, end=" ") + for k, v in loss_dice.items(): + print('%s : %0.3f ' % (k, np.mean(v)), end=" ") + print("\n") + + if NetworkTrainer.BEST_METRIC_VALIDATION <= dice_mean: + # reset early stop count, best dice, and best loss values + NetworkTrainer.BEST_LOSS_VALIDATION = loss_mean + NetworkTrainer.BEST_METRIC_VALIDATION = dice_mean + NetworkTrainer.EARLY_STOP_COUNT = 0 + + # save the best model weights + if not os.path.exists('./weight'): + os.mkdir('./weight') + + self.model.combine_and_train.save( + 'weight/forward_system_' + str(NetworkTrainer.TRAINED_MODEL_IDENTIFIER) + '_.h5' + ) + self.model.fcn_feedback.save( + 'weight/feedback_system_' + str(NetworkTrainer.TRAINED_MODEL_IDENTIFIER) + '_.h5' + ) + else: # just print the current validation metric (dice) and loss, and count early stop + # Increase the early stop count per epoch + NetworkTrainer.EARLY_STOP_COUNT += 1 + + print( + '\n Best model on validation data : %0.3f : Dice: %0.3f \n' % ( + NetworkTrainer.BEST_LOSS_VALIDATION, NetworkTrainer.BEST_METRIC_VALIDATION) + ) + + def evaluation( + self, verbose: int = 0, input_image: ndarray = None, ground_truth: ndarray = None, + validation_or_test: str = 'test', case_name: str = None + ): + """ + + :param case_name: + :param validation_or_test: + :param verbose: + :param input_image: + :param ground_truth: + + Parameters + ---------- + save_all + """ + '''' + declare variables to return: + forward loss and dice with h0 (no feedback), + feedback network loss and dice + forward decoder loss and dice, + forward loss and dice with ht (with feedback latent space) + ''' + all_loss_dice = {'loss_fwd_h0': [], 'dice__fwd_h0': [], 'loss_fdb_h0': [], 'dice_fdb_h0': [], + 'loss_fwd_decoder': [], 'dice_fwd_decoder': [], 'loss_fwd_ht': [], 'dice_fwd_ht': []} + + dice_sen_sp = {'dice': [], 'specificity': [], 'sensitivity': []} + + # latent feedback variable h0 + # replace the first number of batches with the number of input images from the first channel + h0_input = np.zeros( + (len(input_image), int(self.latent_dim[0]), int(self.latent_dim[1]), int(self.latent_dim[2])), np.float32 + ) + + # step 0: + # Loss and dice on the validation of the forward system + loss, dice = self.model.combine_and_train.evaluate([input_image, h0_input], [ground_truth], verbose=verbose) + all_loss_dice['loss_fwd_h0'].append(loss), all_loss_dice['dice__fwd_h0'].append(dice) + + # predict from the forward system + predicted = self.model.combine_and_train.predict([input_image, h0_input]) + + # step 2: + # Loss and dice on the validation of the feedback system + loss, dice = self.model.fcn_feedback.evaluate([predicted], [ground_truth], verbose=verbose) + all_loss_dice['loss_fdb_h0'].append(loss), all_loss_dice['dice_fdb_h0'].append(dice) + + # step 3: + feedback_latent = self.model.feedback_latent.predict(predicted) # feedback: hf + forward_encoder_output = self.model.forward_encoder.predict([input_image]) # forward system's encoder output + + forward_encoder_output = forward_encoder_output[::-1] # bottleneck should be first + forward_encoder_output.insert(1, feedback_latent) + loss, dice = self.model.forward_decoder.evaluate( + [output for output in forward_encoder_output], [ground_truth], verbose=verbose + ) + all_loss_dice['loss_fwd_decoder'].append(loss), all_loss_dice['dice_fwd_decoder'].append(dice) + + # loss and dice from the combined and feed back latent space : input [input_image, fdb_latent_space] + loss, dice = self.model.combine_and_train.evaluate( + [input_image, feedback_latent], [ground_truth], verbose=verbose + ) + all_loss_dice['loss_fwd_ht'].append(loss), all_loss_dice['dice_fwd_ht'].append(dice) + """ + For the testing time, we use defined metrics on the predicted images instead of using model.evaluate during + the validation cases + """ + predicted = self.model.combine_and_train.predict([input_image, feedback_latent]) + + # binary.dc, sen, and specificty works only on binary images + dice_sen_sp['dice'].append( + binary.dc(NetworkTrainer.threshold_image(predicted), NetworkTrainer.threshold_image(ground_truth)) + ) + dice_sen_sp['sensitivity'].append( + binary.sensitivity(NetworkTrainer.threshold_image(predicted), NetworkTrainer.threshold_image(ground_truth)) + ) + dice_sen_sp['specificity'].append( + binary.specificity(NetworkTrainer.threshold_image(predicted), NetworkTrainer.threshold_image(ground_truth)) + ) + # all = np.concatenate((ground_truth, predicted, input_image), axis=0) + # display_image(all) + + # Sometimes save predictions + if self.save_all: + predicted = self.model.combine_and_train.predict([input_image, feedback_latent]) + save_nii_images( + [predicted, ground_truth, input_image], identifier=str(case_name), + name=[case_name + "_predicted", case_name + "_ground_truth", case_name + "_image"], path_save=self.predicted_directory + ) + else: + + n = randint(0, 10) + if n % 3 == 0: + predicted = self.model.combine_and_train.predict([input_image, feedback_latent]) + save_nii_images( + [predicted, ground_truth, input_image], identifier=str(case_name), + name=[case_name + "_predicted", case_name + "_ground_truth", case_name + "_image"], path_save=self.predicted_directory + ) + + return all_loss_dice, dice_sen_sp + + @staticmethod + def display_image(im_display: ndarray): + """ display given images + + :param all: 2D image arrays to display + :returns: display images + """ + plt.figure(figsize=(10, 8)) + plt.subplots_adjust(hspace=0.5) + plt.suptitle("Daily closing prices", fontsize=18, y=0.95) + # loop through the length of tickers and keep track of index + for n, im in enumerate(im_display): + # add a new subplot iteratively + plt.subplot(3, 2, n + 1) + plt.imshow(im) # chart formatting + plt.show() + + @staticmethod + # binary.dc, sen, and specificty works only on binary images + def threshold_image(im_: ndarray, thr_value: float = 0.5) -> ndarray: + """ threshold given input array with the given thresholding value + + :param im_: ndarray of images + :param thr_value: thresholding value + :return: threshold array image + """ + im_[im_ > thr_value] = 1 + im_[im_ < thr_value] = 0 + return im_ + + +class ModelTesting: + """ performs prediction on a given data set. It predicts the segmentation results, and save the results, calculate + the clinical metrics such as TMTV, Dmax, sTMTV, sDmax. + + """ + now = datetime.now() # current time, date, month, + TRAINED_MODEL_IDENTIFIER = re.sub('[ :]', "_", now.ctime()) + print("current directory", os.getcwd()) + + def __init__( + self, config_test: dict = None, preprocessed_dir: str = '../data/test/', data_list: List[str] = None, + predicted_dir: str = "../data/predicted" + ): + """ + + :param config_trainer: + :param folder_preprocessed_train: + :param folder_preprocessed_valid: + :param ids_to_read_train: + :param ids_to_read_valid: + :param task: + :param predicted_dir: + """ + + if config_test is None: + self.config_test = deepcopy(default_training_parameters()) + + # training data + self.preprocessed_dir = preprocessed_dir + self.predicted_dir = predicted_dir + + # if the list of testing cases are not given, get from the directory + if data_list is None: + data_list = os.listdir(preprocessed_dir) + + self.data_list = data_list + + # load the lfb_network architecture + self.model = lfbnet.LfbNet() + + # latent feedback at zero time: means no feedback from feedback network + self.latent_dim = self.model.latent_dim + + # load the last trained weight in the folder weight + print(os.getcwd()) + folder_path = os.path.join(os.getcwd(), 'src/weight') + print(folder_path) + + full_path = [path_i for path_i in glob.glob(str(folder_path) + '/*.h5')] + + print("files \n", full_path) + try: + max_file = max(full_path, key=os.path.getctime) + except: + raise Exception("weight could not found !") + + base_name = str(os.path.basename(max_file)) + print(base_name) + self.model.combine_and_train.load_weights( + str(folder_path) + '/forward_system' + str(base_name.split('system')[1]) + ) + # f + self.model.fcn_feedback.load_weights(str(folder_path) + '/feedback_system' + str(base_name.split('system')[1])) + + self.test() + + def test(self): + """ + Compute the validation dice, loss of the training from the validation data + """ + # path to the validation data, if not specified, the default path ../data/valid/ would be considered + # + folder_preprocessed = self.preprocessed_dir + # image folder names, or identifier: if not specified the default values would be the name of the folder inside + # the directory "folder processed" or the self.folder_processed_valid : + test_identifier = self.data_list + + '''' + declare variables to return if there is a reference segmentation or ground truth : + forward loss and dice with h0 (no feedback), + feedback network loss and dice + forward decoder loss and dice, + forward loss and dice with ht (with feedback latent space) + ''' + loss_dice = {'loss_fwd_h0': [], 'dice__fwd_h0': [], 'loss_fdb_h0': [], 'dice_fdb_h0': [], + 'loss_fwd_decoder': [], 'dice_fwd_decoder': [], 'loss_fwd_ht': [], 'dice_fwd_ht': []} + + # get the validation ids + test_output = [] + for id_to_test in tqdm(list(test_identifier)): + test_input, test_output = NetworkTrainer.load_dataset( + directory_=folder_preprocessed, ids_to_read=[id_to_test] + ) + + if len(test_input) == 0: + print("data %s not read" % id_to_test) + continue + + ''' + if there is a ground truth segmentation (gt), and you would like to compare with the predicted segmentation + by the deep learning model + ''' + + if len(test_output): + results = self.evaluation_test( + input_image=test_input.copy(), ground_truth=test_output.copy(), case_name=str(id_to_test) + ) + + # append all loss to loss and dice to dice from all cases in valid identifiers + for keys in results.keys(): + loss_dice[str(keys)].append(results[str(keys)][0]) + + print("Results (sagittal and coronal) for case id: %s : >> " % id_to_test, end=" ") + for k, v in loss_dice.items(): + print('%s : %0.3f ' % (k, np.mean(v)), end=" ") + print("\n") + + # Predict the segmentation and save in the folder predicted, dataset identifier + else: + self.prediction(input_image=test_input.copy(), case_name=str(id_to_test)) + + """ + print the mean of the testing loss and dice if there is a ground truth, for all cases + """ + if len(test_output): + print("Total dataset metrics: : >> ", end=" ") + for k, v in loss_dice.items(): + print('%s : %0.3f ' % (k, np.mean(v)), end=" ") + print("\n") + + def evaluation_test( + self, verbose: int = 0, input_image: ndarray = None, ground_truth: ndarray = None, + validation_or_test: str = 'validate', case_name: str = None + ): + """ + + :param case_name: + :param validation_or_test: + :param verbose: + :param input_image: + :param ground_truth: + """ + '''' + declare variables to return: + forward loss and dice with h0 (no feedback), + feedback network loss and dice + forward decoder loss and dice, + forward loss and dice with ht (with feedback latent space) + ''' + all_loss_dice = {'loss_fwd_h0': [], 'dice__fwd_h0': [], 'loss_fdb_h0': [], 'dice_fdb_h0': [], + 'loss_fwd_decoder': [], 'dice_fwd_decoder': [], 'loss_fwd_ht': [], 'dice_fwd_ht': []} + # latent feedback variable h0 + # replace the first number of batches with the number of input images from the first channel + h0_input = np.zeros( + (len(input_image), int(self.latent_dim[0]), int(self.latent_dim[1]), int(self.latent_dim[2])), np.float32 + ) + + # step 0: + # Loss and dice on the validation of the forward system + loss, dice = self.model.combine_and_train.evaluate([input_image, h0_input], [ground_truth], verbose=verbose) + all_loss_dice['loss_fwd_h0'].append(loss), all_loss_dice['dice__fwd_h0'].append(dice) + + # predict from the forward system + predicted = self.model.combine_and_train.predict([input_image, h0_input]) + + # step 2: + # Loss and dice on the validation of the feedback system + loss, dice = self.model.fcn_feedback.evaluate([predicted], [ground_truth], verbose=verbose) + all_loss_dice['loss_fdb_h0'].append(loss), all_loss_dice['dice_fdb_h0'].append(dice) + + # step 3: + feedback_latent = self.model.feedback_latent.predict(predicted) # feedback: hf + forward_encoder_output = self.model.forward_encoder.predict([input_image]) # forward system's encoder output + + forward_encoder_output = forward_encoder_output[::-1] # bottleneck should be first + forward_encoder_output.insert(1, feedback_latent) + loss, dice = self.model.forward_decoder.evaluate( + [output for output in forward_encoder_output], [ground_truth], verbose=verbose + ) + all_loss_dice['loss_fwd_decoder'].append(loss), all_loss_dice['dice_fwd_decoder'].append(dice) + + # loss and dice from the combined and feed back latent space : input [input_image, fdb_latent_space] + loss, dice = self.model.combine_and_train.evaluate( + [input_image, feedback_latent], [ground_truth], verbose=verbose + ) + all_loss_dice['loss_fwd_ht'].append(loss), all_loss_dice['dice_fwd_ht'].append(dice) + + """ + For the testing time, we use defined metrics on the predicted images instead of using model.evaluate during + the validation cases + """ + if validation_or_test == "test": + # return [dice, specificity, and sensitivity + return {'dice': binary.dc(predicted, ground_truth), + 'specificity': binary.specificity(predicted, ground_truth), + 'sensitivity': binary.sensitivity(predicted, ground_truth)} + + predicted = self.model.combine_and_train.predict([input_image, feedback_latent]) + predicted = remove_outliers_in_sagittal(predicted) + save_nii_images( + [predicted, ground_truth, input_image], identifier=str(case_name), + name=[case_name + "_predicted", case_name + "_ground_truth", case_name + "_pet"], + path_save= os.path.join(str(self.predicted_dir), 'predicted_data') + ) + + return all_loss_dice + + def prediction(self, input_image: ndarray = None, case_name: str = None): + """ + :param case_name: + :param input_image: + """ + # latent feedback variable h0 + # replace the first number of batches with the number of input images from the first channel + h0_input = np.zeros( + (len(input_image), int(self.latent_dim[0]), int(self.latent_dim[1]), int(self.latent_dim[2])), np.float32 + ) + + # STEP 1: forward system prediction + # predict from the forward system + predicted = self.model.combine_and_train.predict([input_image, h0_input]) + + # step 2: Feedback system prediction + feedback_latent = self.model.feedback_latent.predict(predicted) # feedback: hf + + predicted = self.model.combine_and_train.predict([input_image, feedback_latent]) + predicted = remove_outliers_in_sagittal(predicted) + save_nii_images( + image=[predicted, input_image], identifier=str(case_name), name=[case_name + "_predicted", + case_name + "_pet"], + path_save= os.path.join(str(self.predicted_dir), 'predicted_data') + ) + + +if __name__ == '__main__': + train_valid_data_dir = r"E:\LFBNet\data\remarc_default_MIP_dir/" + train_valid_ids_path_csv = r'E:\LFBNet\data\csv\training_validation_indexs\remarc/' + train_ids, valid_ids = get_training_and_validation_ids_from_csv(train_valid_ids_path_csv) + + trainer = NetworkTrainer( + folder_preprocessed_train=train_valid_data_dir, folder_preprocessed_valid=train_valid_data_dir, + ids_to_read_train=train_ids, ids_to_read_valid=valid_ids + ) + trainer.train()