Switch to unified view

a b/losses/consistency_losses.py
1
import matplotlib.pyplot as plt
2
import segmentation_models_pytorch.utils.losses as vanilla_losses
3
import torch
4
import torch.nn as nn
5
import numpy as np
6
7
8
class ConsistencyLoss(nn.Module):
9
    """
10
11
    """
12
13
    def __init__(self, adaptive=True):
14
        super(ConsistencyLoss, self).__init__()
15
        self.epsilon = 1e-5
16
        self.adaptive = adaptive
17
        self.jaccard = vanilla_losses.JaccardLoss()
18
19
    def forward(self, new_mask, old_mask, new_seg, old_seg, iou_weight=None):
20
        def difference(mask1, mask2):
21
            return mask1 * (1 - mask2) + mask2 * (1 - mask1)
22
23
        vanilla_jaccard = vanilla_losses.JaccardLoss()(old_seg, old_mask)
24
25
        perturbation_loss = torch.sum(
26
            difference(
27
                difference(new_mask, old_mask),
28
                difference(new_seg, old_seg))
29
        ) / torch.sum(torch.clamp(new_mask + old_mask + new_seg + old_seg, 0, 1) + self.epsilon)
30
31
        # ~bd + ~ac + b~d + a~c
32
        # perturbation_loss = (1 - new_mask) * new_seg + (1 - old_mask) * old_seg + new_mask * (
33
        #         1 - new_seg) + old_mask * (1 - old_seg)
34
        if self.adaptive:
35
            return (1 - iou_weight) * vanilla_jaccard + iou_weight * perturbation_loss
36
        return 0.5 * vanilla_jaccard + 0.5 * perturbation_loss
37
        # return perturbation_loss
38
        # return self.jaccard(old_seg, old_mask) + self.jaccard(new_seg, new_mask)
39
40
41
class NakedConsistencyLoss(nn.Module):
42
    """
43
44
    """
45
46
    def __init__(self):
47
        super(NakedConsistencyLoss, self).__init__()
48
        self.epsilon = 1e-5
49
50
    def forward(self, new_mask, old_mask, new_seg, old_seg):
51
        def difference(mask1, mask2):
52
            return mask1 * (1 - mask2) + mask2 * (1 - mask1)
53
54
        perturbation_loss = torch.sum(
55
            difference(
56
                difference(new_mask, old_mask),
57
                difference(new_seg, old_seg))
58
        ) / torch.sum(torch.clamp(new_mask + old_mask + new_seg + old_seg, 0, 1) + self.epsilon)  # normalizing factor
59
        # perturbation_loss = torch.sum(perturbation_loss)
60
        return perturbation_loss
61
62
63
class StrictConsistencyLoss(nn.Module):
64
    def __init__(self, adaptive=True):
65
        super(StrictConsistencyLoss, self).__init__()
66
        self.epsilon = 1e-5
67
        self.adaptive = adaptive
68
        self.jaccard = vanilla_losses.JaccardLoss()
69
70
    def forward(self, new_mask, old_mask, new_seg, old_seg, iou_weight=None):
71
        if iou_weight is not None:
72
            return iou_weight * self.jaccard(new_seg, new_mask) + (1 - iou_weight) * self.jaccard(old_seg, old_mask)
73
        return 0.5 * self.jaccard(new_seg, new_mask) + 0.5 * self.jaccard(old_seg, old_mask)
74
        # return perturbation_loss
75
        # return self.jaccard(old_seg, old_mask) + self.jaccard(new_seg, new_mask)