a b/libs/losses/mag_angle_loss.py
1
import torch
2
import torch.nn as nn
3
import math
4
5
6
def cart2polar(coord):
7
    """ coord: (N, 2, ...)
8
    """
9
    x = coord[:, 0, ...]
10
    y = coord[:, 1, ...]
11
12
    theta = torch.atan(y / (x + 1e-12))
13
14
    theta = theta + (x < 0).to(coord.dtype) * math.pi
15
    theta = theta + ((x > 0).to(coord.dtype) * (y < 0).to(coord.dtype)) * 2 * math.pi
16
    return theta / (2 * math.pi)
17
18
class EuclideanAngleLossWithOHEM(nn.Module):
19
    def __init__(self, npRatio=3):
20
        super(EuclideanAngleLossWithOHEM, self).__init__()
21
        self.npRatio = npRatio
22
    
23
    def __cal_weight(self, gt):
24
        _, H, W = gt.shape  # N=1
25
        labels = torch.unique(gt, sorted=True)[1:]
26
        weight = torch.zeros((H, W), dtype=torch.float, device=gt.device)
27
        posRegion = gt[0, ...] > 0
28
        posCount = torch.sum(posRegion)
29
        if posCount != 0:
30
            segRemain = 0
31
            for segi in labels:
32
                overlap_segi = gt[0, ...] == segi
33
                overlapCount_segi = torch.sum(overlap_segi)
34
                if overlapCount_segi == 0: continue
35
                segRemain = segRemain + 1
36
            segAve = float(posCount) / segRemain
37
            for segi in labels:
38
                overlap_segi = gt[0, ...] == segi
39
                overlapCount_segi = torch.sum(overlap_segi, dtype=torch.float)
40
                if overlapCount_segi == 0: continue
41
                pixAve = segAve / overlapCount_segi
42
                weight = weight * (~overlap_segi).to(torch.float) + pixAve * overlap_segi.to(torch.float)
43
        # weight = weight[None]
44
        return weight
45
46
    def forward(self, pred, gt_df, gt, weight=None):
47
        """ pred: (N, C, H, W)
48
            gt_df: (N, C, H, W)
49
            gt: (N, 1, H, W)
50
        """
51
        # L1 and L2 distance
52
        N, _, H, W = pred.shape
53
        distL1 = pred - gt_df
54
        distL2 = distL1 ** 2
55
56
        theta_p = cart2polar(pred)
57
        theta_g = cart2polar(gt_df)
58
        angleDistL1 = theta_g - theta_p
59
60
61
        if weight is None:
62
            weight = torch.zeros((N, H, W), device=pred.device)
63
            for i in range(N):
64
                weight[i] = self.__cal_weight(gt[i])
65
66
        # the amount of positive and negtive pixels
67
        regionPos = (weight > 0).to(torch.float)
68
        regionNeg = (weight == 0).to(torch.float)
69
        sumPos = torch.sum(regionPos, dim=(1,2))  # (N,)
70
        sumNeg = torch.sum(regionNeg, dim=(1,2))
71
72
        # the amount of hard negative pixels
73
        sumhardNeg = torch.min(self.npRatio * sumPos, sumNeg).to(torch.int)  # (N,)
74
75
        # angle loss on ~(top - sumhardNeg) negative pixels to 0
76
        angleLossNeg = (angleDistL1 ** 2) * regionNeg
77
        angleLossNegFlat = torch.flatten(angleLossNeg, start_dim=1)  # (N, ...)
78
79
80
        # set loss on ~(top - sumhardNeg) negative pixels to 0
81
        lossNeg = (distL2[:,0,...] + distL2[:, 1, ...]) * regionNeg
82
        lossFlat = torch.flatten(lossNeg, start_dim=1)  # (N, ...)
83
        
84
        # l2-norm distance and angle distance
85
        lossFlat = lossFlat + angleLossNegFlat
86
        arg = torch.argsort(lossFlat, dim=1)
87
        for i in range(N):
88
            lossFlat[i, arg[i, :-sumhardNeg[i]]] = 0
89
        lossHard = lossFlat.view(lossNeg.shape)
90
91
        # weight for positive and negative pixels
92
        weightPos = torch.zeros_like(gt, dtype=pred.dtype)
93
        weightNeg = torch.zeros_like(gt, dtype=pred.dtype)
94
95
        weightPos = weight.clone()
96
97
        weightNeg[:,0,...] = (lossHard != 0).to(torch.float32)
98
99
        # total loss
100
        total_loss = torch.sum(((distL2[:,0,...] + distL2[:, 1, ...]) + angleDistL1 ** 2) *
101
                               (weightPos + weightNeg)) / N / 2. / torch.sum(weightPos + weightNeg)
102
103
        return total_loss
104
105
106
107
if __name__ == "__main__":
108
    import os
109
    import torch.nn as nn
110
    import torch.optim as optim
111
112
    os.environ["CUDA_VISIBLE_DEVICES"] = "4"
113
    criterion = EuclideanAngleLossWithOHEM()
114
115
    # models = nn.Sequential(nn.Conv2d(2, 2, 1),
116
    #                        nn.ReLU())
117
    # models.to(device="cuda")
118
119
    # epoch_n = 200
120
    # learning_rate = 1e-4
121
122
    # optimizer = optim.Adam(params=models.parameters(), lr=learning_rate)
123
124
    # for i in range(100):
125
    #     pred = torch.randn((32, 2, 224, 224)).cuda()
126
    #     gt_df = torch.randn((32, 2, 224, 224)).cuda()
127
    #     gt = torch.randint(0, 4, (32, 1, 224, 224)).cuda()
128
129
    #     pred = models(gt_df)
130
    #     loss = criterion(pred, gt_df, gt)
131
    #     optimizer.zero_grad()
132
    #     loss.backward()
133
    #     optimizer.step()
134
        
135
    #     print("{:6} loss:{}".format(i, loss))
136
137
    for i in range(100):
138
        pred = torch.randn((32, 2, 224, 224)).cuda()
139
        gt_df = torch.randn((32, 2, 224, 224)).cuda()
140
        gt = torch.randint(0, 4, (32, 1, 224, 224)).cuda()
141
142
        loss = criterion(-gt_df, gt_df, gt)
143
        print("{:6} loss:{}".format(i, loss))