Diff of /train_2d.py [000000] .. [e50482]

Switch to side-by-side view

--- a
+++ b/train_2d.py
@@ -0,0 +1,545 @@
+import glob
+import shutil
+import argparse
+import monai
+import torch
+from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger
+from torchmetrics.functional import iou, dice_score
+from torch import nn
+from pytorch_lightning import Trainer
+import pytorch_lightning as pl
+import numpy as np
+from torch.utils.data import DataLoader
+from sklearn.model_selection import KFold
+from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau
+import datetime
+import monai.transforms as mt
+from torchmetrics import IoU, F1
+from attn_unet_2d import AttU_Net2D
+from unet_2d import Unet_2d
+from data_2d import train_loader_ACDC, val_loader_ACDC
+from monai.losses.dice import DiceLoss, DiceCELoss
+from tensorboard.backend.event_processing.event_accumulator import EventAccumulator as ea
+import os
+import matplotlib.pyplot as plt
+from pathlib import Path
+
+# Manual seeding
+torch.manual_seed(42)
+
+"""-----------------------Arguments-----------------------"""
+parser = argparse.ArgumentParser(description='Training of UNet2D Segmentation')
+parser.add_argument("--model_choice", default="UNet2D_Attention", type=str)
+parser.add_argument("--kfolds", default=5, type=int)
+parser.add_argument("--Loss_choice", default="dice", type=str)
+parser.add_argument("--Batch_size_train", default=10, type=int)
+parser.add_argument("--Batch_size_val", default=1, type=int)
+parser.add_argument("--lr", default=np.float32(0.0005), type=float)
+parser.add_argument("--lr_decay", default=np.float32(0.985), type=float)
+parser.add_argument("--gpus", default=1, type=int)
+parser.add_argument("--maximum_epochs", default=400, type=int)
+parser.add_argument("--patience_early_stop", default=400, type=int)
+parser.add_argument('--optimizer_choice', default='adam', type=str)
+parser.add_argument('--scheduler_choice', default='plateau', type=str)
+parser.add_argument('--dropout_rate', default=0.3, type=float)
+
+"""--------------Models, Hyperparameters, Metrics and Variables--------------"""
+arguments = parser.parse_args()
+model_choice = arguments.model_choice
+k_folds = arguments.kfolds
+loss_choice = arguments.Loss_choice
+batch_size_train = arguments.Batch_size_train
+batch_size_val = arguments.Batch_size_val
+learning_rate = arguments.lr
+LR_decay_rate = arguments.lr_decay
+dev = arguments.gpus
+max_epochs = arguments.maximum_epochs
+patience = arguments.patience_early_stop
+optim_choice = arguments.optimizer_choice
+scheduler_choice = arguments.scheduler_choice
+drop_rate = arguments.dropout_rate
+
+print("Model Choice:", model_choice,
+      "Dropout Rate", drop_rate,
+      "K Folds:", k_folds,
+      "Loss Choice:", loss_choice,
+      "LR Decay Rate:", LR_decay_rate,
+      "Device:", dev,
+      "Patience Early Stopping:", patience,
+      "Optimizer:", optim_choice,
+      "Scheduler:", scheduler_choice)
+
+# Model
+if model_choice == "UNet2D":
+    my_model = Unet_2d(drop=drop_rate).cuda()  # with upsample --> Has more parameters
+elif model_choice == "UNet2D_Attention":
+    my_model = AttU_Net2D(drop=drop_rate).cuda()  # with Attention --> Has the most parameters
+else:
+    raise ValueError("Wrong model choice!")
+# Loss Function
+if loss_choice == "dice_ce":
+    loss_func = DiceCELoss(include_background=True,
+                           to_onehot_y=True,
+                           sigmoid=False,
+                           softmax=True,
+                           jaccard=False,
+                           reduction="mean",
+                           smooth_nr=1e-05,
+                           smooth_dr=1e-05,
+                           # ce_weight=class_weights,
+                           batch=False).cuda()
+elif loss_choice == "dice":
+    loss_func = DiceLoss(include_background=True,
+                         to_onehot_y=True,
+                         sigmoid=False,
+                         softmax=True,
+                         jaccard=False,
+                         reduction="mean",
+                         smooth_nr=1e-05,
+                         smooth_dr=1e-05,
+                         batch=False).cuda()
+else:
+    raise ValueError("Wrong loss choice!")
+
+# IoU
+IOU_metric = IoU(num_classes=4, absent_score=-1., reduction="none").cuda()
+# F1 score
+f1_metric = F1(num_classes=4, mdmc_average="samplewise", average='none').cuda()
+# Softmax
+soft = torch.nn.Softmax(dim=1)
+# Required dimensions
+tar_shape = [300, 300]
+crop_shape = [224, 224]
+
+"""---------Post Processing---------"""
+keep_largest = monai.transforms.KeepLargestConnectedComponent(applied_labels=[0, 1, 2, 3])
+
+"""---------Augmentations---------"""
+
+train_transform = mt.Compose(
+    [mt.ResizeWithPadOrCropD(keys=["image", "mask"], spatial_size=tar_shape, mode="constant"),
+     mt.RandSpatialCropD(keys=["image", "mask"], roi_size=crop_shape, random_center=True, random_size=False),
+     mt.Rand2DElasticD(
+         keys=["image", "mask"],
+         prob=0.25,
+         spacing=(50, 50),
+         magnitude_range=(1, 3),
+         rotate_range=(np.pi / 4,),
+         scale_range=(0.1, 0.1),
+         translate_range=(10, 10),
+         padding_mode="border",
+     ),
+     # mt.RandScaleIntensityd(keys=["image"], factors=0.3, prob=0.15),
+     mt.RandFlipd(["image", "mask"], spatial_axis=[0], prob=0.5),
+     mt.RandFlipd(["image", "mask"], spatial_axis=[1], prob=0.5),
+     mt.RandRotateD(keys=["image", "mask"], range_x=np.pi / 4, range_y=np.pi / 4, range_z=0.0, prob=0.50,
+                    keep_size=True, mode=("nearest", "nearest"), align_corners=False),
+     mt.RandRotate90D(keys=["image", "mask"], prob=0.25, spatial_axes=(0, 1)),
+     mt.RandGaussianNoiseD(keys=["image"], prob=0.15, std=0.01),
+     mt.ToTensorD(keys=["image", "mask"], allow_missing_keys=False),
+     mt.RandZoomd(
+         keys=["image", "mask"],
+         min_zoom=0.9,
+         max_zoom=1.2,
+         mode="nearest",
+         align_corners=None,
+         prob=0.25,
+     ),
+     mt.RandKSpaceSpikeNoiseD(keys=["image"], prob=0.15, intensity_range=(5.0, 7.5)),
+     ]
+)
+val_transform = mt.Compose([
+    mt.ToTensorD(keys=["image", "mask"], allow_missing_keys=False)
+])
+
+"""------------------Datasets and Directories to save the results------------------"""
+# Define the K-fold Cross Validator
+splits = KFold(n_splits=k_folds, shuffle=True, random_state=42)
+# train + val dataset for 5 fold cross validation training
+concatenated_dataset = train_loader_ACDC(transform=None, train_index=None)
+
+#  paths to store the checkpoints
+if not os.path.exists(r"../unet/checkpoints"):
+    os.makedirs(r"../unet/checkpoints")
+checkpoint_path = "../unet/checkpoints"
+
+if not os.path.exists(r"../unet/tb_logs"):
+    os.makedirs(r"../unet/tb_logs")
+tb_path = "../unet/tb_logs"
+
+if not os.path.exists(r"../unet/csv_logs"):
+    os.makedirs(r"../unet/csv_logs")
+csv_path = "../unet/csv_logs"
+
+#  Temporarily store the validated image and ground truth plots --> to be moved to the respective folders
+if not os.path.exists(r'../unet/val_images_temp_2d/'):
+    os.makedirs(r'../unet/val_images_temp_2d/')
+val_path = r'../unet/val_images_temp_2d/'
+
+# Save the validation images and ground truths
+if not os.path.exists(r'../unet/val_images_save_2d/'):
+    os.makedirs(r'../unet/val_images_save_2d/')
+image_path = r'../unet/val_images_save_2d/'
+
+
+# pad the images so that they are divisible by 16
+def Pad_images(image):
+    orig_shape = list(image.size())
+    original_x = orig_shape[2]
+    original_y = orig_shape[3]
+    new_x = (16 - (original_x % 16)) + original_x
+    new_y = (16 - (original_y % 16)) + original_y
+    new_shape = [new_x, new_y]
+    b, c, h, w = image.shape
+    m = image.min()
+    x_max = new_shape[0]
+    y_max = new_shape[1]
+    result = torch.Tensor(b, c, x_max, y_max).fill_(m)
+    xx = (x_max - h) // 2
+    yy = (y_max - w) // 2
+    result[:, :, xx:xx + h, yy:yy + w] = image
+    return result, tuple([xx, yy])  # result is a torch tensor in CPU --> have to move to GPU
+
+
+# pass the padded image, the indices and the original shape
+def UnPad_imges(image, indices, org_shape):
+    b, c, h, w = org_shape
+    xx = indices[0]
+    yy = indices[1]
+    return image[:, :, xx:xx + h, yy:yy + w]  # image is a torch tensor --> have to move to GPU
+
+
+# reset the parameters of the model
+def reset_weights(m):
+    if isinstance(m, nn.Conv3d) or isinstance(m, nn.Conv2d) or \
+            isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.ConvTranspose3d):
+        m.reset_parameters()
+
+
+# save the masks
+def save_plots_mask(target, idx):
+    out_path = os.path.join(val_path, f"{idx}_gt" + "." + 'png')
+    out_save_path = os.path.join(image_path, f"{idx}_gt" + "." + 'png')
+    target = target.squeeze()
+    target = np.array(target.cpu())
+    plt.imsave(out_save_path, target)
+    image_file_name = str(idx) + "_gt"
+    plt.title = image_file_name
+    plt.imsave(out_path, target, format='png')
+    plt.close()
+
+
+# save the predictions
+def save_plots_pred(pred, idx):
+    out_path = os.path.join(val_path, f"{idx}_pred" + "." + 'png')
+    out_save_path = os.path.join(image_path, f"{idx}_pred" + "." + 'png')
+    soft_pred_log = soft(pred)
+    final_pred_log = torch.argmax(soft_pred_log, dim=1)
+    #  Post Processing after softmax and argmax
+    final_pred_log = keep_largest(final_pred_log)
+    final_pred_log = np.array(final_pred_log.cpu().squeeze())
+    plt.imsave(out_save_path, final_pred_log)
+    image_file_name = str(idx) + "_pred"
+    plt.title = image_file_name
+    plt.imsave(out_path, final_pred_log, format='png')
+    plt.close()
+
+
+# save the images
+def save_plots_image(img, idx):
+    out_path = os.path.join(val_path, f"{idx}_image" + "." + 'png')
+    out_save_path = os.path.join(image_path, f"{idx}_image" + "." + 'png')
+    final_image = np.array(img.cpu().squeeze())
+    plt.imsave(out_save_path, final_image)
+    image_file_name = str(idx) + "_image"
+    plt.title = image_file_name
+    plt.imsave(out_path, final_image, format='png')
+    plt.close()
+
+
+class Train2D(pl.LightningModule):
+    def __init__(self):
+        super(Train2D, self).__init__()
+        self.net = my_model
+        self.loss_function = loss_func
+
+    def forward(self, x):
+        return self.net(x)  # returns output of the model --> B Classes H W
+
+    def training_step(self, batch, batch_idx):
+        img, mask = batch["image"], batch["mask"]  # image --> torch.float(), mask --> torch.Long
+        img = img.float()  # B Channels H W
+        mask = mask.long()  # B Channels H W
+        # image passed through the model
+        out = self(img)  # B Classes H W
+        # calculate loss
+        loss = self.loss_function(out, mask)
+        # calculate softmax of the prediction
+        soft_out = soft(out)  # softmax of the prediction
+        mask = mask.squeeze(dim=1)  # B H W
+        """ Calculation of metrics using Torchmetrics"""
+        # # calculate iou
+        # iou_all = IOU_metric(soft_out, mask)
+        # # train_iou = iou_all.mean()
+        # iou_all = iou_all[iou_all != -1.]
+        # if len(iou_all) == 0:
+        #     train_iou = torch.tensor(0.0).cuda()
+        # else:
+        #     train_iou = iou_all.mean()
+        # # calculate dice score
+        # dice_all = f1_metric(soft_out, mask)
+        # # train_dice = dice_all.mean()
+        # dice_all_np = dice_all.cpu().numpy()
+        # dice_all_np = dice_all_np[~np.isnan(dice_all_np)]
+        # dice_all = (torch.from_numpy(dice_all_np).cuda())
+        # if len(dice_all) == 0:
+        #     train_dice = torch.tensor(0.0).cuda()
+        # else:
+        #     train_dice = dice_all.mean()
+        """ Calculation of metrics using Torchmetrics functional"""
+        # iou
+        iou_all = iou(soft_out, mask, absent_score=-1., num_classes=4, reduction='none', ignore_index=None)
+        iou_all = iou_all[iou_all != -1.]
+        if len(iou_all) == 0:
+            train_iou = torch.tensor(0.0).cuda()
+        else:
+            train_iou = iou_all.mean()
+        # dice score
+        dice_all = dice_score(soft_out, mask, bg=True, no_fg_score=-1., reduction='none')
+        dice_all = dice_all[dice_all != -1.]
+        if len(dice_all) == 0:
+            train_dice = torch.tensor(0.0).cuda()
+        else:
+            train_dice = dice_all.mean()
+        # logger --> log the train_loss
+        self.log('train_loss', loss, on_step=True, on_epoch=False, prog_bar=False)
+        return {"loss": loss, "train_iou": train_iou, "train_dice": train_dice}
+
+    def validation_step(self, batch, batch_idx):
+        img, mask = batch["image"], batch["mask"]  # image --> torch.float(), mask --> torch.Long
+        img = img.float()  # B Channels H W
+        mask = mask.long()  # B Channels H W
+        ###############################################
+        save_plots_image(img, batch_idx)  # save the images
+        save_plots_mask(mask, batch_idx)  # save the masks
+        ###############################################
+        # pad the image
+        padded_image, ind = Pad_images(img)
+        padded_image = padded_image.cuda()
+        # padded image passed through the model
+        out = self(padded_image).cuda()  # B Classes H W
+        # unpad the image
+        unpadded_prediction = UnPad_imges(out, ind, img.shape)
+        unpadded_prediction = unpadded_prediction.cuda()
+        ###############################################
+        save_plots_pred(unpadded_prediction, batch_idx)  # save the predictions
+        ###############################################
+        # calculate loss
+        loss = self.loss_function(unpadded_prediction, mask)
+        # calculate softmax of the prediction
+        soft_out = soft(unpadded_prediction)
+        mask = mask.squeeze(dim=1)  # B H W
+        """ Calculation of metrics using Torchmetrics"""
+        # # calculate iou
+        # iou_all = IOU_metric(soft_out, mask)
+        # # val_iou = iou_all.mean()
+        # iou_all = iou_all[iou_all != -1.]
+        # if len(iou_all) == 0:
+        #     val_iou = torch.tensor(0.0).cuda()
+        # else:
+        #     val_iou = iou_all.mean()
+        # # calculate dice score
+        # dice_all = f1_metric(soft_out, mask)
+        # # val_dice = dice_all.mean()
+        # dice_all_np = dice_all.cpu().numpy()
+        # dice_all_np = dice_all_np[~np.isnan(dice_all_np)]
+        # dice_all = (torch.from_numpy(dice_all_np).cuda())
+        # if len(dice_all) == 0:
+        #     val_dice = torch.tensor(0.0).cuda()
+        # else:
+        #     val_dice = dice_all.mean()
+        """ Calculation of metrics using Torchmetrics functional"""
+        # iou
+        iou_all = iou(soft_out, mask, absent_score=-1., num_classes=4, reduction='none', ignore_index=None)
+        iou_all = iou_all[iou_all != -1.]
+        if len(iou_all) == 0:
+            val_iou = torch.tensor(0.0).cuda()
+        else:
+            val_iou = iou_all.mean()
+        # dice score
+        dice_all = dice_score(soft_out, mask, bg=True, no_fg_score=-1., reduction='none')
+        dice_all = dice_all[dice_all != -1.]
+        if len(dice_all) == 0:
+            val_dice = torch.tensor(0.0).cuda()
+        else:
+            val_dice = dice_all.mean()
+        # logger --> log the val_loss
+        self.log('val_loss', loss, on_step=True, on_epoch=False, prog_bar=False)
+        return {'loss': loss, 'val_iou': val_iou, 'val_dice': val_dice}
+
+    def training_epoch_end(self, train_step_outputs):
+        """-----Calculate and logs the average train loss, IoU score and Dice Score-----"""
+        avg_train_loss = torch.stack([x['loss'] for x in train_step_outputs]).mean()
+        avg_train_iou = torch.stack([x['train_iou'] for x in train_step_outputs]).mean()
+        avg_train_dice = torch.stack([x['train_dice'] for x in train_step_outputs]).mean()
+        self.log('avg_train_loss', avg_train_loss, on_step=False, on_epoch=True, prog_bar=True)
+        self.log('avg_train_iou', avg_train_iou, on_step=False, on_epoch=True, prog_bar=True)
+        self.log('avg_train_dice', avg_train_dice, on_step=False, on_epoch=True, prog_bar=True)
+
+    def validation_epoch_end(self, val_step_outputs):
+        """-----Calculate and logs the average validation loss, IoU score and Dice Score-----"""
+        avg_val_loss = torch.stack([x['loss'] for x in val_step_outputs]).mean()
+        avg_val_iou = torch.stack([x['val_iou'] for x in val_step_outputs]).mean()
+        avg_val_dice = torch.stack([x['val_dice'] for x in val_step_outputs]).mean()
+        self.log('avg_val_loss', avg_val_loss, on_step=False, on_epoch=True, prog_bar=True)
+        self.log('avg_val_iou', avg_val_iou, on_step=False, on_epoch=True, prog_bar=True)
+        self.log('avg_val_dice', avg_val_dice, on_step=False, on_epoch=True, prog_bar=True)
+        return {'avg_val_loss': avg_val_loss, 'avg_val_iou': avg_val_iou, 'avg_val_dice': avg_val_dice}
+
+    def configure_optimizers(self):
+        """-----Optimizers and LR Schedulers-----"""
+        if optim_choice == 'adam':
+            optim = torch.optim.Adam(self.net.parameters(), lr=learning_rate, eps=1e-8, weight_decay=1e-5, amsgrad=True)
+        else:
+            raise ValueError("Wrong optimizer!")
+        if scheduler_choice == 'plateau':
+            scheduler = ReduceLROnPlateau(optim, mode='min', factor=LR_decay_rate, patience=20)
+        elif scheduler_choice == 'step':
+            scheduler = StepLR(optim, step_size=10, gamma=LR_decay_rate, last_epoch=-1)
+        else:
+            raise ValueError("Wrong scheduler!")
+        return {
+            "optimizer": optim,
+            "lr_scheduler": scheduler,
+            'monitor': 'avg_train_loss'
+        }
+
+
+def run_training():
+    """--------------------------------------5 fold Cross Validation--------------------------------------"""
+    for fold, (train_idx, val_idx) in enumerate(splits.split(np.arange(len(concatenated_dataset)))):
+        print(len(train_idx), len(val_idx))
+        print("--------------------------", "Fold", fold + 1, "--------------------------")
+        print("Train Batch Size:", batch_size_train,
+              "Val Batch Size:", batch_size_val,
+              "Learning Rate:", learning_rate,
+              "Max epochs:", max_epochs)
+
+        """-------------------Train the model for "max_epochs" for each fold-------------------"""
+        # training dataset
+        training_data = DataLoader(train_loader_ACDC(transform=train_transform, train_index=train_idx),
+                                   batch_size=batch_size_train,
+                                   shuffle=True, num_workers=2)
+        # validation dataset
+        validation_data = DataLoader(val_loader_ACDC(transform=val_transform, val_index=val_idx),
+                                     batch_size=batch_size_val, shuffle=False, num_workers=2)
+        # init the model
+        model = Train2D()
+        # name of the model
+        name = str(model_choice) + "_" + str(drop_rate) + "_" + str(datetime.date.today()) + "_Fold_" + str(fold + 1)
+        #  Checkpoint callback and Early Stopping
+        checkpoint_callback = pl.callbacks.ModelCheckpoint(dirpath=checkpoint_path,
+                                                           save_top_k=1,
+                                                           save_last=True,
+                                                           verbose=True,
+                                                           monitor='avg_val_iou',
+                                                           mode='max',
+                                                           filename=name + "_" + '{epoch}-{avg_val_iou:.4f}',
+                                                           )
+        early_stop_callback = pl.callbacks.EarlyStopping(monitor='avg_val_loss',
+                                                         min_delta=0.00,
+                                                         patience=patience,
+                                                         verbose=False,
+                                                         mode='min')
+        # Tensorboard logger --> tensorboard --logdir=tb_logs
+        tensorboard_logger = TensorBoardLogger(tb_path, name=name)
+        # CSV logger
+        csv_logger = CSVLogger(csv_path, name=name)
+        # Trainer for training
+        trainer = Trainer(max_epochs=max_epochs, callbacks=[early_stop_callback, checkpoint_callback],
+                          gpus=1, logger=[tensorboard_logger, csv_logger], fast_dev_run=False, log_every_n_steps=2)
+        # Training the model
+        trainer.fit(model, train_dataloader=training_data, val_dataloaders=validation_data)
+
+        # Save the best model
+        best_model_path = checkpoint_callback.best_model_path
+        model = model.load_from_checkpoint(best_model_path)
+        model.eval().cuda()
+        fname = str(model_choice) + "_Best_" + str(drop_rate) + "_Fold_" + str(fold + 1)
+        if not os.path.exists(r'../unet/best_models/'):
+            os.makedirs(r'../unet/best_models/')
+        torch.save(model, str(Path('../unet/best_models/', fname + '.pt')))
+
+        # Folders to save the validation images for each fold
+        if not os.path.exists(os.path.join(r'../unet/', name, f"{fold + 1}_Fold")):
+            os.makedirs(os.path.join(r'../unet/', name, f"{fold + 1}_Fold"))
+        val_images_path = os.path.join(r'../unet/', name, f"{fold + 1}_Fold")
+
+        #  Move the validated images to the respective folders
+        for filename in glob.glob(os.path.join(val_path, '*.*')):
+            shutil.move(filename, val_images_path)
+
+        # Save plots --> Loss, IoU and Dice
+        plot_out_path = str(Path(r"../unet/Plots/", name))
+        if not os.path.exists(plot_out_path):
+            os.makedirs(plot_out_path)
+
+        event_acc = ea(str(Path(r"../unet/tb_logs/", name, "version_0")))
+        event_acc.Reload()
+
+        _, _, training_loss = zip(*event_acc.Scalars('avg_train_loss'))
+        _, _, validation_loss = zip(*event_acc.Scalars('avg_val_loss'))
+        _, _, training_iou = zip(*event_acc.Scalars('avg_train_iou'))
+        _, _, validation_iou = zip(*event_acc.Scalars('avg_val_iou'))
+        _, _, training_dice = zip(*event_acc.Scalars('avg_train_dice'))
+        _, _, validation_dice = zip(*event_acc.Scalars('avg_val_dice'))
+
+        t_loss, v_loss, t_iou, v_iou, t_dice, v_dice = np.array(training_loss), np.array(validation_loss), \
+                                                       np.array(training_iou), np.array(validation_iou), \
+                                                       np.array(training_dice), np.array(validation_dice)
+        min_length = min(len(t_loss), len(v_loss), len(t_iou), len(v_iou), len(t_dice), len(v_dice))
+        total_epochs = np.arange(1, min_length + 1)
+
+        # Save the Loss, IoU and Dice plots
+        plt.figure(1)
+        plt.rcParams.update({'font.size': 15})
+        plt.plot(total_epochs, t_loss[0:min_length], 'X-', label='Training Loss', linewidth=2.0)
+        plt.plot(total_epochs, v_loss[0:min_length], 'o-', label='Validation Loss', linewidth=2.0)
+        plt.xlabel('Epoch')
+        plt.ylabel('Loss')
+        plt.legend(loc='upper right')
+        plt.minorticks_on()
+        plt.grid(which='minor', linestyle='--')
+        plt.savefig(str(Path(plot_out_path, 'Loss_Plot.png')), bbox_inches='tight', format='png', dpi=300)
+        plt.close()  # Always plt.close() to save memory
+
+        plt.figure(2)
+        plt.rcParams.update({'font.size': 15})
+        plt.plot(total_epochs, t_iou[0:min_length], 'X-', label='Training IOU', linewidth=2.0)
+        plt.plot(total_epochs, v_iou[0:min_length], 'o-', label='Validation IOU', linewidth=2.0)
+        plt.xlabel('Epoch')
+        plt.ylabel('IOU')
+        plt.legend(loc='lower right')
+        plt.minorticks_on()
+        plt.grid(which='minor', linestyle='--')
+        plt.savefig(str(Path(plot_out_path, 'Iou_Plot.png')), bbox_inches='tight', format='png', dpi=300)
+        plt.close()  # Always plt.close() to save memory
+
+        plt.figure(3)
+        plt.rcParams.update({'font.size': 15})
+        plt.plot(total_epochs, t_dice[0:min_length], 'X-', label='Training Dice', linewidth=2.0)
+        plt.plot(total_epochs, v_dice[0:min_length], 'o-', label='Validation Dice', linewidth=2.0)
+        plt.xlabel('Epoch')
+        plt.ylabel('Dice Score')
+        plt.legend(loc='lower right')
+        plt.minorticks_on()
+        plt.grid(which='minor', linestyle='--')
+        plt.savefig(str(Path(plot_out_path, 'Dice_Plot.png')), bbox_inches='tight', format='png', dpi=300)
+        plt.close()  # Always plt.close() to save memory
+
+        # reset model parameters after each fold
+        model.apply(reset_weights)
+
+
+if __name__ == "__main__":
+    run_training()