--- a +++ b/train.py @@ -0,0 +1,202 @@ +import os + +import numpy as np +import pandas as pd +import cv2 + +import matplotlib.pyplot as plt + +from sklearn.model_selection import train_test_split + +import os +import time + +import torch +import torch.nn as nn +from torch.optim import Adam +from torch.optim.lr_scheduler import ReduceLROnPlateau + +from dataset_dataloader import LungsDataset, get_augmentations, get_dataloader +from loss_metric import Meter, BCEDiceLoss + +from segmentation_models_pytorch.unet import Unet +import matplotlib.pyplot as plt +from IPython.display import clear_output + + +class Trainer: + """ + Factory for training proccess. + Args: + display_plot: if True - plot train history after each epoch. + net: neural network for mask prediction. + criterion: factory for calculating objective loss. + optimizer: optimizer for weights updating. + phases: list with train and validation phases. + dataloaders: dict with data loaders for train and val phases. + imgs_dir: path to folder with images. + masks_dir: path to folder with imasks. + path_to_csv: path to csv file. + meter: factory for storing and updating metrics. + batch_size: data batch size for one step weights updating. + num_epochs: num weights updation for all data. + accumulation_steps: the number of steps after which the optimization step can be taken + (https://www.kaggle.com/c/understanding_cloud_organization/discussion/105614). + lr: learning rate for optimizer. + scheduler: scheduler for control learning rate. + losses: dict for storing lists with losses for each phase. + jaccard_scores: dict for storing lists with jaccard scores for each phase. + dice_scores: dict for storing lists with dice scores for each phase. + """ + def __init__(self, + net: nn.Module, + criterion: nn.Module, + lr: float, + accumulation_steps: int, + batch_size: int, + num_epochs: int, + imgs_dir: str, + masks_dir: str, + path_to_csv: str, + display_plot: bool = True + ): + + """Initialization.""" + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + print("device:", self.device) + self.display_plot = display_plot + self.net = net + self.net = self.net.to(self.device) + self.criterion = criterion + self.optimizer = Adam(self.net.parameters(), lr=lr) + self.scheduler = ReduceLROnPlateau(self.optimizer, mode="min", + patience=3, verbose=True) + self.accumulation_steps = accumulation_steps // batch_size + self.phases = ["train", "val"] + self.num_epochs = num_epochs + + self.dataloaders = { + phase: get_dataloader( + imgs_dir = imgs_dir, + masks_dir = masks_dir, + path_to_csv = path_to_csv, + phase = phase, + batch_size = 8, + num_workers = 6 + ) + for phase in self.phases + } + self.best_loss = float("inf") + self.losses = {phase: [] for phase in self.phases} + self.dice_scores = {phase: [] for phase in self.phases} + self.jaccard_scores = {phase: [] for phase in self.phases} + + def _compute_loss_and_outputs(self, + images: torch.Tensor, + targets: torch.Tensor): + images = images.to(self.device) + targets = targets.to(self.device) + logits = self.net(images) + loss = self.criterion(logits, targets) + return loss, logits + + def _do_epoch(self, epoch: int, phase: str): + print(f"{phase} epoch: {epoch} | time: {time.strftime('%H:%M:%S')}") + + self.net.train() if phase == "train" else self.net.eval() + meter = Meter() + dataloader = self.dataloaders[phase] + total_batches = len(dataloader) + running_loss = 0.0 + self.optimizer.zero_grad() + for itr, (images, targets) in enumerate(dataloader): + loss, logits = self._compute_loss_and_outputs(images, targets) + loss = loss / self.accumulation_steps + if phase == "train": + loss.backward() + if (itr + 1) % self.accumulation_steps == 0: + self.optimizer.step() + self.optimizer.zero_grad() + running_loss += loss.item() + meter.update(logits.detach().cpu(), + targets.detach().cpu() + ) + + epoch_loss = (running_loss * self.accumulation_steps) / total_batches + epoch_dice, epoch_iou = meter.get_metrics() + + self.losses[phase].append(epoch_loss) + self.dice_scores[phase].append(epoch_dice) + self.jaccard_scores[phase].append(epoch_iou) + + return epoch_loss + + def train(self): + for epoch in range(self.num_epochs): + self._do_epoch(epoch, "train") + with torch.no_grad(): + val_loss = self._do_epoch(epoch, "val") + self.scheduler.step(val_loss) + if self.display_plot: + self._plot_train_history() + + if val_loss < self.best_loss: + print(f"\n{'#'*20}\nSaved new checkpoint\n{'#'*20}\n") + self.best_loss = val_loss + torch.save(self.net.state_dict(), "best_model.pth") + print() + self._save_train_history() + + def _plot_train_history(self): + data = [self.losses, self.dice_scores, self.jaccard_scores] + colors = ['deepskyblue', "crimson"] + labels = [ + f""" + train loss {self.losses['train'][-1]} + val loss {self.losses['val'][-1]} + """, + + f""" + train dice score {self.dice_scores['train'][-1]} + val dice score {self.dice_scores['val'][-1]} + """, + + f""" + train jaccard score {self.jaccard_scores['train'][-1]} + val jaccard score {self.jaccard_scores['val'][-1]} + """, + ] + + clear_output(True) + with plt.style.context("seaborn-dark-palette"): + fig, axes = plt.subplots(3, 1, figsize=(8, 10)) + for i, ax in enumerate(axes): + ax.plot(data[i]['val'], c=colors[0], label="val") + ax.plot(data[i]['train'], c=colors[-1], label="train") + ax.set_title(labels[i]) + ax.legend(loc="upper right") + + plt.tight_layout() + plt.show() + + def load_predtrain_model(self, + state_path: str): + self.net.load_state_dict(torch.load(state_path)) + print("Predtrain model loaded") + + def _save_train_history(self): + """writing model weights and training logs to files.""" + torch.save(self.net.state_dict(), + f"last_epoch_model.pth") + + logs_ = [self.losses, self.dice_scores, self.jaccard_scores] + log_names_ = ["_loss", "_dice", "_jaccard"] + logs = [logs_[i][key] for i in list(range(len(logs_))) + for key in logs_[i]] + log_names = [key+log_names_[i] + for i in list(range(len(logs_))) + for key in logs_[i] + ] + pd.DataFrame( + dict(zip(log_names, logs)) + ).to_csv("train_log.csv", index=False)