|
a |
|
b/src/train_functions.py |
|
|
1 |
import torch |
|
|
2 |
import os |
|
|
3 |
from utils import get_scheduler, get_model, get_criterion, get_optimizer, get_metric |
|
|
4 |
from data_functions import get_loaders |
|
|
5 |
from utils import OneHotEncoder |
|
|
6 |
import wandb |
|
|
7 |
from tqdm import tqdm |
|
|
8 |
import time |
|
|
9 |
|
|
|
10 |
|
|
|
11 |
# helping function to normal visualisation in Colaboratory |
|
|
12 |
def foo_(): |
|
|
13 |
time.sleep(0.3) |
|
|
14 |
|
|
|
15 |
|
|
|
16 |
def train_epoch(model, train_dl, encoder, criterion, metric, optimizer, scheduler, device): |
|
|
17 |
model.train() |
|
|
18 |
loss_sum = 0 |
|
|
19 |
score_sum = 0 |
|
|
20 |
with tqdm(total=len(train_dl), position=0, leave=True) as pbar: |
|
|
21 |
for X, y in tqdm(train_dl, position=0, leave=True): |
|
|
22 |
pbar.update() |
|
|
23 |
X = X.to(device) |
|
|
24 |
if len(torch.unique(X)) == 1: |
|
|
25 |
continue |
|
|
26 |
if encoder is not None: |
|
|
27 |
y = encoder(y) |
|
|
28 |
y = y.squeeze(4) |
|
|
29 |
y = y.to(device) |
|
|
30 |
|
|
|
31 |
optimizer.zero_grad() |
|
|
32 |
output = model(X) |
|
|
33 |
loss = criterion(output, y) |
|
|
34 |
loss.backward() |
|
|
35 |
optimizer.step() |
|
|
36 |
scheduler.step() |
|
|
37 |
|
|
|
38 |
loss = loss.item() |
|
|
39 |
score = metric(output, y).mean().item() |
|
|
40 |
loss_sum += loss |
|
|
41 |
score_sum += score |
|
|
42 |
return loss_sum / len(train_dl), score_sum / len(train_dl) |
|
|
43 |
|
|
|
44 |
|
|
|
45 |
def eval_epoch(model, val_dl, encoder, criterion, metric, device): |
|
|
46 |
model.eval() |
|
|
47 |
loss_sum = 0 |
|
|
48 |
score_sum = 0 |
|
|
49 |
with tqdm(total=len(val_dl), position=0, leave=True) as pbar: |
|
|
50 |
for X, y in tqdm(val_dl, position=0, leave=True): |
|
|
51 |
pbar.update() |
|
|
52 |
X = X.to(device) |
|
|
53 |
if len(torch.unique(X)) == 1: |
|
|
54 |
continue |
|
|
55 |
if encoder is not None: |
|
|
56 |
y = encoder(y) |
|
|
57 |
y = y.squeeze() |
|
|
58 |
y = y.to(device) |
|
|
59 |
|
|
|
60 |
with torch.no_grad(): |
|
|
61 |
output = model(X) |
|
|
62 |
loss = criterion(output, y).item() |
|
|
63 |
score = metric(output, y).mean().item() |
|
|
64 |
loss_sum += loss |
|
|
65 |
score_sum += score |
|
|
66 |
return loss_sum / len(val_dl), score_sum / len(val_dl) |
|
|
67 |
|
|
|
68 |
|
|
|
69 |
def run(cfg, model_name, use_wandb=True, max_early_stopping=2): |
|
|
70 |
torch.cuda.empty_cache() |
|
|
71 |
|
|
|
72 |
# <<<<< SETUP >>>>> |
|
|
73 |
train_loader, val_loader = get_loaders(cfg) |
|
|
74 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
75 |
|
|
|
76 |
model = get_model(cfg)(cfg=cfg).to(device) |
|
|
77 |
optimizer = get_optimizer(cfg)(model.parameters(), **cfg.optimizer_params) |
|
|
78 |
scheduler = get_scheduler(cfg)(optimizer, **cfg.scheduler_params) |
|
|
79 |
metric = get_metric(cfg)(**cfg.metric_params) |
|
|
80 |
criterion = get_criterion(cfg)(**cfg.criterion_params) |
|
|
81 |
encoder = OneHotEncoder(cfg) |
|
|
82 |
|
|
|
83 |
# wandb is watching |
|
|
84 |
if use_wandb: |
|
|
85 |
wandb.init(project='Covid19_CT_segmentation_' + str(cfg.dataset_name), entity='aiijcteamname', config=cfg, |
|
|
86 |
name=model_name) |
|
|
87 |
wandb.watch(model, log_freq=100) |
|
|
88 |
|
|
|
89 |
best_val_loss = 999 |
|
|
90 |
last_train_loss = 0 |
|
|
91 |
last_val_loss = 999 |
|
|
92 |
early_stopping_flag = 0 |
|
|
93 |
best_state_dict = model.state_dict() |
|
|
94 |
for epoch in range(1, cfg.epochs + 1): |
|
|
95 |
print(f'Epoch #{epoch}') |
|
|
96 |
|
|
|
97 |
# <<<<< TRAIN >>>>> |
|
|
98 |
train_loss, train_score = train_epoch(model, train_loader, encoder, |
|
|
99 |
criterion, metric, |
|
|
100 |
optimizer, scheduler, device) |
|
|
101 |
print(' Score | Loss') |
|
|
102 |
print(f'Train: {train_score:.6f} | {train_loss:.6f}') |
|
|
103 |
|
|
|
104 |
# <<<<< EVAL >>>>> |
|
|
105 |
val_loss, val_score = eval_epoch(model, val_loader, encoder, |
|
|
106 |
criterion, metric, device) |
|
|
107 |
print(f'Val: {val_score:.6f} | {val_loss:.6f}', end='\n\n') |
|
|
108 |
metrics = {'train_score': train_score, |
|
|
109 |
'train_loss': train_loss, |
|
|
110 |
'val_score': val_score, |
|
|
111 |
'val_loss': val_loss, |
|
|
112 |
'lr': scheduler.get_last_lr()[-1]} |
|
|
113 |
|
|
|
114 |
if use_wandb: # log metrics to wandb |
|
|
115 |
wandb.log(metrics) |
|
|
116 |
|
|
|
117 |
# saving best weights |
|
|
118 |
if val_loss < best_val_loss: |
|
|
119 |
best_val_loss = val_loss |
|
|
120 |
best_state_dict = model.state_dict() |
|
|
121 |
torch.save(best_state_dict, os.path.join('checkpoints', model_name + '.pth')) |
|
|
122 |
|
|
|
123 |
# weapon counter over-fitting |
|
|
124 |
if train_loss < last_train_loss and val_loss > last_val_loss: |
|
|
125 |
early_stopping_flag += 1 |
|
|
126 |
if early_stopping_flag == max_early_stopping: |
|
|
127 |
print('<<< EarlyStopping >>>') |
|
|
128 |
break |
|
|
129 |
|
|
|
130 |
last_train_loss = train_loss |
|
|
131 |
last_val_loss = val_loss |
|
|
132 |
|
|
|
133 |
# loading best weights |
|
|
134 |
model.load_state_dict(best_state_dict) |
|
|
135 |
|
|
|
136 |
if use_wandb: |
|
|
137 |
wandb.finish() |
|
|
138 |
return model |