--- a +++ b/src/train_functions.py @@ -0,0 +1,138 @@ +import torch +import os +from utils import get_scheduler, get_model, get_criterion, get_optimizer, get_metric +from data_functions import get_loaders +from utils import OneHotEncoder +import wandb +from tqdm import tqdm +import time + + +# helping function to normal visualisation in Colaboratory +def foo_(): + time.sleep(0.3) + + +def train_epoch(model, train_dl, encoder, criterion, metric, optimizer, scheduler, device): + model.train() + loss_sum = 0 + score_sum = 0 + with tqdm(total=len(train_dl), position=0, leave=True) as pbar: + for X, y in tqdm(train_dl, position=0, leave=True): + pbar.update() + X = X.to(device) + if len(torch.unique(X)) == 1: + continue + if encoder is not None: + y = encoder(y) + y = y.squeeze(4) + y = y.to(device) + + optimizer.zero_grad() + output = model(X) + loss = criterion(output, y) + loss.backward() + optimizer.step() + scheduler.step() + + loss = loss.item() + score = metric(output, y).mean().item() + loss_sum += loss + score_sum += score + return loss_sum / len(train_dl), score_sum / len(train_dl) + + +def eval_epoch(model, val_dl, encoder, criterion, metric, device): + model.eval() + loss_sum = 0 + score_sum = 0 + with tqdm(total=len(val_dl), position=0, leave=True) as pbar: + for X, y in tqdm(val_dl, position=0, leave=True): + pbar.update() + X = X.to(device) + if len(torch.unique(X)) == 1: + continue + if encoder is not None: + y = encoder(y) + y = y.squeeze() + y = y.to(device) + + with torch.no_grad(): + output = model(X) + loss = criterion(output, y).item() + score = metric(output, y).mean().item() + loss_sum += loss + score_sum += score + return loss_sum / len(val_dl), score_sum / len(val_dl) + + +def run(cfg, model_name, use_wandb=True, max_early_stopping=2): + torch.cuda.empty_cache() + + # <<<<< SETUP >>>>> + train_loader, val_loader = get_loaders(cfg) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + model = get_model(cfg)(cfg=cfg).to(device) + optimizer = get_optimizer(cfg)(model.parameters(), **cfg.optimizer_params) + scheduler = get_scheduler(cfg)(optimizer, **cfg.scheduler_params) + metric = get_metric(cfg)(**cfg.metric_params) + criterion = get_criterion(cfg)(**cfg.criterion_params) + encoder = OneHotEncoder(cfg) + + # wandb is watching + if use_wandb: + wandb.init(project='Covid19_CT_segmentation_' + str(cfg.dataset_name), entity='aiijcteamname', config=cfg, + name=model_name) + wandb.watch(model, log_freq=100) + + best_val_loss = 999 + last_train_loss = 0 + last_val_loss = 999 + early_stopping_flag = 0 + best_state_dict = model.state_dict() + for epoch in range(1, cfg.epochs + 1): + print(f'Epoch #{epoch}') + + # <<<<< TRAIN >>>>> + train_loss, train_score = train_epoch(model, train_loader, encoder, + criterion, metric, + optimizer, scheduler, device) + print(' Score | Loss') + print(f'Train: {train_score:.6f} | {train_loss:.6f}') + + # <<<<< EVAL >>>>> + val_loss, val_score = eval_epoch(model, val_loader, encoder, + criterion, metric, device) + print(f'Val: {val_score:.6f} | {val_loss:.6f}', end='\n\n') + metrics = {'train_score': train_score, + 'train_loss': train_loss, + 'val_score': val_score, + 'val_loss': val_loss, + 'lr': scheduler.get_last_lr()[-1]} + + if use_wandb: # log metrics to wandb + wandb.log(metrics) + + # saving best weights + if val_loss < best_val_loss: + best_val_loss = val_loss + best_state_dict = model.state_dict() + torch.save(best_state_dict, os.path.join('checkpoints', model_name + '.pth')) + + # weapon counter over-fitting + if train_loss < last_train_loss and val_loss > last_val_loss: + early_stopping_flag += 1 + if early_stopping_flag == max_early_stopping: + print('<<< EarlyStopping >>>') + break + + last_train_loss = train_loss + last_val_loss = val_loss + + # loading best weights + model.load_state_dict(best_state_dict) + + if use_wandb: + wandb.finish() + return model