--- a +++ b/losses/consistency_losses.py @@ -0,0 +1,75 @@ +import matplotlib.pyplot as plt +import segmentation_models_pytorch.utils.losses as vanilla_losses +import torch +import torch.nn as nn +import numpy as np + + +class ConsistencyLoss(nn.Module): + """ + + """ + + def __init__(self, adaptive=True): + super(ConsistencyLoss, self).__init__() + self.epsilon = 1e-5 + self.adaptive = adaptive + self.jaccard = vanilla_losses.JaccardLoss() + + def forward(self, new_mask, old_mask, new_seg, old_seg, iou_weight=None): + def difference(mask1, mask2): + return mask1 * (1 - mask2) + mask2 * (1 - mask1) + + vanilla_jaccard = vanilla_losses.JaccardLoss()(old_seg, old_mask) + + perturbation_loss = torch.sum( + difference( + difference(new_mask, old_mask), + difference(new_seg, old_seg)) + ) / torch.sum(torch.clamp(new_mask + old_mask + new_seg + old_seg, 0, 1) + self.epsilon) + + # ~bd + ~ac + b~d + a~c + # perturbation_loss = (1 - new_mask) * new_seg + (1 - old_mask) * old_seg + new_mask * ( + # 1 - new_seg) + old_mask * (1 - old_seg) + if self.adaptive: + return (1 - iou_weight) * vanilla_jaccard + iou_weight * perturbation_loss + return 0.5 * vanilla_jaccard + 0.5 * perturbation_loss + # return perturbation_loss + # return self.jaccard(old_seg, old_mask) + self.jaccard(new_seg, new_mask) + + +class NakedConsistencyLoss(nn.Module): + """ + + """ + + def __init__(self): + super(NakedConsistencyLoss, self).__init__() + self.epsilon = 1e-5 + + def forward(self, new_mask, old_mask, new_seg, old_seg): + def difference(mask1, mask2): + return mask1 * (1 - mask2) + mask2 * (1 - mask1) + + perturbation_loss = torch.sum( + difference( + difference(new_mask, old_mask), + difference(new_seg, old_seg)) + ) / torch.sum(torch.clamp(new_mask + old_mask + new_seg + old_seg, 0, 1) + self.epsilon) # normalizing factor + # perturbation_loss = torch.sum(perturbation_loss) + return perturbation_loss + + +class StrictConsistencyLoss(nn.Module): + def __init__(self, adaptive=True): + super(StrictConsistencyLoss, self).__init__() + self.epsilon = 1e-5 + self.adaptive = adaptive + self.jaccard = vanilla_losses.JaccardLoss() + + def forward(self, new_mask, old_mask, new_seg, old_seg, iou_weight=None): + if iou_weight is not None: + return iou_weight * self.jaccard(new_seg, new_mask) + (1 - iou_weight) * self.jaccard(old_seg, old_mask) + return 0.5 * self.jaccard(new_seg, new_mask) + 0.5 * self.jaccard(old_seg, old_mask) + # return perturbation_loss + # return self.jaccard(old_seg, old_mask) + self.jaccard(new_seg, new_mask)