|
a |
|
b/eval/loss.py |
|
|
1 |
import torch |
|
|
2 |
import argparse |
|
|
3 |
import segmentation_models_pytorch as smp |
|
|
4 |
import torch.nn as nn |
|
|
5 |
|
|
|
6 |
|
|
|
7 |
def sim(z_i, z_j): |
|
|
8 |
norm_dot_product = None |
|
|
9 |
norm_dot_product = torch.dot(z_i, z_j)/\ |
|
|
10 |
(torch.linalg.norm(z_i) * torch.linalg.norm(z_j)) |
|
|
11 |
return norm_dot_product |
|
|
12 |
|
|
|
13 |
def sim_positive_pairs(out_left, out_right): |
|
|
14 |
"""Normalized dot product between positive pairs. |
|
|
15 |
|
|
|
16 |
Inputs: |
|
|
17 |
- out_left: NxD tensor; output of the projection head g(), left branch in SimCLR model. |
|
|
18 |
- out_right: NxD tensor; output of the projection head g(), right branch in SimCLR model. |
|
|
19 |
Each row is a z-vector for an augmented sample in the batch. |
|
|
20 |
The same row in out_left and out_right form a positive pair. |
|
|
21 |
|
|
|
22 |
Returns: |
|
|
23 |
- A Nx1 tensor; each row k is the normalized dot product between out_left[k] and out_right[k]. |
|
|
24 |
""" |
|
|
25 |
pos_pairs = None |
|
|
26 |
|
|
|
27 |
pos_pairs = torch.sum(out_left * out_right, dim = 1) / \ |
|
|
28 |
(torch.linalg.norm(out_left, dim = 1) * torch.linalg.norm(out_right, dim = 1)) |
|
|
29 |
|
|
|
30 |
pos_pairs = pos_pairs.unsqueeze(1) |
|
|
31 |
return pos_pairs |
|
|
32 |
|
|
|
33 |
|
|
|
34 |
def compute_sim_matrix(out): |
|
|
35 |
"""Compute a 2N x 2N matrix of normalized dot products between all pairs of augmented examples in a batch. |
|
|
36 |
|
|
|
37 |
Inputs: |
|
|
38 |
- out: 2N x D tensor; each row is the z-vector (output of projection head) of a single augmented example. |
|
|
39 |
There are a total of 2N augmented examples in the batch. |
|
|
40 |
|
|
|
41 |
Returns: |
|
|
42 |
- sim_matrix: 2N x 2N tensor; each element i, j in the matrix is the normalized dot product between out[i] and out[j]. |
|
|
43 |
""" |
|
|
44 |
sim_matrix = None |
|
|
45 |
|
|
|
46 |
sim_matrix = torch.mm(out, out.T) |
|
|
47 |
sim_matrix/= torch.linalg.norm(out, dim = 1).unsqueeze(1) |
|
|
48 |
sim_matrix/= torch.linalg.norm(out, dim = 1).unsqueeze(1).T |
|
|
49 |
|
|
|
50 |
return sim_matrix |
|
|
51 |
|
|
|
52 |
def simclr_loss_vectorized(out_left, out_right, tau): |
|
|
53 |
|
|
|
54 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
55 |
|
|
|
56 |
N = out_left.shape[0] |
|
|
57 |
|
|
|
58 |
# Concatenate out_left and out_right into a 2*N x D tensor. |
|
|
59 |
out = torch.cat([out_left, out_right], dim=0) # [2*N, D] |
|
|
60 |
|
|
|
61 |
# Compute similarity matrix between all pairs of augmented examples in the batch. |
|
|
62 |
sim_matrix = compute_sim_matrix(out) # [2*N, 2*N] |
|
|
63 |
|
|
|
64 |
# Step 1: Use sim_matrix to compute the denominator value for all augmented samples. |
|
|
65 |
# Hint: Compute e^{sim / tau} and store into exponential, which should have shape 2N x 2N. |
|
|
66 |
exponential = torch.exp(sim_matrix/tau) |
|
|
67 |
|
|
|
68 |
# This binary mask zeros out terms where k=i. |
|
|
69 |
mask = (torch.ones_like(exponential, device=device) - torch.eye(2 * N, device=device)).to(device).bool() |
|
|
70 |
|
|
|
71 |
#print((sim_matrix * mask).argmax(dim = 1)) |
|
|
72 |
|
|
|
73 |
# We apply the binary mask. |
|
|
74 |
exponential = exponential.masked_select(mask).view(2 * N, -1) |
|
|
75 |
|
|
|
76 |
# Hint: Compute the denominator values for all augmented samples. This should be a 2N x 1 vector. |
|
|
77 |
denom = torch.sum(exponential, dim = 1).unsqueeze(1) |
|
|
78 |
|
|
|
79 |
# Step 2: Compute similarity between positive pairs. |
|
|
80 |
|
|
|
81 |
sim_pos = sim_positive_pairs(out_left, out_right) |
|
|
82 |
|
|
|
83 |
numerator = None |
|
|
84 |
|
|
|
85 |
numerator = torch.exp(sim_pos/tau) |
|
|
86 |
|
|
|
87 |
# Step 4: Now that you have the numerator and denominator for all augmented samples, compute the total loss. |
|
|
88 |
loss = None |
|
|
89 |
|
|
|
90 |
loss = torch.sum(-torch.log(numerator/denom[:N]) - torch.log(numerator/denom[N:])) / (2*N) |
|
|
91 |
|
|
|
92 |
#print(numerator/denom[:N]) |
|
|
93 |
#print(numerator/denom[N:]) |
|
|
94 |
|
|
|
95 |
return loss |
|
|
96 |
|
|
|
97 |
class CombinedLoss(nn.Module): |
|
|
98 |
def __init__(self): |
|
|
99 |
super().__init__() |
|
|
100 |
|
|
|
101 |
def forward(self, pred, truth, aux = None, aux_features = None): |
|
|
102 |
|
|
|
103 |
if pred is not None: |
|
|
104 |
DiceLoss = smp.losses.DiceLoss(mode = 'multilabel', from_logits = True) |
|
|
105 |
BCELoss = smp.losses.SoftBCEWithLogitsLoss() |
|
|
106 |
Loss = 0.4 * DiceLoss(pred, truth) + 0.6 * BCELoss(pred,truth) |
|
|
107 |
|
|
|
108 |
if aux == "simclr": |
|
|
109 |
N, D = aux_features.shape |
|
|
110 |
if (N % 2) == 1: |
|
|
111 |
aux_features = aux_features[:-1, :] |
|
|
112 |
N, D = aux_features.shape |
|
|
113 |
|
|
|
114 |
aux_features = aux_features.reshape(N//2, 2, D) |
|
|
115 |
|
|
|
116 |
SimClrLoss = simclr_loss_vectorized(aux_features[:, 0, :], aux_features[:, 1, :], tau = 0.1) |
|
|
117 |
|
|
|
118 |
if pred is None: |
|
|
119 |
return SimClrLoss |
|
|
120 |
|
|
|
121 |
else: |
|
|
122 |
Loss = 0.9 * Loss + 0.1 * SimClrLoss |
|
|
123 |
#print('simclr_loss:', SimClrLoss) |
|
|
124 |
return Loss, SimClrLoss |
|
|
125 |
|
|
|
126 |
elif aux == "reg": |
|
|
127 |
#Add reg loss code here |
|
|
128 |
pass |
|
|
129 |
|
|
|
130 |
else: |
|
|
131 |
return Loss |
|
|
132 |
|
|
|
133 |
def get_loss_fn(loss_args): |
|
|
134 |
loss_args_ = loss_args |
|
|
135 |
if isinstance(loss_args, argparse.Namespace): |
|
|
136 |
loss_args_ = vars(loss_args) |
|
|
137 |
loss_fn = loss_args_.get("loss_fn") |
|
|
138 |
|
|
|
139 |
if loss_fn == "BCE": |
|
|
140 |
return torch.nn.BCEWithLogitsLoss() |
|
|
141 |
elif loss_fn == "CE": |
|
|
142 |
return torch.nn.CrossEntropyLoss() |
|
|
143 |
elif loss_fn == 'DBE': |
|
|
144 |
return smp.losses.DiceLoss(mode = 'binary', from_logits = True) |
|
|
145 |
elif loss_fn == 'DLE': |
|
|
146 |
return smp.losses.DiceLoss(mode = 'multilabel', from_logits = True) |
|
|
147 |
elif loss_fn == 'Combined': |
|
|
148 |
return CombinedLoss() |
|
|
149 |
else: |
|
|
150 |
raise ValueError(f"loss_fn {loss_args.loss_fn} not supported.") |