a b/loss/IoU.py
1
import torch
2
from .utils import *
3
import numpy as np
4
5
6
def IoU_loss(input, target, threshold=0.5):
7
    """
8
    2d dice loss
9
    :param input: predict tensor
10
    :param target: target tensor
11
    :return: scalar loss value
12
    """
13
    
14
    input = input > 0.5
15
    target = target == torch.max(target)
16
17
    input = to_float_and_cuda(input)
18
    target = to_float_and_cuda(target)
19
    num = input * target
20
    num = torch.sum(num, dim=2)
21
    num = torch.sum(num, dim=2)
22
23
    den1 = input * input
24
    den1 = torch.sum(den1, dim=2)
25
    den1 = torch.sum(den1, dim=2)
26
27
    den2 = target * target
28
    den2 = torch.sum(den2, dim=2)
29
    den2 = torch.sum(den2, dim=2)
30
31
    iou =  num / (den1 + den2 - num) + 1e-6
32
    iou_total = 1 - 1 * torch.sum(iou) / iou.size(0)  # divide by batchsize
33
    return iou_total