--- a +++ b/src/train/trainer.py @@ -0,0 +1,270 @@ +import torch +from PIL import Image +from torchvision import transforms as T +from tqdm import tqdm + +from src.dataset.utils.visualization import plot_batch +from src.models.io_model import save_checkpoint, save_model +from src.metrics.training_metrics import AverageMeter +from src.logging_conf import logger + + +class TrainerArgs: + def __init__(self, n_epochs=50, device="cpu", output_path="", loss="dice"): + self.n_epochs = n_epochs + self.device = device + self.output_path = output_path + self.loss = loss + + +class Trainer: + + def __init__(self, args, model, optimizer, criterion, start_epoch, train_loader, val_loader, lr_scheduler, writer): + self.model = model + self.optimizer = optimizer + self.criterion = criterion + + self.train_data_loader = train_loader + self.number_train_data = len(self.train_data_loader) + + self.valid_data_loader = val_loader + self.number_val_data = len(self.valid_data_loader) + + self.lr_scheduler = lr_scheduler + self.writer = writer + + self.start_epoch = start_epoch + self.args = args + + def start(self, best_loss=1000): + val_dice_score = 0 + + for epoch in range(self.start_epoch, self.args.n_epochs): + + train_dice_loss, train_dice_score, train_combined_loss, train_ce_loss = self.train_epoch(epoch) + val_dice_loss, val_dice_score, val_combined_loss, val_ce_loss = self.val_epoch(epoch) + + val_loss = val_combined_loss if self.args.loss == "combined" else val_dice_loss + if self.lr_scheduler: + self.lr_scheduler.step(val_loss) + + self._epoch_summary(epoch, train_dice_loss, val_dice_loss, train_dice_score, val_dice_score, + train_combined_loss, train_ce_loss, val_combined_loss, val_ce_loss) + + is_best = bool(val_loss < best_loss) + best_loss = val_loss if is_best else best_loss + save_checkpoint({ + 'epoch': epoch, + 'model_state_dict': self.model.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict(), + 'val_loss': best_loss, + 'val_dice_score': val_dice_score + }, is_best, self.args.output_path) + + save_model({ + 'epoch': self.args.n_epochs + 1, + 'model_state_dict': self.model.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict(), + 'val_loss': best_loss, + 'val_dice_score': val_dice_score + }, self.args.output_path) + + def train_epoch(self, epoch): + + self.model.train() + dice_loss_global, ce_loss_global, combined_loss_global = AverageMeter(), AverageMeter(), AverageMeter() + dice_score = AverageMeter() + + i = 0 + for data_batch, labels_batch in tqdm(self.train_data_loader, desc="Training epoch"): + def step(trainer): + trainer.optimizer.zero_grad() + + inputs = data_batch.float().to(trainer.args.device) + targets = labels_batch.float().to(trainer.args.device) + inputs.require_grad = True + + if i == 0: + self.writer.add_graph(trainer.model, inputs) + + predictions, _ = trainer.model(inputs) + + if trainer.args.loss == "dice": + dice_loss, mean_dice, per_channel_dice = trainer.criterion(predictions, targets) + subregion_loss = [] + dice_loss.backward() + trainer.optimizer.step() + + trainer.writer.add_scalar('Training Dice Loss NCR', per_channel_dice[0].detach().item(), + epoch * trainer.number_train_data + i) + trainer.writer.add_scalar('Training Dice Loss ED', per_channel_dice[1].detach().item(), + epoch * trainer.number_train_data + i) + trainer.writer.add_scalar('Training Dice Loss ET', per_channel_dice[2].detach().item(), + epoch * trainer.number_train_data + i) + + elif trainer.args.loss == "both_dice": + total_loss, dice_loss, mean_dice, dice_loss_reg, subregion_loss = trainer.criterion(predictions, + targets) + + total_loss.backward() + trainer.optimizer.step() + + total_loss = total_loss.detach().item() + dice_loss_reg = dice_loss_reg.detach().item() + trainer.writer.add_scalar('Train combined Region-Dice Loss', total_loss, + epoch * trainer.number_train_data + i) + trainer.writer.add_scalar('Train region dice loss', dice_loss_reg, + epoch * trainer.number_train_data + i) + + elif trainer.args.loss == "gdl": + + dice_loss, mean_dice = trainer.criterion(predictions, targets) + subregion_loss = [] + dice_loss.backward() + trainer.optimizer.step() + + else: + combined_loss, dice_loss, ce_loss, mean_dice, subregion_loss = trainer.criterion(predictions, + targets) + combined_loss.backward() + trainer.optimizer.step() + + combined_loss = combined_loss.detach().item() + ce_loss = ce_loss.detach().item() + + combined_loss_global.update(combined_loss, data_batch.size(0)) + ce_loss_global.update(ce_loss, data_batch.size(0)) + + trainer.writer.add_scalar('Train combined CE-Dice Loss', combined_loss, + epoch * trainer.number_train_data + i) + trainer.writer.add_scalar('Train Cross Entropy Loss', ce_loss, + epoch * trainer.number_train_data + i) + + dice_loss = dice_loss.detach().item() + mean_dice = mean_dice.detach().item() + dice_loss_global.update(dice_loss, data_batch.size(0)) + dice_score.update(mean_dice, data_batch.size(0)) + + if subregion_loss: + trainer.writer.add_scalar('Training Dice Loss WT', subregion_loss[0].detach().item(), + epoch * trainer.number_train_data + i) + trainer.writer.add_scalar('Training Dice Loss TC', subregion_loss[1].detach().item(), + epoch * trainer.number_train_data + i) + trainer.writer.add_scalar('Training Dice Loss ET', subregion_loss[2].detach().item(), + epoch * trainer.number_train_data + i) + + trainer.writer.add_scalar('Training Dice Loss', dice_loss, epoch * trainer.number_train_data + i) + trainer.writer.add_scalar('Training Dice Score', mean_dice, epoch * trainer.number_train_data + i) + + trainer._add_image(data_batch, False, "Modality patch") + trainer._add_image(labels_batch, True, "Segmentation ground truth patch") + trainer._add_image(predictions.max(1)[1], True, "Segmentation prediction patch") + + step(self) + + i += 1 + + if self.args.loss == "combined": + return dice_loss_global.avg(), dice_score.avg(), combined_loss_global.avg(), ce_loss_global.avg() + else: + return dice_loss_global.avg(), dice_score.avg(), 0, 0 + + def _add_image(self, batch, seg=False, title=""): + plot_buf = plot_batch(batch, seg=seg, slice=16, batch_size=len(batch)) + im = Image.open(plot_buf) + image = T.ToTensor()(im) + self.writer.add_image(title, image) + + def val_epoch(self, epoch): + self.model.eval() + losses, ce_loss_global, combined_loss_global = AverageMeter(), AverageMeter(), AverageMeter() + dice_score = AverageMeter() + + i = 0 + for data_batch, labels_batch in tqdm(self.valid_data_loader, desc="Validation epoch"): + + def step(trainer): + + inputs = data_batch.float().to(trainer.args.device) + targets = labels_batch.float().to(trainer.args.device) + + with torch.no_grad(): + outputs, _ = trainer.model(inputs) + + if trainer.args.loss == "dice": + dice_loss, mean_dice, subregion_loss = trainer.criterion(outputs, targets) + + elif trainer.args.loss == "gdl": + dice_loss, mean_dice = trainer.criterion(outputs, targets) + subregion_loss = [] + + + elif trainer.args.loss == "both_dice": + total_loss, dice_loss, mean_dice, dice_loss_reg, subregion_loss = trainer.criterion(outputs, + targets) + total_loss = total_loss.detach().item() + dice_loss_reg = dice_loss_reg.detach().item() + + trainer.writer.add_scalar('Validation combined Region-Dice Loss', total_loss, + epoch * trainer.number_val_data + i) + trainer.writer.add_scalar('Validation region dice loss', dice_loss_reg, + epoch * trainer.number_val_data + i) + + + else: + combined_loss, dice_loss, ce_loss, mean_dice, subregion_loss = trainer.criterion(outputs, + targets) + combined_loss = combined_loss.detach().item() + ce_loss = ce_loss.detach().item() + combined_loss_global.update(combined_loss, data_batch.size(0)) + ce_loss_global.update(ce_loss, data_batch.size(0)) + + trainer.writer.add_scalar('Validation Combined CE-Dice Loss', combined_loss, + epoch * trainer.number_val_data + i) + trainer.writer.add_scalar('Validation Cross Entropy Loss', ce_loss, + epoch * trainer.number_val_data + i) + + dice_loss = dice_loss.detach().item() + mean_dice = mean_dice.detach().item() + losses.update(dice_loss, data_batch.size(0)) + dice_score.update(mean_dice, data_batch.size(0)) + + if subregion_loss: + trainer.writer.add_scalar('Validation Dice Loss WT', subregion_loss[0].detach().item(), + epoch * trainer.number_train_data + i) + trainer.writer.add_scalar('Validation Dice Loss TC', subregion_loss[1].detach().item(), + epoch * trainer.number_train_data + i) + trainer.writer.add_scalar('Validation Dice Loss ET', subregion_loss[2].detach().item(), + epoch * trainer.number_train_data + i) + + trainer.writer.add_scalar('Validation Dice Loss', dice_loss, epoch * trainer.number_val_data + i) + trainer.writer.add_scalar('Validation Dice Score', mean_dice, epoch * trainer.number_val_data + i) + + trainer._add_image(data_batch, False, "Val Modality patch") + trainer._add_image(labels_batch, True, "Val Segmentation ground truth patch") + trainer._add_image(outputs.max(1)[1], True, "Val Segmentation prediction patch") + + step(self) + + i += 1 + + if self.args.loss == "combined": + return losses.avg(), dice_score.avg(), combined_loss_global.avg(), ce_loss_global.avg() + else: + return losses.avg(), dice_score.avg(), 0, 0 + + def _epoch_summary(self, epoch, train_loss, val_loss, train_dice_score, val_dice_score, train_combined_loss, + train_ce_loss, val_combined_loss, val_ce_loss): + + if self.args.loss == "dice" or self.args.loss == "both_dice": + logger.info(f'epoch: {epoch}\n ' + f'** Dice Loss ** : train_loss: {train_loss:.2f} | val_loss {val_loss:.2f} \n' + f'** Dice Score ** : train_dice_score {train_dice_score:.2f} | val_dice_score {val_dice_score:.2f}') + + else: + logger.info(f'epoch: {epoch}\n' + f'** Combined Loss ** : train_loss: {train_combined_loss:.2f} | val_loss {val_combined_loss:.2f} \n' + f'** CE Loss ** : train_loss {train_ce_loss:.2f} | val_loss {val_ce_loss:.2f}\n' + f'** Dice Loss ** : train_loss: {train_loss:.2f} | val_loss {val_loss:.2f} \n' + f'** Dice Score ** : train_dice_score {train_dice_score:.2f} | val_dice_score {val_dice_score:.2f}\n' + )