[cf6a9e]: / loss / IoU.py

Download this file

34 lines (26 with data), 797 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
from .utils import *
import numpy as np
def IoU_loss(input, target, threshold=0.5):
"""
2d dice loss
:param input: predict tensor
:param target: target tensor
:return: scalar loss value
"""
input = input > 0.5
target = target == torch.max(target)
input = to_float_and_cuda(input)
target = to_float_and_cuda(target)
num = input * target
num = torch.sum(num, dim=2)
num = torch.sum(num, dim=2)
den1 = input * input
den1 = torch.sum(den1, dim=2)
den1 = torch.sum(den1, dim=2)
den2 = target * target
den2 = torch.sum(den2, dim=2)
den2 = torch.sum(den2, dim=2)
iou = num / (den1 + den2 - num) + 1e-6
iou_total = 1 - 1 * torch.sum(iou) / iou.size(0) # divide by batchsize
return iou_total