Diff of /eval/loss.py [000000] .. [139527]

Switch to unified view

a b/eval/loss.py
1
import torch
2
import argparse
3
import segmentation_models_pytorch as smp
4
import torch.nn as nn
5
6
7
def sim(z_i, z_j):
8
    norm_dot_product = None
9
    norm_dot_product = torch.dot(z_i, z_j)/\
10
    (torch.linalg.norm(z_i) * torch.linalg.norm(z_j))
11
    return norm_dot_product
12
13
def sim_positive_pairs(out_left, out_right):
14
    """Normalized dot product between positive pairs.
15
16
    Inputs:
17
    - out_left: NxD tensor; output of the projection head g(), left branch in SimCLR model.
18
    - out_right: NxD tensor; output of the projection head g(), right branch in SimCLR model.
19
    Each row is a z-vector for an augmented sample in the batch.
20
    The same row in out_left and out_right form a positive pair.
21
    
22
    Returns:
23
    - A Nx1 tensor; each row k is the normalized dot product between out_left[k] and out_right[k].
24
    """
25
    pos_pairs = None
26
    
27
    pos_pairs = torch.sum(out_left * out_right, dim = 1) / \
28
    (torch.linalg.norm(out_left, dim = 1) * torch.linalg.norm(out_right, dim = 1))
29
30
    pos_pairs = pos_pairs.unsqueeze(1)
31
    return pos_pairs
32
33
34
def compute_sim_matrix(out):
35
    """Compute a 2N x 2N matrix of normalized dot products between all pairs of augmented examples in a batch.
36
37
    Inputs:
38
    - out: 2N x D tensor; each row is the z-vector (output of projection head) of a single augmented example.
39
    There are a total of 2N augmented examples in the batch.
40
    
41
    Returns:
42
    - sim_matrix: 2N x 2N tensor; each element i, j in the matrix is the normalized dot product between out[i] and out[j].
43
    """
44
    sim_matrix = None
45
46
    sim_matrix = torch.mm(out, out.T) 
47
    sim_matrix/= torch.linalg.norm(out, dim = 1).unsqueeze(1)
48
    sim_matrix/= torch.linalg.norm(out, dim = 1).unsqueeze(1).T
49
50
    return sim_matrix
51
52
def simclr_loss_vectorized(out_left, out_right, tau):
53
54
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
55
56
    N = out_left.shape[0]
57
    
58
    # Concatenate out_left and out_right into a 2*N x D tensor.
59
    out = torch.cat([out_left, out_right], dim=0)  # [2*N, D]
60
    
61
    # Compute similarity matrix between all pairs of augmented examples in the batch.
62
    sim_matrix = compute_sim_matrix(out)  # [2*N, 2*N]
63
    
64
    # Step 1: Use sim_matrix to compute the denominator value for all augmented samples.
65
    # Hint: Compute e^{sim / tau} and store into exponential, which should have shape 2N x 2N.
66
    exponential = torch.exp(sim_matrix/tau)
67
    
68
    # This binary mask zeros out terms where k=i.
69
    mask = (torch.ones_like(exponential, device=device) - torch.eye(2 * N, device=device)).to(device).bool()
70
    
71
   #print((sim_matrix * mask).argmax(dim = 1))
72
73
    # We apply the binary mask.
74
    exponential = exponential.masked_select(mask).view(2 * N, -1)
75
    
76
    # Hint: Compute the denominator values for all augmented samples. This should be a 2N x 1 vector.
77
    denom = torch.sum(exponential, dim = 1).unsqueeze(1)
78
 
79
    # Step 2: Compute similarity between positive pairs.
80
81
    sim_pos = sim_positive_pairs(out_left, out_right)
82
83
    numerator = None
84
85
    numerator = torch.exp(sim_pos/tau)
86
    
87
    # Step 4: Now that you have the numerator and denominator for all augmented samples, compute the total loss.
88
    loss = None
89
90
    loss = torch.sum(-torch.log(numerator/denom[:N]) - torch.log(numerator/denom[N:])) / (2*N)
91
92
    #print(numerator/denom[:N])
93
    #print(numerator/denom[N:])
94
    
95
    return loss
96
97
class CombinedLoss(nn.Module):
98
    def __init__(self):
99
        super().__init__()
100
101
    def forward(self, pred, truth, aux = None, aux_features = None):
102
103
        if pred is not None:
104
            DiceLoss = smp.losses.DiceLoss(mode = 'multilabel', from_logits = True)
105
            BCELoss = smp.losses.SoftBCEWithLogitsLoss()
106
            Loss = 0.4 * DiceLoss(pred, truth) + 0.6 * BCELoss(pred,truth)
107
108
        if aux == "simclr":
109
            N, D = aux_features.shape
110
            if (N % 2) == 1:
111
                aux_features = aux_features[:-1, :]
112
                N, D = aux_features.shape
113
114
            aux_features = aux_features.reshape(N//2, 2, D)
115
116
            SimClrLoss = simclr_loss_vectorized(aux_features[:, 0, :], aux_features[:, 1, :],  tau = 0.1)
117
118
            if pred is None:
119
                return SimClrLoss
120
            
121
            else:
122
                Loss = 0.9 * Loss + 0.1 * SimClrLoss
123
                #print('simclr_loss:', SimClrLoss)
124
                return Loss, SimClrLoss
125
126
        elif aux == "reg":
127
            #Add reg loss code here
128
            pass
129
        
130
        else:
131
            return Loss
132
133
def get_loss_fn(loss_args):
134
    loss_args_ = loss_args
135
    if isinstance(loss_args, argparse.Namespace):
136
        loss_args_ = vars(loss_args)
137
    loss_fn = loss_args_.get("loss_fn")
138
139
    if loss_fn == "BCE":
140
        return torch.nn.BCEWithLogitsLoss()
141
    elif loss_fn == "CE":
142
        return torch.nn.CrossEntropyLoss()
143
    elif loss_fn == 'DBE':
144
        return smp.losses.DiceLoss(mode = 'binary', from_logits = True)
145
    elif loss_fn == 'DLE':
146
        return smp.losses.DiceLoss(mode = 'multilabel', from_logits = True)
147
    elif loss_fn == 'Combined':
148
        return CombinedLoss()
149
    else:
150
        raise ValueError(f"loss_fn {loss_args.loss_fn} not supported.")