Diff of /trainer.py [000000] .. [70e190]

Switch to unified view

a b/trainer.py
1
import torch
2
import torch.nn as nn
3
from collections import OrderedDict
4
from tqdm import tqdm
5
from utils import write_csv
6
import sys
7
8
def trainer(config, train_loader, optimizer, model, ce, dice, iou, hd):        
9
    model.train()
10
    steps = len(train_loader)
11
    pbar = tqdm(total=steps)
12
    
13
    total_ce_loss, total_dice_score, total_dice_loss, \
14
    total_iou_score, total_iou_loss, total_loss, total_hausdorff = 0.0,0.0,0.0,0.0,0.0,0.0,0.0
15
    
16
    for iter, (input, target) in tqdm(enumerate(train_loader)):
17
        sys.stdout.write(f"\riter: {iter+1}/{steps}")
18
        sys.stdout.flush()
19
        
20
        input = input.unsqueeze(1).cuda()
21
        target = target.cuda()
22
        logits, _, _, _ = model(input)
23
        
24
        ce_loss = ce(logits, target)
25
        dice_score, dice_loss, class_dice_score, class_dice_loss = dice(logits, target)
26
        iou_score, class_iou = iou(logits, target)
27
        hausdorff = hd(logits, target)
28
        loss = dice_loss*0.4 + (1 - iou_score)*0.6
29
        
30
        total_ce_loss += ce_loss.item()
31
        total_dice_score += dice_score.item()
32
        total_dice_loss += dice_loss.item()
33
        total_iou_score += iou_score.item()
34
        total_iou_loss += 1.0-iou_score.item()
35
        total_hausdorff += hausdorff
36
        total_loss += loss.item()
37
            
38
        write_csv(f'outputs/{config.name}/iter_log.csv', [
39
            ce_loss.item(),
40
            dice_score.item(),
41
            dice_loss.item(),
42
            iou_score.item(),
43
            1.0-iou_score.item(),
44
            hausdorff,
45
            loss.item()
46
        ])
47
        
48
        write_csv(f'outputs/{config.name}/ds_class_iter.csv', class_dice_score)
49
        write_csv(f'outputs/{config.name}/dl_loss_iter.csv', class_dice_loss)
50
        write_csv(f'outputs/{config.name}/iou_class_iter.csv', class_iou)
51
                                        
52
        optimizer.zero_grad()
53
        loss.backward()
54
        optimizer.step()
55
        pbar.update(1)
56
    
57
    pbar.close()           
58
    return OrderedDict([
59
        ('ce_loss', total_ce_loss / steps),
60
        ('dice_score', total_dice_score / steps),
61
        ('dice_loss', total_dice_loss / steps),
62
        ('iou_score', total_iou_score / steps),
63
        ('iou_loss', total_iou_loss / steps),
64
        ('hausdorff', total_hausdorff / steps),
65
        ('loss', total_loss / steps)
66
    ])
67
    
68
    
69
def validate(config, val_loader, model, ce, dice, iou, hd):
70
    model.eval()
71
    steps = len(val_loader)
72
    
73
    total_ce_loss, total_dice_score, total_dice_loss, \
74
    total_iou_score, total_iou_loss, total_loss, total_hausdorff = 0.0,0.0,0.0,0.0,0.0,0.0,0.0
75
    
76
    with torch.no_grad():
77
        for input, target in val_loader:
78
            input = input.unsqueeze(1).cuda()
79
            target = target.cuda()
80
            logits, _, _, _ = model(input)
81
        
82
            ce_loss = ce(logits, target)
83
            dice_score, dice_loss, _, _ = dice(logits, target)
84
            iou_score, _ = iou(logits, target)
85
            hausdorff = hd(logits, target)
86
            loss = dice_loss*0.4 + (1 - iou_score)*0.6
87
            
88
            total_ce_loss += ce_loss.item()
89
            total_dice_score += dice_score.item()
90
            total_dice_loss += dice_loss.item()
91
            total_iou_score += iou_score.item()
92
            total_iou_loss += 1.0-iou_score.item()
93
            total_hausdorff += hausdorff
94
            total_loss += loss.item()
95
            
96
    return OrderedDict([
97
        ('ce_loss', total_ce_loss / steps),
98
        ('dice_score', total_dice_score / steps),
99
        ('dice_loss', total_dice_loss / steps),
100
        ('iou_score', total_iou_score / steps),
101
        ('iou_loss', total_iou_loss / steps),
102
        ('hausdorff', total_hausdorff / steps),
103
        ('loss', total_loss / steps)
104
    ])