|
a |
|
b/losses/consistency_losses.py |
|
|
1 |
import matplotlib.pyplot as plt |
|
|
2 |
import segmentation_models_pytorch.utils.losses as vanilla_losses |
|
|
3 |
import torch |
|
|
4 |
import torch.nn as nn |
|
|
5 |
import numpy as np |
|
|
6 |
|
|
|
7 |
|
|
|
8 |
class ConsistencyLoss(nn.Module): |
|
|
9 |
""" |
|
|
10 |
|
|
|
11 |
""" |
|
|
12 |
|
|
|
13 |
def __init__(self, adaptive=True): |
|
|
14 |
super(ConsistencyLoss, self).__init__() |
|
|
15 |
self.epsilon = 1e-5 |
|
|
16 |
self.adaptive = adaptive |
|
|
17 |
self.jaccard = vanilla_losses.JaccardLoss() |
|
|
18 |
|
|
|
19 |
def forward(self, new_mask, old_mask, new_seg, old_seg, iou_weight=None): |
|
|
20 |
def difference(mask1, mask2): |
|
|
21 |
return mask1 * (1 - mask2) + mask2 * (1 - mask1) |
|
|
22 |
|
|
|
23 |
vanilla_jaccard = vanilla_losses.JaccardLoss()(old_seg, old_mask) |
|
|
24 |
|
|
|
25 |
perturbation_loss = torch.sum( |
|
|
26 |
difference( |
|
|
27 |
difference(new_mask, old_mask), |
|
|
28 |
difference(new_seg, old_seg)) |
|
|
29 |
) / torch.sum(torch.clamp(new_mask + old_mask + new_seg + old_seg, 0, 1) + self.epsilon) |
|
|
30 |
|
|
|
31 |
# ~bd + ~ac + b~d + a~c |
|
|
32 |
# perturbation_loss = (1 - new_mask) * new_seg + (1 - old_mask) * old_seg + new_mask * ( |
|
|
33 |
# 1 - new_seg) + old_mask * (1 - old_seg) |
|
|
34 |
if self.adaptive: |
|
|
35 |
return (1 - iou_weight) * vanilla_jaccard + iou_weight * perturbation_loss |
|
|
36 |
return 0.5 * vanilla_jaccard + 0.5 * perturbation_loss |
|
|
37 |
# return perturbation_loss |
|
|
38 |
# return self.jaccard(old_seg, old_mask) + self.jaccard(new_seg, new_mask) |
|
|
39 |
|
|
|
40 |
|
|
|
41 |
class NakedConsistencyLoss(nn.Module): |
|
|
42 |
""" |
|
|
43 |
|
|
|
44 |
""" |
|
|
45 |
|
|
|
46 |
def __init__(self): |
|
|
47 |
super(NakedConsistencyLoss, self).__init__() |
|
|
48 |
self.epsilon = 1e-5 |
|
|
49 |
|
|
|
50 |
def forward(self, new_mask, old_mask, new_seg, old_seg): |
|
|
51 |
def difference(mask1, mask2): |
|
|
52 |
return mask1 * (1 - mask2) + mask2 * (1 - mask1) |
|
|
53 |
|
|
|
54 |
perturbation_loss = torch.sum( |
|
|
55 |
difference( |
|
|
56 |
difference(new_mask, old_mask), |
|
|
57 |
difference(new_seg, old_seg)) |
|
|
58 |
) / torch.sum(torch.clamp(new_mask + old_mask + new_seg + old_seg, 0, 1) + self.epsilon) # normalizing factor |
|
|
59 |
# perturbation_loss = torch.sum(perturbation_loss) |
|
|
60 |
return perturbation_loss |
|
|
61 |
|
|
|
62 |
|
|
|
63 |
class StrictConsistencyLoss(nn.Module): |
|
|
64 |
def __init__(self, adaptive=True): |
|
|
65 |
super(StrictConsistencyLoss, self).__init__() |
|
|
66 |
self.epsilon = 1e-5 |
|
|
67 |
self.adaptive = adaptive |
|
|
68 |
self.jaccard = vanilla_losses.JaccardLoss() |
|
|
69 |
|
|
|
70 |
def forward(self, new_mask, old_mask, new_seg, old_seg, iou_weight=None): |
|
|
71 |
if iou_weight is not None: |
|
|
72 |
return iou_weight * self.jaccard(new_seg, new_mask) + (1 - iou_weight) * self.jaccard(old_seg, old_mask) |
|
|
73 |
return 0.5 * self.jaccard(new_seg, new_mask) + 0.5 * self.jaccard(old_seg, old_mask) |
|
|
74 |
# return perturbation_loss |
|
|
75 |
# return self.jaccard(old_seg, old_mask) + self.jaccard(new_seg, new_mask) |