[98e649]: / libs / losses / df_loss.py

Download this file

109 lines (87 with data), 3.8 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
def LossSegDF(net_ret, data, device="cuda"):
net_out, df_out = net_ret
_, gts, gts_df = data
gts = torch.squeeze(gts, 1).to(device).long()
gts_df = gts_df.to(device).long()
# segmentation Loss
seg_loss = F.cross_entropy(net_out, gts)
# direction field Loss
df_loss = F.mse_loss(df_out, gts_df)
total_loss = seg_loss + df_loss
return total_loss
class EuclideanLossWithOHEM(nn.Module):
def __init__(self, npRatio=3):
super(EuclideanLossWithOHEM, self).__init__()
self.npRatio = npRatio
def __cal_weight(self, gt):
_, H, W = gt.shape # N=1
labels = torch.unique(gt, sorted=True)[1:]
weight = torch.zeros((H, W), dtype=torch.float, device=gt.device)
posRegion = gt[0, ...] > 0
posCount = torch.sum(posRegion)
if posCount != 0:
segRemain = 0
for segi in labels:
overlap_segi = gt[0, ...] == segi
overlapCount_segi = torch.sum(overlap_segi)
if overlapCount_segi == 0: continue
segRemain = segRemain + 1
segAve = float(posCount) / segRemain
for segi in labels:
overlap_segi = gt[0, ...] == segi
overlapCount_segi = torch.sum(overlap_segi, dtype=torch.float)
if overlapCount_segi == 0: continue
pixAve = segAve / overlapCount_segi
weight = weight * (~overlap_segi).to(torch.float) + pixAve * overlap_segi.to(torch.float)
# weight = weight[None]
return weight
def forward(self, pred, gt_df, gt, weight=None):
""" pred: (N, C, H, W)
gt_df: (N, C, H, W)
gt: (N, 1, H, W)
"""
# L1 and L2 distance
N, _, H, W = pred.shape
distL1 = pred - gt_df
distL2 = distL1 ** 2
if weight is None:
weight = torch.zeros((N, H, W), device=pred.device)
for i in range(N):
weight[i] = self.__cal_weight(gt[i])
# the amount of positive and negtive pixels
regionPos = (weight > 0).to(torch.float)
regionNeg = (weight == 0).to(torch.float)
sumPos = torch.sum(regionPos, dim=(1,2)) # (N,)
sumNeg = torch.sum(regionNeg, dim=(1,2))
# the amount of hard negative pixels
sumhardNeg = torch.min(self.npRatio * sumPos, sumNeg).to(torch.int) # (N,)
# set loss on ~(top - sumhardNeg) negative pixels to 0
lossNeg = (distL2[:,0,...] + distL2[:, 1, ...]) * regionNeg
lossFlat = torch.flatten(lossNeg, start_dim=1) # (N, ...)
arg = torch.argsort(lossFlat, dim=1)
for i in range(N):
lossFlat[i, arg[i, :-sumhardNeg[i]]] = 0
lossHard = lossFlat.view(lossNeg.shape)
# weight for positive and negative pixels
weightPos = torch.zeros_like(pred)
weightNeg = torch.zeros_like(pred)
weightPos = torch.stack([weight, weight], dim=1)
weightNeg[:,0,...] = (lossHard != 0).to(torch.float32)
weightNeg[:,1,...] = (lossHard != 0).to(torch.float32)
# total loss
total_loss = torch.sum((distL1 ** 2) * (weightPos + weightNeg)) / N / 2. / torch.sum(weightPos + weightNeg)
return total_loss
if __name__ == "__main__":
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
criterion = EuclideanLossWithOHEM()
for i in range(100):
pred = torch.randn((32, 2, 224, 224)).cuda()
gt_df = torch.randn((32, 2, 224, 224)).cuda()
gt = torch.randint(0, 4, (32, 1, 224, 224)).cuda()
loss = criterion(100*gt_df, gt_df, gt)
print("{:6} loss:{}".format(i, loss))