[c1a3f2]: / trainer.py

Download this file

104 lines (87 with data), 3.8 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
import torch
import torch.nn as nn
from collections import OrderedDict
from tqdm import tqdm
from utils import write_csv
import sys
def trainer(config, train_loader, optimizer, model, ce, dice, iou, hd):
model.train()
steps = len(train_loader)
pbar = tqdm(total=steps)
total_ce_loss, total_dice_score, total_dice_loss, \
total_iou_score, total_iou_loss, total_loss, total_hausdorff = 0.0,0.0,0.0,0.0,0.0,0.0,0.0
for iter, (input, target) in tqdm(enumerate(train_loader)):
sys.stdout.write(f"\riter: {iter+1}/{steps}")
sys.stdout.flush()
input = input.unsqueeze(1).cuda()
target = target.cuda()
logits, _, _, _ = model(input)
ce_loss = ce(logits, target)
dice_score, dice_loss, class_dice_score, class_dice_loss = dice(logits, target)
iou_score, class_iou = iou(logits, target)
hausdorff = hd(logits, target)
loss = dice_loss*0.4 + (1 - iou_score)*0.6
total_ce_loss += ce_loss.item()
total_dice_score += dice_score.item()
total_dice_loss += dice_loss.item()
total_iou_score += iou_score.item()
total_iou_loss += 1.0-iou_score.item()
total_hausdorff += hausdorff
total_loss += loss.item()
write_csv(f'outputs/{config.name}/iter_log.csv', [
ce_loss.item(),
dice_score.item(),
dice_loss.item(),
iou_score.item(),
1.0-iou_score.item(),
hausdorff,
loss.item()
])
write_csv(f'outputs/{config.name}/ds_class_iter.csv', class_dice_score)
write_csv(f'outputs/{config.name}/dl_loss_iter.csv', class_dice_loss)
write_csv(f'outputs/{config.name}/iou_class_iter.csv', class_iou)
optimizer.zero_grad()
loss.backward()
optimizer.step()
pbar.update(1)
pbar.close()
return OrderedDict([
('ce_loss', total_ce_loss / steps),
('dice_score', total_dice_score / steps),
('dice_loss', total_dice_loss / steps),
('iou_score', total_iou_score / steps),
('iou_loss', total_iou_loss / steps),
('hausdorff', total_hausdorff / steps),
('loss', total_loss / steps)
])
def validate(config, val_loader, model, ce, dice, iou, hd):
model.eval()
steps = len(val_loader)
total_ce_loss, total_dice_score, total_dice_loss, \
total_iou_score, total_iou_loss, total_loss, total_hausdorff = 0.0,0.0,0.0,0.0,0.0,0.0,0.0
with torch.no_grad():
for input, target in val_loader:
input = input.unsqueeze(1).cuda()
target = target.cuda()
logits, _, _, _ = model(input)
ce_loss = ce(logits, target)
dice_score, dice_loss, _, _ = dice(logits, target)
iou_score, _ = iou(logits, target)
hausdorff = hd(logits, target)
loss = dice_loss*0.4 + (1 - iou_score)*0.6
total_ce_loss += ce_loss.item()
total_dice_score += dice_score.item()
total_dice_loss += dice_loss.item()
total_iou_score += iou_score.item()
total_iou_loss += 1.0-iou_score.item()
total_hausdorff += hausdorff
total_loss += loss.item()
return OrderedDict([
('ce_loss', total_ce_loss / steps),
('dice_score', total_dice_score / steps),
('dice_loss', total_dice_loss / steps),
('iou_score', total_iou_score / steps),
('iou_loss', total_iou_loss / steps),
('hausdorff', total_hausdorff / steps),
('loss', total_loss / steps)
])