a b/libs/losses/df_loss.py
1
import torch
2
import torch.nn as nn
3
import torch.nn.functional as F
4
5
import numpy as np
6
7
def LossSegDF(net_ret, data, device="cuda"):
8
    net_out, df_out = net_ret
9
10
    _, gts, gts_df = data
11
    gts = torch.squeeze(gts, 1).to(device).long()
12
    gts_df = gts_df.to(device).long()
13
14
    # segmentation Loss
15
    seg_loss = F.cross_entropy(net_out, gts)
16
17
    # direction field Loss
18
    df_loss = F.mse_loss(df_out, gts_df)
19
20
    total_loss = seg_loss + df_loss
21
    return total_loss
22
23
class EuclideanLossWithOHEM(nn.Module):
24
    def __init__(self, npRatio=3):
25
        super(EuclideanLossWithOHEM, self).__init__()
26
        self.npRatio = npRatio
27
    
28
    def __cal_weight(self, gt):
29
        _, H, W = gt.shape  # N=1
30
        labels = torch.unique(gt, sorted=True)[1:]
31
        weight = torch.zeros((H, W), dtype=torch.float, device=gt.device)
32
        posRegion = gt[0, ...] > 0
33
        posCount = torch.sum(posRegion)
34
        if posCount != 0:
35
            segRemain = 0
36
            for segi in labels:
37
                overlap_segi = gt[0, ...] == segi
38
                overlapCount_segi = torch.sum(overlap_segi)
39
                if overlapCount_segi == 0: continue
40
                segRemain = segRemain + 1
41
            segAve = float(posCount) / segRemain
42
            for segi in labels:
43
                overlap_segi = gt[0, ...] == segi
44
                overlapCount_segi = torch.sum(overlap_segi, dtype=torch.float)
45
                if overlapCount_segi == 0: continue
46
                pixAve = segAve / overlapCount_segi
47
                weight = weight * (~overlap_segi).to(torch.float) + pixAve * overlap_segi.to(torch.float)
48
        # weight = weight[None]
49
        return weight
50
51
    def forward(self, pred, gt_df, gt, weight=None):
52
        """ pred: (N, C, H, W)
53
            gt_df: (N, C, H, W)
54
            gt: (N, 1, H, W)
55
        """
56
        # L1 and L2 distance
57
        N, _, H, W = pred.shape
58
        distL1 = pred - gt_df
59
        distL2 = distL1 ** 2
60
61
        if weight is None:
62
            weight = torch.zeros((N, H, W), device=pred.device)
63
            for i in range(N):
64
                weight[i] = self.__cal_weight(gt[i])
65
66
        # the amount of positive and negtive pixels
67
        regionPos = (weight > 0).to(torch.float)
68
        regionNeg = (weight == 0).to(torch.float)
69
        sumPos = torch.sum(regionPos, dim=(1,2))  # (N,)
70
        sumNeg = torch.sum(regionNeg, dim=(1,2))
71
72
        # the amount of hard negative pixels
73
        sumhardNeg = torch.min(self.npRatio * sumPos, sumNeg).to(torch.int)  # (N,)
74
75
        # set loss on ~(top - sumhardNeg) negative pixels to 0
76
        lossNeg = (distL2[:,0,...] + distL2[:, 1, ...]) * regionNeg
77
        lossFlat = torch.flatten(lossNeg, start_dim=1)  # (N, ...)
78
        arg = torch.argsort(lossFlat, dim=1)
79
        for i in range(N):
80
            lossFlat[i, arg[i, :-sumhardNeg[i]]] = 0
81
        lossHard = lossFlat.view(lossNeg.shape)
82
83
        # weight for positive and negative pixels
84
        weightPos = torch.zeros_like(pred)
85
        weightNeg = torch.zeros_like(pred)
86
87
        weightPos = torch.stack([weight, weight], dim=1)
88
89
        weightNeg[:,0,...] = (lossHard != 0).to(torch.float32)
90
        weightNeg[:,1,...] = (lossHard != 0).to(torch.float32)
91
92
        # total loss
93
        total_loss = torch.sum((distL1 ** 2) * (weightPos + weightNeg)) / N / 2. / torch.sum(weightPos + weightNeg)
94
95
        return total_loss
96
97
98
99
if __name__ == "__main__":
100
    import os
101
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
102
    criterion = EuclideanLossWithOHEM()
103
    for i in range(100):
104
        pred = torch.randn((32, 2, 224, 224)).cuda()
105
        gt_df = torch.randn((32, 2, 224, 224)).cuda()
106
        gt = torch.randint(0, 4, (32, 1, 224, 224)).cuda()
107
108
        loss = criterion(100*gt_df, gt_df, gt)
109
        print("{:6} loss:{}".format(i, loss))