--- a +++ b/shepherd/utils/loss_utils.py @@ -0,0 +1,270 @@ +from pytorch_metric_learning.distances import LpDistance +from pytorch_metric_learning.utils import common_functions as c_f +from pytorch_metric_learning.utils import loss_and_miner_utils as lmu +from pytorch_metric_learning.losses import BaseMetricLossFunction + +import torch, torch.nn as nn, torch.nn.functional as F, numpy as np + + +def unique(x, dim=None): + """Unique elements of x and indices of those unique elements + https://github.com/pytorch/pytorch/issues/36748#issuecomment-619514810 + e.g. + unique(tensor([ + [1, 2, 3], + [1, 2, 4], + [1, 2, 3], + [1, 2, 5] + ]), dim=0) + => (tensor([[1, 2, 3], + [1, 2, 4], + [1, 2, 5]]), + tensor([0, 1, 3])) + """ + unique, inverse = torch.unique( + x, sorted=True, return_inverse=True, dim=dim) + perm = torch.arange(inverse.size(0), dtype=inverse.dtype, + device=inverse.device) + inverse, perm = inverse.flip([0]), perm.flip([0]) + return unique, inverse.new_empty(unique.size(dim)).scatter_(0, inverse, perm) + + +def _construct_labels(candidate_embeddings, candidate_node_idx, correct_node_idx, mask): + ''' + Format the batch to input into metric learning loss function + ''' + batch, n_candidates, embed_dim = candidate_embeddings.shape + + # get mask + mask_reshaped = mask.reshape(batch*n_candidates, -1) + expanded_mask = mask_reshaped.expand(-1,embed_dim) + + # flatten the gene node idx and gene embeddings + candidate_node_idx_flattened = candidate_node_idx.view(batch*n_candidates, -1) + candidate_embeddings_flattened = candidate_embeddings.view(batch*n_candidates, -1) + candidate_embeddings_flattened = candidate_embeddings_flattened * expanded_mask + + # get unique node idx & corresponding embeddings + candidate_node_idx_flattened_unique, unique_ind = unique(candidate_node_idx_flattened, dim=0) + candidate_embeddings_flattened_unique = candidate_embeddings_flattened[unique_ind,:] + + # remove padding + if candidate_node_idx_flattened_unique[0] == 0: + candidate_embeddings_flattened_unique = candidate_embeddings_flattened_unique[1:,:] + candidate_node_idx_flattened_unique = candidate_node_idx_flattened_unique[1:, :] + + # create a one hot encoding of correct gene/disease in the list of all in the batch + label_idx = torch.where(candidate_node_idx_flattened_unique.unsqueeze(1) == correct_node_idx.unsqueeze(0), 1, 0) + label_idx = label_idx.sum(dim=-1).T + + return candidate_node_idx_flattened_unique, candidate_embeddings_flattened_unique, label_idx + + +def _construct_disease_labels(disease_embedding, batch_disease_nid): + if len(disease_embedding.shape) == 3: + batch, n_candidates, embed_dim = disease_embedding.shape + batch_disease_nid_reshaped = batch_disease_nid.view(batch*n_candidates, -1) + disease_embedding_reshaped = disease_embedding.view(batch*n_candidates, -1) + else: + batch_disease_nid_reshaped = batch_disease_nid + disease_embedding_reshaped = disease_embedding + + # get unique diseases * corresponding embeddings in batch + batch_disease_nid_unique, unique_ind = unique(batch_disease_nid_reshaped, dim=0) + disease_embeddings_unique = disease_embedding_reshaped[unique_ind,:] + + #remove padding + if batch_disease_nid_unique[0] == 0: + disease_embeddings_unique = disease_embeddings_unique[1:,:] + batch_disease_nid_unique = batch_disease_nid_unique[1:, :] + + # create a one hot encoding of correct disease in the list of all diseases in the batch + label_idx = torch.where(batch_disease_nid_unique.T == batch_disease_nid_reshaped, 1, 0) + if len(disease_embedding.shape) == 3: #need to reshape the label_idx + batch, n_candidates, embed_dim = disease_embedding.shape + label_idx = label_idx.view(batch, n_candidates, -1) + label_idx = torch.sum(label_idx, dim=1) + + return disease_embeddings_unique, label_idx + + +### https://github.com/Confusezius/Revisiting_Deep_Metric_Learning_PyTorch/blob/master/criteria/multisimilarity.py +class MultisimilarityCriterion(torch.nn.Module): + def __init__(self, pos_weight, neg_weight, margin, thresh, + embed_dim, only_hard_distractors=True): + super().__init__() + self.pos_weight = pos_weight + self.neg_weight = neg_weight + self.margin = margin + self.thresh = thresh + self.only_hard_distractors = only_hard_distractors + + + def forward(self, sims, mask, one_hot_labels, **kwargs): + + loss = [] + pos_terms, neg_terms = [], [] + for i in range(sims.shape[0]): + + pos_idxs = one_hot_labels[i,:] == 1 + if self.only_hard_distractors: + curr_mask = mask[i,:] + neg_idxs = ((one_hot_labels[i,:] == 0) * curr_mask) + else: + neg_idxs = (one_hot_labels[i,:] == 0) + + if not torch.sum(pos_idxs) or not torch.sum(neg_idxs): + print('No positive or negative examples available') + continue + + anchor_pos_sim = sims[i][pos_idxs] + anchor_neg_sim = sims[i][neg_idxs] + + neg_idxs = (anchor_neg_sim + self.margin) > torch.min(anchor_pos_sim) + pos_idxs = (anchor_pos_sim - self.margin) < torch.max(anchor_neg_sim) + if not torch.sum(neg_idxs): + print('No negative examples available - check 2') + elif not torch.sum(pos_idxs): + print('No positive examples available - check 2') + else: + anchor_neg_sim = anchor_neg_sim[neg_idxs] + anchor_pos_sim = anchor_pos_sim[pos_idxs] + + pos_term = 1./self.pos_weight * torch.log(1+torch.sum(torch.exp(-self.pos_weight * (anchor_pos_sim - self.thresh)))) + neg_term = 1./self.neg_weight * torch.log(1+torch.sum(torch.exp(self.neg_weight * (anchor_neg_sim - self.thresh)))) + + + loss.append(pos_term + neg_term) + pos_terms.append(pos_term) + neg_terms.append(neg_term) + + if loss == []: + loss = torch.Tensor([0]).to(sims.device) + pos_terms = torch.Tensor([0]).to(sims.device) + neg_terms = torch.Tensor([0]).to(sims.device) + loss.requires_grad = True + else: + loss = torch.mean(torch.stack(loss)) + pos_terms = torch.mean(torch.stack(pos_terms)) + neg_terms = torch.mean(torch.stack(neg_terms)) + + return loss + + +def construct_batch_labels(candidate_embeddings, candidate_node_idx, correct_node_idx, mask): + ''' + Format the batch to input into metric learning loss function + ''' + batch, n_candidates, embed_dim = candidate_embeddings.shape + + # get mask + mask_reshaped = mask.reshape(batch*n_candidates, -1) + expanded_mask = mask_reshaped.expand(-1,embed_dim) + + # flatten the gene node idx and gene embeddings + candidate_node_idx_flattened = candidate_node_idx.view(batch*n_candidates, -1) + candidate_embeddings_flattened = candidate_embeddings.view(batch*n_candidates, -1) + candidate_embeddings_flattened = candidate_embeddings_flattened * expanded_mask + + # NOTE: assumes there are already unique values + candidate_node_idx_flattened_unique = candidate_node_idx_flattened[candidate_node_idx_flattened.squeeze() != 0] + candidate_embeddings_flattened_unique = candidate_embeddings_flattened[candidate_node_idx_flattened.squeeze() != 0,:] + + # create a one hot encoding of correct gene/disease in the list of all in the batch + label_idx = torch.where(candidate_node_idx_flattened_unique.unsqueeze(1) == correct_node_idx.unsqueeze(0), 1, 0) + label_idx = label_idx.sum(dim=-1).T + + return candidate_node_idx_flattened_unique, candidate_embeddings_flattened_unique, label_idx + + +class NCALoss(BaseMetricLossFunction): + def __init__(self, softmax_scale=1, only_hard_distractors=False, **kwargs): + super().__init__(**kwargs) + self.softmax_scale = softmax_scale + self.only_hard_distractors = only_hard_distractors + self.add_to_recordable_attributes( + list_of_names=["softmax_scale"], is_stat=False + ) + + def forward(self, phenotype_embedding, disease_embedding, batch_disease_nid, batch_cand_disease_nid=None, disease_mask=None, one_hot_labels=None, indices_tuple=None, use_candidate_list=False): + """ + Args: + embeddings: tensor of size (batch_size, embedding_size) + labels: tensor of size (batch_size) + indices_tuple: tuple of size 3 for triplets (anchors, positives, negatives) + or size 4 for pairs (anchor1, postives, anchor2, negatives) + Can also be left as None + Returns: the loss + """ + self.reset_stats() + loss_dict, disease_softmax, one_hot_labels, candidate_disease_idx, candidate_disease_embeddings = self.compute_loss(phenotype_embedding, disease_embedding, batch_disease_nid, batch_cand_disease_nid, disease_mask, one_hot_labels, indices_tuple, use_candidate_list) + self.add_embedding_regularization_to_loss_dict(loss_dict, phenotype_embedding) + if loss_dict is None: reduction = None + else: reduction = self.reducer(loss_dict, None, None) + return reduction, disease_softmax, one_hot_labels, candidate_disease_idx, candidate_disease_embeddings + + # https://www.cs.toronto.edu/~hinton/absps/nca.pdf + def compute_loss(self, phenotype_embedding, disease_embedding, batch_corr_disease_nid, batch_cand_disease_nid, disease_mask, labels, indices_tuple, use_candidate_list): + + if len(phenotype_embedding) <= 1: + return self.zero_losses(), None, None + + if disease_embedding is None: #phenotype-phenotypes + loss_dict, disease_softmax, labels = self.nca_computation( + phenotype_embedding, phenotype_embedding, labels, indices_tuple, use_one_hot_labels=False + ) + candidate_disease_idx = None + candidate_disease_embeddings = None + + else: + # disease-phenotypes + if self.only_hard_distractors or use_candidate_list: + candidate_disease_embeddings = disease_embedding + phenotype_embedding = phenotype_embedding.unsqueeze(1) + else: + candidate_disease_idx, candidate_disease_embeddings, labels = construct_batch_labels(disease_embedding, batch_cand_disease_nid, batch_corr_disease_nid, disease_mask) + + loss_dict, disease_softmax, labels = self.nca_computation( + phenotype_embedding, candidate_disease_embeddings, labels, indices_tuple, use_one_hot_labels=True + ) + + return loss_dict, disease_softmax, labels, candidate_disease_idx, candidate_disease_embeddings + + def nca_computation( + self, query, reference, labels, indices_tuple, use_one_hot_labels + ): + dtype = query.dtype + mat = self.distance(query, reference) + if not self.distance.is_inverted: + mat = -mat + mat = mat.squeeze(1) + + if query is reference: + mat.fill_diagonal_(c_f.neg_inf(dtype)) + softmax = torch.nn.functional.softmax(self.softmax_scale * mat, dim=1) + + if labels.nelement() == 0: + loss_dict = None + else: + if not use_one_hot_labels: + labels = c_f.to_dtype( + labels.unsqueeze(1) == labels.unsqueeze(0), dtype=dtype + ) + labels = labels.squeeze(-1) + exp = torch.sum(softmax * labels, dim=1) + non_zero = exp != 0 + loss = -torch.log(exp[non_zero]) + indices = c_f.torch_arange_from_size(query)[non_zero] + loss_dict = { + "loss": { + "losses": loss, + "indices": indices, + "reduction_type": "element", + } + } + return loss_dict, softmax, labels + + def get_default_distance(self): + return LpDistance(power=2) + +