--- a
+++ b/src/train/trainer.py
@@ -0,0 +1,270 @@
+import torch
+from PIL import Image
+from torchvision import transforms as T
+from tqdm import tqdm
+
+from src.dataset.utils.visualization import plot_batch
+from src.models.io_model import save_checkpoint, save_model
+from src.metrics.training_metrics import AverageMeter
+from src.logging_conf import logger
+
+
+class TrainerArgs:
+    def __init__(self, n_epochs=50, device="cpu", output_path="", loss="dice"):
+        self.n_epochs = n_epochs
+        self.device = device
+        self.output_path = output_path
+        self.loss = loss
+
+
+class Trainer:
+
+    def __init__(self, args, model, optimizer, criterion, start_epoch, train_loader, val_loader, lr_scheduler, writer):
+        self.model = model
+        self.optimizer = optimizer
+        self.criterion = criterion
+
+        self.train_data_loader = train_loader
+        self.number_train_data = len(self.train_data_loader)
+
+        self.valid_data_loader = val_loader
+        self.number_val_data = len(self.valid_data_loader)
+
+        self.lr_scheduler = lr_scheduler
+        self.writer = writer
+
+        self.start_epoch = start_epoch
+        self.args = args
+
+    def start(self, best_loss=1000):
+        val_dice_score = 0
+
+        for epoch in range(self.start_epoch, self.args.n_epochs):
+
+            train_dice_loss, train_dice_score, train_combined_loss, train_ce_loss = self.train_epoch(epoch)
+            val_dice_loss, val_dice_score, val_combined_loss, val_ce_loss = self.val_epoch(epoch)
+
+            val_loss = val_combined_loss if self.args.loss == "combined" else val_dice_loss
+            if self.lr_scheduler:
+                self.lr_scheduler.step(val_loss)
+
+            self._epoch_summary(epoch, train_dice_loss, val_dice_loss, train_dice_score, val_dice_score,
+                                train_combined_loss, train_ce_loss, val_combined_loss, val_ce_loss)
+
+            is_best = bool(val_loss < best_loss)
+            best_loss = val_loss if is_best else best_loss
+            save_checkpoint({
+                'epoch': epoch,
+                'model_state_dict': self.model.state_dict(),
+                'optimizer_state_dict': self.optimizer.state_dict(),
+                'val_loss': best_loss,
+                'val_dice_score': val_dice_score
+            }, is_best, self.args.output_path)
+
+        save_model({
+            'epoch': self.args.n_epochs + 1,
+            'model_state_dict': self.model.state_dict(),
+            'optimizer_state_dict': self.optimizer.state_dict(),
+            'val_loss': best_loss,
+            'val_dice_score': val_dice_score
+        }, self.args.output_path)
+
+    def train_epoch(self, epoch):
+
+        self.model.train()
+        dice_loss_global, ce_loss_global, combined_loss_global = AverageMeter(), AverageMeter(), AverageMeter()
+        dice_score = AverageMeter()
+
+        i = 0
+        for data_batch, labels_batch in tqdm(self.train_data_loader, desc="Training epoch"):
+            def step(trainer):
+                trainer.optimizer.zero_grad()
+
+                inputs = data_batch.float().to(trainer.args.device)
+                targets = labels_batch.float().to(trainer.args.device)
+                inputs.require_grad = True
+
+                if i == 0:
+                    self.writer.add_graph(trainer.model, inputs)
+
+                predictions, _ = trainer.model(inputs)
+
+                if trainer.args.loss == "dice":
+                    dice_loss, mean_dice, per_channel_dice = trainer.criterion(predictions, targets)
+                    subregion_loss = []
+                    dice_loss.backward()
+                    trainer.optimizer.step()
+
+                    trainer.writer.add_scalar('Training Dice Loss NCR', per_channel_dice[0].detach().item(),
+                                              epoch * trainer.number_train_data + i)
+                    trainer.writer.add_scalar('Training Dice Loss ED', per_channel_dice[1].detach().item(),
+                                              epoch * trainer.number_train_data + i)
+                    trainer.writer.add_scalar('Training Dice Loss ET', per_channel_dice[2].detach().item(),
+                                              epoch * trainer.number_train_data + i)
+
+                elif trainer.args.loss == "both_dice":
+                    total_loss, dice_loss, mean_dice, dice_loss_reg, subregion_loss = trainer.criterion(predictions,
+                                                                                                        targets)
+
+                    total_loss.backward()
+                    trainer.optimizer.step()
+
+                    total_loss = total_loss.detach().item()
+                    dice_loss_reg = dice_loss_reg.detach().item()
+                    trainer.writer.add_scalar('Train combined Region-Dice Loss', total_loss,
+                                              epoch * trainer.number_train_data + i)
+                    trainer.writer.add_scalar('Train region dice loss', dice_loss_reg,
+                                              epoch * trainer.number_train_data + i)
+
+                elif trainer.args.loss == "gdl":
+
+                    dice_loss, mean_dice = trainer.criterion(predictions, targets)
+                    subregion_loss = []
+                    dice_loss.backward()
+                    trainer.optimizer.step()
+
+                else:
+                    combined_loss, dice_loss, ce_loss, mean_dice, subregion_loss = trainer.criterion(predictions,
+                                                                                                     targets)
+                    combined_loss.backward()
+                    trainer.optimizer.step()
+
+                    combined_loss = combined_loss.detach().item()
+                    ce_loss = ce_loss.detach().item()
+
+                    combined_loss_global.update(combined_loss, data_batch.size(0))
+                    ce_loss_global.update(ce_loss, data_batch.size(0))
+
+                    trainer.writer.add_scalar('Train combined CE-Dice Loss', combined_loss,
+                                              epoch * trainer.number_train_data + i)
+                    trainer.writer.add_scalar('Train Cross Entropy Loss', ce_loss,
+                                              epoch * trainer.number_train_data + i)
+
+                dice_loss = dice_loss.detach().item()
+                mean_dice = mean_dice.detach().item()
+                dice_loss_global.update(dice_loss, data_batch.size(0))
+                dice_score.update(mean_dice, data_batch.size(0))
+
+                if subregion_loss:
+                    trainer.writer.add_scalar('Training Dice Loss WT', subregion_loss[0].detach().item(),
+                                              epoch * trainer.number_train_data + i)
+                    trainer.writer.add_scalar('Training Dice Loss TC', subregion_loss[1].detach().item(),
+                                              epoch * trainer.number_train_data + i)
+                    trainer.writer.add_scalar('Training Dice Loss ET', subregion_loss[2].detach().item(),
+                                              epoch * trainer.number_train_data + i)
+
+                trainer.writer.add_scalar('Training Dice Loss', dice_loss, epoch * trainer.number_train_data + i)
+                trainer.writer.add_scalar('Training Dice Score', mean_dice, epoch * trainer.number_train_data + i)
+
+                trainer._add_image(data_batch, False, "Modality patch")
+                trainer._add_image(labels_batch, True, "Segmentation ground truth patch")
+                trainer._add_image(predictions.max(1)[1], True, "Segmentation prediction patch")
+
+            step(self)
+
+            i += 1
+
+        if self.args.loss == "combined":
+            return dice_loss_global.avg(), dice_score.avg(), combined_loss_global.avg(), ce_loss_global.avg()
+        else:
+            return dice_loss_global.avg(), dice_score.avg(), 0, 0
+
+    def _add_image(self, batch, seg=False, title=""):
+        plot_buf = plot_batch(batch, seg=seg, slice=16, batch_size=len(batch))
+        im = Image.open(plot_buf)
+        image = T.ToTensor()(im)
+        self.writer.add_image(title, image)
+
+    def val_epoch(self, epoch):
+        self.model.eval()
+        losses, ce_loss_global, combined_loss_global = AverageMeter(), AverageMeter(), AverageMeter()
+        dice_score = AverageMeter()
+
+        i = 0
+        for data_batch, labels_batch in tqdm(self.valid_data_loader, desc="Validation epoch"):
+
+            def step(trainer):
+
+                inputs = data_batch.float().to(trainer.args.device)
+                targets = labels_batch.float().to(trainer.args.device)
+
+                with torch.no_grad():
+                    outputs, _ = trainer.model(inputs)
+
+                    if trainer.args.loss == "dice":
+                        dice_loss, mean_dice, subregion_loss = trainer.criterion(outputs, targets)
+
+                    elif trainer.args.loss == "gdl":
+                        dice_loss, mean_dice = trainer.criterion(outputs, targets)
+                        subregion_loss = []
+
+
+                    elif trainer.args.loss == "both_dice":
+                        total_loss, dice_loss, mean_dice, dice_loss_reg, subregion_loss = trainer.criterion(outputs,
+                                                                                                            targets)
+                        total_loss = total_loss.detach().item()
+                        dice_loss_reg = dice_loss_reg.detach().item()
+
+                        trainer.writer.add_scalar('Validation combined Region-Dice Loss', total_loss,
+                                                  epoch * trainer.number_val_data + i)
+                        trainer.writer.add_scalar('Validation region dice loss', dice_loss_reg,
+                                                  epoch * trainer.number_val_data + i)
+
+
+                    else:
+                        combined_loss, dice_loss, ce_loss, mean_dice, subregion_loss = trainer.criterion(outputs,
+                                                                                                         targets)
+                        combined_loss = combined_loss.detach().item()
+                        ce_loss = ce_loss.detach().item()
+                        combined_loss_global.update(combined_loss, data_batch.size(0))
+                        ce_loss_global.update(ce_loss, data_batch.size(0))
+
+                        trainer.writer.add_scalar('Validation Combined CE-Dice Loss', combined_loss,
+                                                  epoch * trainer.number_val_data + i)
+                        trainer.writer.add_scalar('Validation Cross Entropy Loss', ce_loss,
+                                                  epoch * trainer.number_val_data + i)
+
+                    dice_loss = dice_loss.detach().item()
+                    mean_dice = mean_dice.detach().item()
+                    losses.update(dice_loss, data_batch.size(0))
+                    dice_score.update(mean_dice, data_batch.size(0))
+
+                if subregion_loss:
+                    trainer.writer.add_scalar('Validation Dice Loss WT', subregion_loss[0].detach().item(),
+                                              epoch * trainer.number_train_data + i)
+                    trainer.writer.add_scalar('Validation Dice Loss TC', subregion_loss[1].detach().item(),
+                                              epoch * trainer.number_train_data + i)
+                    trainer.writer.add_scalar('Validation Dice Loss ET', subregion_loss[2].detach().item(),
+                                              epoch * trainer.number_train_data + i)
+
+                trainer.writer.add_scalar('Validation Dice Loss', dice_loss, epoch * trainer.number_val_data + i)
+                trainer.writer.add_scalar('Validation Dice Score', mean_dice, epoch * trainer.number_val_data + i)
+
+                trainer._add_image(data_batch, False, "Val Modality patch")
+                trainer._add_image(labels_batch, True, "Val Segmentation ground truth patch")
+                trainer._add_image(outputs.max(1)[1], True, "Val Segmentation prediction patch")
+
+            step(self)
+
+            i += 1
+
+        if self.args.loss == "combined":
+            return losses.avg(), dice_score.avg(), combined_loss_global.avg(), ce_loss_global.avg()
+        else:
+            return losses.avg(), dice_score.avg(), 0, 0
+
+    def _epoch_summary(self, epoch, train_loss, val_loss, train_dice_score, val_dice_score, train_combined_loss,
+                       train_ce_loss, val_combined_loss, val_ce_loss):
+
+        if self.args.loss == "dice" or self.args.loss == "both_dice":
+            logger.info(f'epoch: {epoch}\n '
+                        f'** Dice Loss **  : train_loss: {train_loss:.2f} | val_loss {val_loss:.2f} \n'
+                        f'** Dice Score ** : train_dice_score {train_dice_score:.2f} | val_dice_score {val_dice_score:.2f}')
+
+        else:
+            logger.info(f'epoch: {epoch}\n'
+                        f'** Combined Loss **  : train_loss: {train_combined_loss:.2f} | val_loss {val_combined_loss:.2f} \n'
+                        f'** CE Loss **        : train_loss {train_ce_loss:.2f} | val_loss {val_ce_loss:.2f}\n'
+                        f'** Dice Loss **      : train_loss: {train_loss:.2f} | val_loss {val_loss:.2f} \n'
+                        f'** Dice Score **     : train_dice_score {train_dice_score:.2f} | val_dice_score {val_dice_score:.2f}\n'
+                        )