|
a |
|
b/train.py |
|
|
1 |
import os |
|
|
2 |
|
|
|
3 |
import numpy as np |
|
|
4 |
import pandas as pd |
|
|
5 |
import cv2 |
|
|
6 |
|
|
|
7 |
import matplotlib.pyplot as plt |
|
|
8 |
|
|
|
9 |
from sklearn.model_selection import train_test_split |
|
|
10 |
|
|
|
11 |
import os |
|
|
12 |
import time |
|
|
13 |
|
|
|
14 |
import torch |
|
|
15 |
import torch.nn as nn |
|
|
16 |
from torch.optim import Adam |
|
|
17 |
from torch.optim.lr_scheduler import ReduceLROnPlateau |
|
|
18 |
|
|
|
19 |
from dataset_dataloader import LungsDataset, get_augmentations, get_dataloader |
|
|
20 |
from loss_metric import Meter, BCEDiceLoss |
|
|
21 |
|
|
|
22 |
from segmentation_models_pytorch.unet import Unet |
|
|
23 |
import matplotlib.pyplot as plt |
|
|
24 |
from IPython.display import clear_output |
|
|
25 |
|
|
|
26 |
|
|
|
27 |
class Trainer: |
|
|
28 |
""" |
|
|
29 |
Factory for training proccess. |
|
|
30 |
Args: |
|
|
31 |
display_plot: if True - plot train history after each epoch. |
|
|
32 |
net: neural network for mask prediction. |
|
|
33 |
criterion: factory for calculating objective loss. |
|
|
34 |
optimizer: optimizer for weights updating. |
|
|
35 |
phases: list with train and validation phases. |
|
|
36 |
dataloaders: dict with data loaders for train and val phases. |
|
|
37 |
imgs_dir: path to folder with images. |
|
|
38 |
masks_dir: path to folder with imasks. |
|
|
39 |
path_to_csv: path to csv file. |
|
|
40 |
meter: factory for storing and updating metrics. |
|
|
41 |
batch_size: data batch size for one step weights updating. |
|
|
42 |
num_epochs: num weights updation for all data. |
|
|
43 |
accumulation_steps: the number of steps after which the optimization step can be taken |
|
|
44 |
(https://www.kaggle.com/c/understanding_cloud_organization/discussion/105614). |
|
|
45 |
lr: learning rate for optimizer. |
|
|
46 |
scheduler: scheduler for control learning rate. |
|
|
47 |
losses: dict for storing lists with losses for each phase. |
|
|
48 |
jaccard_scores: dict for storing lists with jaccard scores for each phase. |
|
|
49 |
dice_scores: dict for storing lists with dice scores for each phase. |
|
|
50 |
""" |
|
|
51 |
def __init__(self, |
|
|
52 |
net: nn.Module, |
|
|
53 |
criterion: nn.Module, |
|
|
54 |
lr: float, |
|
|
55 |
accumulation_steps: int, |
|
|
56 |
batch_size: int, |
|
|
57 |
num_epochs: int, |
|
|
58 |
imgs_dir: str, |
|
|
59 |
masks_dir: str, |
|
|
60 |
path_to_csv: str, |
|
|
61 |
display_plot: bool = True |
|
|
62 |
): |
|
|
63 |
|
|
|
64 |
"""Initialization.""" |
|
|
65 |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
66 |
print("device:", self.device) |
|
|
67 |
self.display_plot = display_plot |
|
|
68 |
self.net = net |
|
|
69 |
self.net = self.net.to(self.device) |
|
|
70 |
self.criterion = criterion |
|
|
71 |
self.optimizer = Adam(self.net.parameters(), lr=lr) |
|
|
72 |
self.scheduler = ReduceLROnPlateau(self.optimizer, mode="min", |
|
|
73 |
patience=3, verbose=True) |
|
|
74 |
self.accumulation_steps = accumulation_steps // batch_size |
|
|
75 |
self.phases = ["train", "val"] |
|
|
76 |
self.num_epochs = num_epochs |
|
|
77 |
|
|
|
78 |
self.dataloaders = { |
|
|
79 |
phase: get_dataloader( |
|
|
80 |
imgs_dir = imgs_dir, |
|
|
81 |
masks_dir = masks_dir, |
|
|
82 |
path_to_csv = path_to_csv, |
|
|
83 |
phase = phase, |
|
|
84 |
batch_size = 8, |
|
|
85 |
num_workers = 6 |
|
|
86 |
) |
|
|
87 |
for phase in self.phases |
|
|
88 |
} |
|
|
89 |
self.best_loss = float("inf") |
|
|
90 |
self.losses = {phase: [] for phase in self.phases} |
|
|
91 |
self.dice_scores = {phase: [] for phase in self.phases} |
|
|
92 |
self.jaccard_scores = {phase: [] for phase in self.phases} |
|
|
93 |
|
|
|
94 |
def _compute_loss_and_outputs(self, |
|
|
95 |
images: torch.Tensor, |
|
|
96 |
targets: torch.Tensor): |
|
|
97 |
images = images.to(self.device) |
|
|
98 |
targets = targets.to(self.device) |
|
|
99 |
logits = self.net(images) |
|
|
100 |
loss = self.criterion(logits, targets) |
|
|
101 |
return loss, logits |
|
|
102 |
|
|
|
103 |
def _do_epoch(self, epoch: int, phase: str): |
|
|
104 |
print(f"{phase} epoch: {epoch} | time: {time.strftime('%H:%M:%S')}") |
|
|
105 |
|
|
|
106 |
self.net.train() if phase == "train" else self.net.eval() |
|
|
107 |
meter = Meter() |
|
|
108 |
dataloader = self.dataloaders[phase] |
|
|
109 |
total_batches = len(dataloader) |
|
|
110 |
running_loss = 0.0 |
|
|
111 |
self.optimizer.zero_grad() |
|
|
112 |
for itr, (images, targets) in enumerate(dataloader): |
|
|
113 |
loss, logits = self._compute_loss_and_outputs(images, targets) |
|
|
114 |
loss = loss / self.accumulation_steps |
|
|
115 |
if phase == "train": |
|
|
116 |
loss.backward() |
|
|
117 |
if (itr + 1) % self.accumulation_steps == 0: |
|
|
118 |
self.optimizer.step() |
|
|
119 |
self.optimizer.zero_grad() |
|
|
120 |
running_loss += loss.item() |
|
|
121 |
meter.update(logits.detach().cpu(), |
|
|
122 |
targets.detach().cpu() |
|
|
123 |
) |
|
|
124 |
|
|
|
125 |
epoch_loss = (running_loss * self.accumulation_steps) / total_batches |
|
|
126 |
epoch_dice, epoch_iou = meter.get_metrics() |
|
|
127 |
|
|
|
128 |
self.losses[phase].append(epoch_loss) |
|
|
129 |
self.dice_scores[phase].append(epoch_dice) |
|
|
130 |
self.jaccard_scores[phase].append(epoch_iou) |
|
|
131 |
|
|
|
132 |
return epoch_loss |
|
|
133 |
|
|
|
134 |
def train(self): |
|
|
135 |
for epoch in range(self.num_epochs): |
|
|
136 |
self._do_epoch(epoch, "train") |
|
|
137 |
with torch.no_grad(): |
|
|
138 |
val_loss = self._do_epoch(epoch, "val") |
|
|
139 |
self.scheduler.step(val_loss) |
|
|
140 |
if self.display_plot: |
|
|
141 |
self._plot_train_history() |
|
|
142 |
|
|
|
143 |
if val_loss < self.best_loss: |
|
|
144 |
print(f"\n{'#'*20}\nSaved new checkpoint\n{'#'*20}\n") |
|
|
145 |
self.best_loss = val_loss |
|
|
146 |
torch.save(self.net.state_dict(), "best_model.pth") |
|
|
147 |
print() |
|
|
148 |
self._save_train_history() |
|
|
149 |
|
|
|
150 |
def _plot_train_history(self): |
|
|
151 |
data = [self.losses, self.dice_scores, self.jaccard_scores] |
|
|
152 |
colors = ['deepskyblue', "crimson"] |
|
|
153 |
labels = [ |
|
|
154 |
f""" |
|
|
155 |
train loss {self.losses['train'][-1]} |
|
|
156 |
val loss {self.losses['val'][-1]} |
|
|
157 |
""", |
|
|
158 |
|
|
|
159 |
f""" |
|
|
160 |
train dice score {self.dice_scores['train'][-1]} |
|
|
161 |
val dice score {self.dice_scores['val'][-1]} |
|
|
162 |
""", |
|
|
163 |
|
|
|
164 |
f""" |
|
|
165 |
train jaccard score {self.jaccard_scores['train'][-1]} |
|
|
166 |
val jaccard score {self.jaccard_scores['val'][-1]} |
|
|
167 |
""", |
|
|
168 |
] |
|
|
169 |
|
|
|
170 |
clear_output(True) |
|
|
171 |
with plt.style.context("seaborn-dark-palette"): |
|
|
172 |
fig, axes = plt.subplots(3, 1, figsize=(8, 10)) |
|
|
173 |
for i, ax in enumerate(axes): |
|
|
174 |
ax.plot(data[i]['val'], c=colors[0], label="val") |
|
|
175 |
ax.plot(data[i]['train'], c=colors[-1], label="train") |
|
|
176 |
ax.set_title(labels[i]) |
|
|
177 |
ax.legend(loc="upper right") |
|
|
178 |
|
|
|
179 |
plt.tight_layout() |
|
|
180 |
plt.show() |
|
|
181 |
|
|
|
182 |
def load_predtrain_model(self, |
|
|
183 |
state_path: str): |
|
|
184 |
self.net.load_state_dict(torch.load(state_path)) |
|
|
185 |
print("Predtrain model loaded") |
|
|
186 |
|
|
|
187 |
def _save_train_history(self): |
|
|
188 |
"""writing model weights and training logs to files.""" |
|
|
189 |
torch.save(self.net.state_dict(), |
|
|
190 |
f"last_epoch_model.pth") |
|
|
191 |
|
|
|
192 |
logs_ = [self.losses, self.dice_scores, self.jaccard_scores] |
|
|
193 |
log_names_ = ["_loss", "_dice", "_jaccard"] |
|
|
194 |
logs = [logs_[i][key] for i in list(range(len(logs_))) |
|
|
195 |
for key in logs_[i]] |
|
|
196 |
log_names = [key+log_names_[i] |
|
|
197 |
for i in list(range(len(logs_))) |
|
|
198 |
for key in logs_[i] |
|
|
199 |
] |
|
|
200 |
pd.DataFrame( |
|
|
201 |
dict(zip(log_names, logs)) |
|
|
202 |
).to_csv("train_log.csv", index=False) |