import os
import numpy as np
import pandas as pd
import cv2
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import os
import time
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from dataset_dataloader import LungsDataset, get_augmentations, get_dataloader
from loss_metric import Meter, BCEDiceLoss
from segmentation_models_pytorch.unet import Unet
import matplotlib.pyplot as plt
from IPython.display import clear_output
class Trainer:
"""
Factory for training proccess.
Args:
display_plot: if True - plot train history after each epoch.
net: neural network for mask prediction.
criterion: factory for calculating objective loss.
optimizer: optimizer for weights updating.
phases: list with train and validation phases.
dataloaders: dict with data loaders for train and val phases.
imgs_dir: path to folder with images.
masks_dir: path to folder with imasks.
path_to_csv: path to csv file.
meter: factory for storing and updating metrics.
batch_size: data batch size for one step weights updating.
num_epochs: num weights updation for all data.
accumulation_steps: the number of steps after which the optimization step can be taken
(https://www.kaggle.com/c/understanding_cloud_organization/discussion/105614).
lr: learning rate for optimizer.
scheduler: scheduler for control learning rate.
losses: dict for storing lists with losses for each phase.
jaccard_scores: dict for storing lists with jaccard scores for each phase.
dice_scores: dict for storing lists with dice scores for each phase.
"""
def __init__(self,
net: nn.Module,
criterion: nn.Module,
lr: float,
accumulation_steps: int,
batch_size: int,
num_epochs: int,
imgs_dir: str,
masks_dir: str,
path_to_csv: str,
display_plot: bool = True
):
"""Initialization."""
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("device:", self.device)
self.display_plot = display_plot
self.net = net
self.net = self.net.to(self.device)
self.criterion = criterion
self.optimizer = Adam(self.net.parameters(), lr=lr)
self.scheduler = ReduceLROnPlateau(self.optimizer, mode="min",
patience=3, verbose=True)
self.accumulation_steps = accumulation_steps // batch_size
self.phases = ["train", "val"]
self.num_epochs = num_epochs
self.dataloaders = {
phase: get_dataloader(
imgs_dir = imgs_dir,
masks_dir = masks_dir,
path_to_csv = path_to_csv,
phase = phase,
batch_size = 8,
num_workers = 6
)
for phase in self.phases
}
self.best_loss = float("inf")
self.losses = {phase: [] for phase in self.phases}
self.dice_scores = {phase: [] for phase in self.phases}
self.jaccard_scores = {phase: [] for phase in self.phases}
def _compute_loss_and_outputs(self,
images: torch.Tensor,
targets: torch.Tensor):
images = images.to(self.device)
targets = targets.to(self.device)
logits = self.net(images)
loss = self.criterion(logits, targets)
return loss, logits
def _do_epoch(self, epoch: int, phase: str):
print(f"{phase} epoch: {epoch} | time: {time.strftime('%H:%M:%S')}")
self.net.train() if phase == "train" else self.net.eval()
meter = Meter()
dataloader = self.dataloaders[phase]
total_batches = len(dataloader)
running_loss = 0.0
self.optimizer.zero_grad()
for itr, (images, targets) in enumerate(dataloader):
loss, logits = self._compute_loss_and_outputs(images, targets)
loss = loss / self.accumulation_steps
if phase == "train":
loss.backward()
if (itr + 1) % self.accumulation_steps == 0:
self.optimizer.step()
self.optimizer.zero_grad()
running_loss += loss.item()
meter.update(logits.detach().cpu(),
targets.detach().cpu()
)
epoch_loss = (running_loss * self.accumulation_steps) / total_batches
epoch_dice, epoch_iou = meter.get_metrics()
self.losses[phase].append(epoch_loss)
self.dice_scores[phase].append(epoch_dice)
self.jaccard_scores[phase].append(epoch_iou)
return epoch_loss
def train(self):
for epoch in range(self.num_epochs):
self._do_epoch(epoch, "train")
with torch.no_grad():
val_loss = self._do_epoch(epoch, "val")
self.scheduler.step(val_loss)
if self.display_plot:
self._plot_train_history()
if val_loss < self.best_loss:
print(f"\n{'#'*20}\nSaved new checkpoint\n{'#'*20}\n")
self.best_loss = val_loss
torch.save(self.net.state_dict(), "best_model.pth")
print()
self._save_train_history()
def _plot_train_history(self):
data = [self.losses, self.dice_scores, self.jaccard_scores]
colors = ['deepskyblue', "crimson"]
labels = [
f"""
train loss {self.losses['train'][-1]}
val loss {self.losses['val'][-1]}
""",
f"""
train dice score {self.dice_scores['train'][-1]}
val dice score {self.dice_scores['val'][-1]}
""",
f"""
train jaccard score {self.jaccard_scores['train'][-1]}
val jaccard score {self.jaccard_scores['val'][-1]}
""",
]
clear_output(True)
with plt.style.context("seaborn-dark-palette"):
fig, axes = plt.subplots(3, 1, figsize=(8, 10))
for i, ax in enumerate(axes):
ax.plot(data[i]['val'], c=colors[0], label="val")
ax.plot(data[i]['train'], c=colors[-1], label="train")
ax.set_title(labels[i])
ax.legend(loc="upper right")
plt.tight_layout()
plt.show()
def load_predtrain_model(self,
state_path: str):
self.net.load_state_dict(torch.load(state_path))
print("Predtrain model loaded")
def _save_train_history(self):
"""writing model weights and training logs to files."""
torch.save(self.net.state_dict(),
f"last_epoch_model.pth")
logs_ = [self.losses, self.dice_scores, self.jaccard_scores]
log_names_ = ["_loss", "_dice", "_jaccard"]
logs = [logs_[i][key] for i in list(range(len(logs_)))
for key in logs_[i]]
log_names = [key+log_names_[i]
for i in list(range(len(logs_)))
for key in logs_[i]
]
pd.DataFrame(
dict(zip(log_names, logs))
).to_csv("train_log.csv", index=False)