Diff of /src/train_functions.py [000000] .. [cbdc43]

Switch to unified view

a b/src/train_functions.py
1
import torch
2
import os
3
from utils import get_scheduler, get_model, get_criterion, get_optimizer, get_metric
4
from data_functions import get_loaders
5
from utils import OneHotEncoder
6
import wandb
7
from tqdm import tqdm
8
import time
9
10
11
# helping function to normal visualisation in Colaboratory
12
def foo_():
13
    time.sleep(0.3)
14
15
16
def train_epoch(model, train_dl, encoder, criterion, metric, optimizer, scheduler, device):
17
    model.train()
18
    loss_sum = 0
19
    score_sum = 0
20
    with tqdm(total=len(train_dl), position=0, leave=True) as pbar:
21
        for X, y in tqdm(train_dl, position=0, leave=True):
22
            pbar.update()
23
            X = X.to(device)
24
            if len(torch.unique(X)) == 1:
25
                continue
26
            if encoder is not None:
27
                y = encoder(y)
28
            y = y.squeeze(4)
29
            y = y.to(device)
30
31
            optimizer.zero_grad()
32
            output = model(X)
33
            loss = criterion(output, y)
34
            loss.backward()
35
            optimizer.step()
36
            scheduler.step()
37
38
            loss = loss.item()
39
            score = metric(output, y).mean().item()
40
            loss_sum += loss
41
            score_sum += score
42
    return loss_sum / len(train_dl), score_sum / len(train_dl)
43
44
45
def eval_epoch(model, val_dl, encoder, criterion, metric, device):
46
    model.eval()
47
    loss_sum = 0
48
    score_sum = 0
49
    with tqdm(total=len(val_dl), position=0, leave=True) as pbar:
50
        for X, y in tqdm(val_dl, position=0, leave=True):
51
            pbar.update()
52
            X = X.to(device)
53
            if len(torch.unique(X)) == 1:
54
                continue
55
            if encoder is not None:
56
                y = encoder(y)
57
            y = y.squeeze()
58
            y = y.to(device)
59
60
            with torch.no_grad():
61
                output = model(X)
62
                loss = criterion(output, y).item()
63
                score = metric(output, y).mean().item()
64
                loss_sum += loss
65
                score_sum += score
66
    return loss_sum / len(val_dl), score_sum / len(val_dl)
67
68
69
def run(cfg, model_name, use_wandb=True, max_early_stopping=2):
70
    torch.cuda.empty_cache()
71
72
    # <<<<< SETUP >>>>>
73
    train_loader, val_loader = get_loaders(cfg)
74
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
75
76
    model = get_model(cfg)(cfg=cfg).to(device)
77
    optimizer = get_optimizer(cfg)(model.parameters(), **cfg.optimizer_params)
78
    scheduler = get_scheduler(cfg)(optimizer, **cfg.scheduler_params)
79
    metric = get_metric(cfg)(**cfg.metric_params)
80
    criterion = get_criterion(cfg)(**cfg.criterion_params)
81
    encoder = OneHotEncoder(cfg)
82
83
    # wandb is watching
84
    if use_wandb:
85
        wandb.init(project='Covid19_CT_segmentation_' + str(cfg.dataset_name), entity='aiijcteamname', config=cfg,
86
                   name=model_name)
87
        wandb.watch(model, log_freq=100)
88
89
    best_val_loss = 999
90
    last_train_loss = 0
91
    last_val_loss = 999
92
    early_stopping_flag = 0
93
    best_state_dict = model.state_dict()
94
    for epoch in range(1, cfg.epochs + 1):
95
        print(f'Epoch #{epoch}')
96
97
        # <<<<< TRAIN >>>>>
98
        train_loss, train_score = train_epoch(model, train_loader, encoder,
99
                                              criterion, metric,
100
                                              optimizer, scheduler, device)
101
        print('      Score    |    Loss')
102
        print(f'Train: {train_score:.6f} | {train_loss:.6f}')
103
104
        # <<<<< EVAL >>>>>
105
        val_loss, val_score = eval_epoch(model, val_loader, encoder,
106
                                         criterion, metric, device)
107
        print(f'Val: {val_score:.6f} | {val_loss:.6f}', end='\n\n')
108
        metrics = {'train_score': train_score,
109
                   'train_loss': train_loss,
110
                   'val_score': val_score,
111
                   'val_loss': val_loss,
112
                   'lr': scheduler.get_last_lr()[-1]}
113
114
        if use_wandb:  # log metrics to wandb
115
            wandb.log(metrics)
116
117
        # saving best weights
118
        if val_loss < best_val_loss:
119
            best_val_loss = val_loss
120
            best_state_dict = model.state_dict()
121
            torch.save(best_state_dict, os.path.join('checkpoints', model_name + '.pth'))
122
123
        # weapon counter over-fitting
124
        if train_loss < last_train_loss and val_loss > last_val_loss:
125
            early_stopping_flag += 1
126
        if early_stopping_flag == max_early_stopping:
127
            print('<<< EarlyStopping >>>')
128
            break
129
130
        last_train_loss = train_loss
131
        last_val_loss = val_loss
132
133
    # loading best weights
134
    model.load_state_dict(best_state_dict)
135
136
    if use_wandb:
137
        wandb.finish()
138
    return model