a | b/loss/Contrastive.py | ||
---|---|---|---|
1 | import torch.nn.functional as F |
||
2 | from torch import nn |
||
3 | import torch |
||
4 | |||
5 | def ContrastiveLoss(output1, output2): |
||
6 | euclidean_distance = F.pairwise_distance(output1, output2) |
||
7 | loss_contrastive = torch.mean(torch.pow(euclidean_distance, 2)) |
||
8 | return loss_contrastive |