a b/libs/network/train_functions.py
1
import torch
2
import torch.nn as nn
3
import torch.nn.functional as F
4
5
import numpy as np
6
import math
7
import cv2
8
from collections import namedtuple
9
from utils.metrics import dice, cal_hausdorff_distance
10
from utils.vis_utils import batchToColorImg, masks_to_contours
11
12
def model_fn_decorator():
13
    ModelReturn = namedtuple("ModelReturn", ["loss", "tb_dict", "disp_dict"])
14
    
15
    def model_fn(model, data, criterion, perfermance=False, vis=False, device="cuda", epoch=0, num_class=4):
16
        # imgs, gts, _ = data
17
        imgs, gts = data[:2]
18
19
        imgs = imgs.to(device)
20
        gts = torch.squeeze(gts, 1).to(device)
21
22
        net_out = model(imgs)
23
24
        loss = criterion(net_out[0], gts.long())
25
26
        tb_dict = {}
27
        disp_dict = {}
28
        tb_dict.update({"loss": loss.item()})
29
        disp_dict.update({"loss": loss.item()})
30
31
        if perfermance:
32
            gts_ = gts.unsqueeze(1)
33
            
34
            net_out = F.softmax(net_out[0], dim=1)
35
            _, preds = torch.max(net_out, 1)
36
            preds = preds.unsqueeze(1)
37
            cal_perfer(make_one_hot(preds, num_class), make_one_hot(gts_, num_class), tb_dict)
38
        
39
40
        return ModelReturn(loss, tb_dict, disp_dict)
41
    
42
    return model_fn
43
44
def model_DF_decorator():
45
    ModelReturn = namedtuple("ModelReturn", ["loss", "tb_dict", "disp_dict"])
46
    
47
    def model_fn(model, data, criterion=None, perfermance=False, vis=False, device="cuda", epoch=0, num_class=4):
48
        imgs, gts = data[:2]
49
        gts_df, dist_maps = data[2:]
50
51
        imgs = imgs.to(device)
52
        gts = torch.squeeze(gts, 1).to(device).long()
53
        gts_df = gts_df.to(device)
54
55
        net_out = model(imgs)
56
        seg_out, df_out = net_out[:2]
57
58
        # add Auxiliary Segmentation
59
        if len(net_out) >= 3 and net_out[2] is not None:
60
            auxseg_out = net_out[2]
61
            auxseg_loss = F.cross_entropy(auxseg_out, gts)
62
        else:
63
            auxseg_loss =  torch.tensor([0.], dtype=torch.float32, device=device)
64
65
66
        # loss = criterion(net_out, gts.long())
67
        # segmentation Loss
68
        seg_loss = F.cross_entropy(seg_out, gts)
69
70
        # direction field Loss
71
        df_loss, boundary_loss = criterion(seg_out, dist_maps, df_out, gts_df, gts)
72
73
        alpha = 1.0 
74
        loss = alpha*(seg_loss + 1. * df_loss + 0.1*auxseg_loss) + (1.-alpha)*boundary_loss
75
76
        tb_dict = {}
77
        disp_dict = {}
78
        tb_dict.update({"loss": loss.item(), "seg_loss": alpha*seg_loss.item(), "df_loss": alpha*1.*df_loss.item(),
79
                        "boundary_loss": (1.-alpha)*boundary_loss.item(), "auxseg_loss": alpha*0.1*auxseg_loss.item()})
80
        disp_dict.update({"loss": loss.item()})
81
82
        if perfermance:
83
            gts_ = gts.unsqueeze(1)
84
            
85
            seg_out = F.softmax(seg_out, dim=1)
86
            _, preds = torch.max(seg_out, 1)
87
            preds = preds.unsqueeze(1)
88
            cal_perfer(make_one_hot(preds, num_class), make_one_hot(gts_, num_class), tb_dict)
89
90
        if vis:
91
            # 可视化 方向场
92
            # vis_dict = {}
93
            gt_df = gts_df.cpu().numpy()
94
            _, angle_gt = cv2.cartToPolar(gt_df[:, 0,...], gt_df[:, 1,...])
95
            angle_gt = batchToColorImg(angle_gt, minv=0, maxv=2*math.pi).transpose(0, 3, 1, 2)
96
97
            df_map = df_out.cpu().numpy()
98
            mag, angle_df = cv2.cartToPolar(df_map[:, 0,...], df_map[:, 1,...])
99
            angle_df = batchToColorImg(angle_df, minv=0, maxv=2*math.pi).transpose(0, 3, 1, 2)
100
            mag = batchToColorImg(mag).transpose(0, 3, 1, 2)
101
102
            tb_dict.update({"vis": [angle_gt, mag, angle_df]})
103
104
105
        return ModelReturn(loss, tb_dict, disp_dict)
106
    
107
    return model_fn
108
109
110
111
def cal_perfer(preds, masks, tb_dict):
112
    LV_dice = []  # 1
113
    MYO_dice = []  # 2
114
    RV_dice = []  # 3
115
    LV_hausdorff = []
116
    MYO_hausdorff = []
117
    RV_hausdorff = []
118
119
    for i in range(preds.shape[0]):
120
        LV_dice.append(dice(preds[i,1,:,:],masks[i,1,:,:]))
121
        RV_dice.append(dice(preds[i, 3, :, :], masks[i, 3, :, :]))
122
        MYO_dice.append(dice(preds[i, 2, :, :], masks[i, 2, :, :]))
123
        
124
        LV_hausdorff.append(cal_hausdorff_distance(preds[i,1,:,:],masks[i,1,:,:]))
125
        RV_hausdorff.append(cal_hausdorff_distance(preds[i,3,:,:],masks[i,3,:,:]))
126
        MYO_hausdorff.append(cal_hausdorff_distance(preds[i,2,:,:],masks[i,2,:,:]))
127
    
128
    tb_dict.update({"LV_dice": np.mean(LV_dice)})
129
    tb_dict.update({"RV_dice": np.mean(RV_dice)})
130
    tb_dict.update({"MYO_dice": np.mean(MYO_dice)})
131
    tb_dict.update({"LV_hausdorff": np.mean(LV_hausdorff)})
132
    tb_dict.update({"RV_hausdorff": np.mean(RV_hausdorff)})
133
    tb_dict.update({"MYO_hausdorff": np.mean(MYO_hausdorff)})
134
135
def make_one_hot(input, num_classes):
136
    """Convert class index tensor to one hot encoding tensor.
137
    Args:
138
         input: A tensor of shape [N, 1, *]
139
         num_classes: An int of number of class
140
    Returns:
141
        A tensor of shape [N, num_classes, *]
142
    """
143
    shape = np.array(input.shape)
144
    shape[1] = num_classes
145
    shape = tuple(shape)
146
    result = torch.zeros(shape).scatter_(1, input.cpu().long(), 1)
147
    # result = result.scatter_(1, input.cpu(), 1)
148
149
    return result