--- a +++ b/unimol/losses/cross_entropy.py @@ -0,0 +1,946 @@ +# Copyright (c) DP Technology. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +import torch +import torch.nn.functional as F +import pandas as pd +from unicore import metrics +from unicore.losses import UnicoreLoss, register_loss +from unicore.losses.cross_entropy import CrossEntropyLoss +from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score +import numpy as np +import warnings +from sklearn.metrics import top_k_accuracy_score +from rdkit.ML.Scoring.Scoring import CalcBEDROC +import scipy.stats as stats + + +def calculate_bedroc(y_true, y_score, alpha): + """ + Calculate BEDROC score. + + Parameters: + - y_true: true binary labels (0 or 1) + - y_score: predicted scores or probabilities + - alpha: parameter controlling the degree of early retrieval emphasis + + Returns: + - BEDROC score + """ + + # concate res_single and labels + scores = np.expand_dims(y_score, axis=1) + y_true = np.expand_dims(y_true, axis=1) + #print(scores.shape, y_true.shape) + scores = np.concatenate((scores, y_true), axis=1) + # inverse sort scores based on first column + scores = scores[scores[:,0].argsort()[::-1]] + bedroc = CalcBEDROC(scores, 1, 80.5) + return bedroc + +@register_loss("decoder_loss") +class DecoderLoss(CrossEntropyLoss): + def __init__(self, task): + super().__init__(task) + + def forward(self, model, sample, reduce=True, fix_encoder=False): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model( + **sample["net_input"], + features_only=True, + classification_head_name=None, + fix_encoder=fix_encoder + ) + loss = self.compute_loss(model, net_output, sample, reduce=reduce) + targets = sample["net_input"]["selfie_tokens"] + sample_size = targets.size(0) + + lprobs = net_output[:,:,:targets.shape[-1]] + if not self.training: + logging_output = { + "loss": loss.data, + "prob": lprobs.data, + "target": targets.data, + "smi_name": sample["smi_name"], + "sample_size": sample_size, + "bsz": sample["target"]["finetune_target"].size(0), + } + else: + logging_output = { + "loss": loss.data, + "sample_size": sample_size, + "bsz": sample["target"]["finetune_target"].size(0), + } + return loss, sample_size, logging_output + + def compute_loss(self, model, net_output, sample, reduce=True): + lprobs = net_output + targets = sample["net_input"]["selfie_tokens"] + lprobs = lprobs[:,:,:targets.shape[-1]] + # print("111", lprobs.shape, targets.shape, sample["target"]["finetune_target"].shape) + nll_loss = F.nll_loss( + lprobs, + targets, + reduction="sum" if reduce else "none", + ) / lprobs.shape[-1] + + loss = nll_loss + #print(loss.data, nll_loss.data, kld_loss.data) + return loss + + @staticmethod + def reduce_metrics(logging_outputs, split="valid") -> None: + """Aggregate logging outputs from data parallel training.""" + # if split == "valid": + # print("hi1") + loss = sum(log.get("loss", 0).float() for log in logging_outputs) + # if split == "valid": + # print("hi2") + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + # if split == "valid": + # print("hi3") + # we divide by log(2) to convert the loss from base e to base 2 + metrics.log_scalar( + "loss_all", loss / sample_size / math.log(2), sample_size, round=3 + ) + + if "valid" in split or "test" in split: + + prob_list = [] + pred_list = [] + target_list = [] + for log in logging_outputs: + prob = log.get("prob") + prob = torch.transpose(prob, 1, 2) + prob = prob.reshape((-1, prob.shape[-1])) + prob_list.append(prob) + pred = log.get("prob").argmax(dim=1) + pred = pred.flatten() + pred_list.append(pred) + target = log.get("target") + target = target.flatten() + target_list.append(target) + + preds = torch.cat(pred_list, dim=0) + targets = torch.cat(target_list, dim=0) + #print(preds.shape, targets.shape) + acc = (preds == targets).float().mean(dim=-1) + #print(acc.shape) + metrics.log_scalar( + f"{split}_acc", acc , sample_size, round=3 + ) + + + @staticmethod + def logging_outputs_can_be_summed(is_train) -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return is_train + + +@register_loss("decoder_vae_loss") +class DecoderVAELoss(CrossEntropyLoss): + def __init__(self, task): + super().__init__(task) + + def forward(self, model, sample, candidate_reps, candidate_embs, candidate_smiles, reduce=True, fix_encoder=False): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model( + **sample["net_input"], + candidate_reps=candidate_reps, + candidate_embs=candidate_embs, + candidate_smiles=candidate_smiles, + features_only=True, + classification_head_name=None, + fix_encoder=fix_encoder + ) + loss, nll_loss, kld_loss = self.compute_loss(model, net_output, sample, reduce=reduce) + targets = sample["net_input"]["selfie_tokens"] + sample_size = targets.size(0) + + lprobs = net_output[0][:,:,:targets.shape[-1]] + if not self.training: + logging_output = { + "loss": loss.data, + "kld_loss": kld_loss.data, + "nll_loss": nll_loss.data, + "prob": lprobs.data, + "target": targets.data, + "smi_name": sample["smi_name"], + "sample_size": sample_size, + "bsz": sample["target"]["finetune_target"].size(0), + } + else: + logging_output = { + "loss": loss.data, + "kld_loss": kld_loss.data, + "nll_loss": nll_loss.data, + "sample_size": sample_size, + "bsz": sample["target"]["finetune_target"].size(0), + } + return loss, sample_size, logging_output + + def compute_loss(self, model, net_output, sample, reduce=True): + out, z, mu, log_var = net_output + lprobs = out + targets = sample["net_input"]["selfie_tokens"] + lprobs = lprobs[:,:,:targets.shape[-1]] + # print("111", lprobs.shape, targets.shape, sample["target"]["finetune_target"].shape) + nll_loss = F.nll_loss( + lprobs, + targets, + reduction="sum" if reduce else "none", + ) / lprobs.shape[-1] + kld_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) / lprobs.shape[1] + p=0.2 + loss = p * kld_loss + (1-p) * nll_loss + #print(loss.data, nll_loss.data, kld_loss.data) + return loss, nll_loss, kld_loss + + @staticmethod + def reduce_metrics(logging_outputs, split="valid") -> None: + """Aggregate logging outputs from data parallel training.""" + # if split == "valid": + # print("hi1") + loss_all = sum(log.get("loss", 0).float() for log in logging_outputs) + loss_kld = sum(log.get("kld_loss", 0).float() for log in logging_outputs) + loss_nll = sum(log.get("nll_loss", 0).float() for log in logging_outputs) + # if split == "valid": + # print("hi2") + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + # if split == "valid": + # print("hi3") + # we divide by log(2) to convert the loss from base e to base 2 + metrics.log_scalar( + "loss_all", loss_all / sample_size / math.log(2), sample_size, round=3 + ) + metrics.log_scalar( + "loss_kld", loss_kld / sample_size, sample_size, round=3 + ) + metrics.log_scalar( + "loss_nll", loss_nll / sample_size, sample_size, round=3 + ) + if "valid" in split or "test" in split: + + prob_list = [] + pred_list = [] + target_list = [] + for log in logging_outputs: + prob = log.get("prob") + prob = torch.transpose(prob, 1, 2) + prob = prob.reshape((-1, prob.shape[-1])) + prob_list.append(prob) + pred = log.get("prob").argmax(dim=1) + pred = pred.flatten() + pred_list.append(pred) + target = log.get("target") + target = target.flatten() + target_list.append(target) + + probs = torch.cat(prob_list, dim=0) + preds = torch.cat(pred_list, dim=0) + targets = torch.cat(target_list, dim=0) + #print(preds.shape, targets.shape) + acc = (preds == targets).float().mean(dim=-1) + #print(acc.shape) + metrics.log_scalar( + f"{split}_acc", acc , sample_size, round=3 + ) + ''' + # smi_list = [ + # item for log in logging_outputs for item in log.get("smi_name") + # ] + probs = torch.exp(probs) + #prob_flat = prob_flat.reshape((-1, prob_flat.shape[-1])) + print(probs.shape) + + #targets = targets.squeeze(dim=-1) + auc = roc_auc_score(targets.cpu(), probs.cpu(), multi_class="ovo", labels=torch.arange(probs.shape[-1])) + #df = df.groupby("smi").mean() + #agg_auc = roc_auc_score(df["targets"], df["probs"]) + agg_auc = auc + + metrics.log_scalar(f"{split}_auc", auc, sample_size, round=3) + metrics.log_scalar(f"{split}_agg_auc", agg_auc, sample_size, round=4) + ''' + + @staticmethod + def logging_outputs_can_be_summed(is_train) -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return is_train + + +@register_loss("finetune_cross_entropy") +class FinetuneCrossEntropyLoss(CrossEntropyLoss): + def __init__(self, task): + super().__init__(task) + + def forward(self, model, sample, reduce=True, fix_encoder=False): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model( + **sample["net_input"], + features_only=True, + classification_head_name=self.args.classification_head_name, + fix_encoder=fix_encoder + ) + logit_output = net_output[0] + loss = self.compute_loss(model, logit_output, sample, reduce=reduce) + sample_size = sample["target"]["finetune_target"].size(0) + if not self.training: + probs = F.softmax(logit_output.float(), dim=-1).view( + -1, logit_output.size(-1) + ) + logging_output = { + "loss": loss.data, + "prob": probs.data, + "target": sample["target"]["finetune_target"].view(-1).data, + "smi_name": sample["smi_name"], + "sample_size": sample_size, + "bsz": sample["target"]["finetune_target"].size(0), + } + else: + logging_output = { + "loss": loss.data, + "sample_size": sample_size, + "bsz": sample["target"]["finetune_target"].size(0), + } + return loss, sample_size, logging_output + + def compute_loss(self, model, net_output, sample, reduce=True): + lprobs = F.log_softmax(net_output.float(), dim=-1) + lprobs = lprobs.view(-1, lprobs.size(-1)) + targets = sample["target"]["finetune_target"].view(-1) + # print("111", lprobs.shape, targets.shape, sample["target"]["finetune_target"].shape) + loss = F.nll_loss( + lprobs, + targets, + reduction="sum" if reduce else "none", + ) + return loss + + @staticmethod + def reduce_metrics(logging_outputs, split="valid") -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + # we divide by log(2) to convert the loss from base e to base 2 + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + if "valid" in split or "test" in split: + acc_sum = sum( + sum(log.get("prob").argmax(dim=-1) == log.get("target")) + for log in logging_outputs + ) + probs = torch.cat([log.get("prob") for log in logging_outputs], dim=0) + metrics.log_scalar( + f"{split}_acc", acc_sum / sample_size, sample_size, round=3 + ) + if probs.size(-1) == 2: + # binary classification task, add auc score + targets = torch.cat( + [log.get("target", 0) for log in logging_outputs], dim=0 + ) + smi_list = [ + item for log in logging_outputs for item in log.get("smi_name") + ] + df = pd.DataFrame( + { + "probs": probs[:, 1].cpu(), + "targets": targets.cpu(), + "smi": smi_list, + } + ) + auc = roc_auc_score(df["targets"], df["probs"]) + df = df.groupby("smi").mean() + agg_auc = roc_auc_score(df["targets"], df["probs"]) + metrics.log_scalar(f"{split}_auc", auc, sample_size, round=3) + metrics.log_scalar(f"{split}_agg_auc", agg_auc, sample_size, round=4) + + @staticmethod + def logging_outputs_can_be_summed(is_train) -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return is_train + +@register_loss("ce") +class CEntropyLoss(CrossEntropyLoss): + def __init__(self, task): + super().__init__(task) + + def forward(self, model, sample, reduce=True, fix_encoder=False): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model( + **sample["net_input"], + smi_list = sample["smi_name"], + pocket_list = sample["pocket_name"], + features_only=True, + fix_encoder=fix_encoder + ) + logit_output = net_output + loss = self.compute_loss(model, logit_output, sample, reduce=reduce) + #print(sample["target"]["finetune_target"]) + sample_size = sample["target"]["finetune_target"].size(0) + if not self.training: + probs = torch.sigmoid(logit_output.float()) + logging_output = { + "loss": loss.data, + "prob": probs.data, + "target": sample["target"]["finetune_target"].view(-1).data, + "smi_name": sample["smi_name"], + "sample_size": sample_size, + "bsz": sample["target"]["finetune_target"].size(0), + } + else: + logging_output = { + "loss": loss.data, + "sample_size": sample_size, + "bsz": sample["target"]["finetune_target"].size(0), + } + return loss, sample_size, logging_output + + def compute_loss(self, model, net_output, sample, reduce=True): + + targets = sample["target"]["finetune_target"].view(-1) + # print("111", lprobs.shape, targets.shape, sample["target"]["finetune_target"].shape) + #print(net_output.shape, targets.shape) + loss = F.binary_cross_entropy_with_logits( + net_output.float(), + targets, + reduction="sum" if reduce else "none", + ) + return loss + + @staticmethod + def reduce_metrics(logging_outputs, split="valid") -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + # we divide by log(2) to convert the loss from base e to base 2 + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + if "valid" in split or "test" in split: + + # get acc + acc_sum = 0 + for log in logging_outputs: + probs = torch.sigmoid(log.get("prob")) + targets = log.get("target") + probs = probs > 0.5 + # convert to int + probs = probs.long() + + + logs = [log.get("prob") for log in logging_outputs] + targets = torch.cat( + [log.get("target", 0) for log in logging_outputs], dim=0 + ) + probs = torch.cat([log.get("prob") for log in logging_outputs], dim=0) + #probs = torch.sigmoid(probs) + print(probs.shape, targets.shape) + print(probs[:10], targets[:10]) + preds = probs > 0.5 + # convert to int + preds = preds.long() + acc_sum = (preds == targets).sum() + metrics.log_scalar( + f"{split}_acc", acc_sum / sample_size, sample_size, round=3 + ) + # binary classification task, add auc score + + + smi_list = [ + item for log in logging_outputs for item in log.get("smi_name") + ] + df = pd.DataFrame( + { + "probs": probs.cpu(), + "targets": targets.cpu(), + "smi": smi_list, + } + ) + # get values of df["targets"] + # + + auc = roc_auc_score(df["targets"].values, df["probs"].values) + df = df.groupby("smi").mean() + agg_auc = roc_auc_score(df["targets"], df["probs"]) + #print(df["targets"]) + metrics.log_scalar(f"{split}_auc", auc, sample_size, round=3) + metrics.log_scalar(f"{split}_agg_auc", agg_auc, sample_size, round=4) + + @staticmethod + def logging_outputs_can_be_summed(is_train) -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return is_train + + +@register_loss("in_batch_softmax") +class IBSLoss(CrossEntropyLoss): + def __init__(self, task): + super().__init__(task) + + def forward(self, model, sample, reduce=True, fix_encoder=False): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model( + **sample["net_input"], + smi_list = sample["smi_name"], + pocket_list = sample["pocket_name"], + features_only=True, + fix_encoder=fix_encoder, + is_train = self.training + ) + + logit_output = net_output[0] + loss = self.compute_loss(model, logit_output, sample, reduce=reduce) + sample_size = logit_output.size(0) + targets = torch.arange(sample_size, dtype=torch.long).cuda() + affinities = sample["target"]["finetune_target"].view(-1) + if not self.training: + logit_output = logit_output[:,:sample_size] + probs = F.softmax(logit_output.float(), dim=-1).view( + -1, logit_output.size(-1) + ) + logging_output = { + "loss": loss.data, + "prob": probs.data, + "target": targets, + "smi_name": sample["smi_name"], + "sample_size": sample_size, + "bsz": targets.size(0), + "scale": net_output[1].data, + "affinity": affinities, + } + else: + logging_output = { + "loss": loss.data, + "sample_size": sample_size, + "bsz": targets.size(0), + "scale": net_output[1].data + } + return loss, sample_size, logging_output + + def compute_loss(self, model, net_output, sample, reduce=True): + lprobs_pocket = F.log_softmax(net_output.float(), dim=-1) + lprobs_pocket = lprobs_pocket.view(-1, lprobs_pocket.size(-1)) + sample_size = lprobs_pocket.size(0) + targets= torch.arange(sample_size, dtype=torch.long).view(-1).cuda() + + # pocket retrieve mol + loss_pocket = F.nll_loss( + lprobs_pocket, + targets, + reduction="sum" if reduce else "none", + ) + + lprobs_mol = F.log_softmax(torch.transpose(net_output.float(), 0, 1), dim=-1) + lprobs_mol = lprobs_mol.view(-1, lprobs_mol.size(-1)) + lprobs_mol = lprobs_mol[:sample_size] + + # mol retrieve pocket + loss_mol = F.nll_loss( + lprobs_mol, + targets, + reduction="sum" if reduce else "none", + ) + + loss = 0.5 * loss_pocket + 0.5 * loss_mol + return loss + + @staticmethod + def reduce_metrics(logging_outputs, split="valid") -> None: + """Aggregate logging outputs from data parallel training.""" + metrics.log_scalar("scale", logging_outputs[0].get("scale"), round=3) + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + # we divide by log(2) to convert the loss from base e to base 2 + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + if "valid" in split or "test" in split: + acc_sum = sum( + sum(log.get("prob").argmax(dim=-1) == log.get("target")) + for log in logging_outputs + ) + + prob_list = [] + if len(logging_outputs) == 1: + prob_list.append(logging_outputs[0].get("prob")) + else: + for i in range(len(logging_outputs)-1): + prob_list.append(logging_outputs[i].get("prob")) + probs = torch.cat(prob_list, dim=0) + + metrics.log_scalar( + f"{split}_acc", acc_sum / sample_size, sample_size, round=3 + ) + + metrics.log_scalar( + "valid_neg_loss", -loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + + + targets = torch.cat( + [log.get("target", 0) for log in logging_outputs], dim=0 + ) + print(targets.shape, probs.shape) + + targets = targets[:len(probs)] + bedroc_list = [] + auc_list = [] + for i in range(len(probs)): + prob = probs[i] + target = targets[i] + label = torch.zeros_like(prob) + label[target]=1.0 + cur_auc = roc_auc_score(label.cpu(), prob.cpu()) + auc_list.append(cur_auc) + bedroc = calculate_bedroc(label.cpu(), prob.cpu(), 80.5) + bedroc_list.append(bedroc) + bedroc = np.mean(bedroc_list) + auc = np.mean(auc_list) + + top_k_acc = top_k_accuracy_score(targets.cpu(), probs.cpu(), k=3, normalize=True) + metrics.log_scalar(f"{split}_auc", auc, sample_size, round=3) + metrics.log_scalar(f"{split}_bedroc", bedroc, sample_size, round=3) + metrics.log_scalar(f"{split}_top3_acc", top_k_acc, sample_size, round=3) + + + + @staticmethod + def logging_outputs_can_be_summed(is_train) -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return is_train + + + + + + +@register_loss("multi_task_BCE") +class MultiTaskBCELoss(CrossEntropyLoss): + def __init__(self, task): + super().__init__(task) + + def forward(self, model, sample, reduce=True, fix_encoder=False): + """Compute the loss for the given sample. + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model( + **sample["net_input"], + masked_tokens=None, + features_only=True, + classification_head_name=self.args.classification_head_name, + fix_encoder=fix_encoder + ) + logit_output = net_output[0] + is_valid = sample["target"]["finetune_target"] > -0.5 + loss = self.compute_loss( + model, logit_output, sample, reduce=reduce, is_valid=is_valid + ) + sample_size = sample["target"]["finetune_target"].size(0) + if not self.training: + probs = torch.sigmoid(logit_output.float()).view(-1, logit_output.size(-1)) + logging_output = { + "loss": loss.data, + "prob": probs.data, + "target": sample["target"]["finetune_target"].view(-1).data, + "num_task": self.args.num_classes, + "sample_size": sample_size, + "conf_size": self.args.conf_size, + "bsz": sample["target"]["finetune_target"].size(0), + } + else: + logging_output = { + "loss": loss.data, + "sample_size": sample_size, + "bsz": sample["target"]["finetune_target"].size(0), + } + return loss, sample_size, logging_output + + def compute_loss(self, model, net_output, sample, reduce=True, is_valid=None): + pred = net_output[is_valid].float() + targets = sample["target"]["finetune_target"][is_valid].float() + loss = F.binary_cross_entropy_with_logits( + pred, + targets, + reduction="sum" if reduce else "none", + ) + return loss + + @staticmethod + def reduce_metrics(logging_outputs, split="valid") -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + # we divide by log(2) to convert the loss from base e to base 2 + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + if "valid" in split or "test" in split: + agg_auc_list = [] + num_task = logging_outputs[0].get("num_task", 0) + conf_size = logging_outputs[0].get("conf_size", 0) + y_true = ( + torch.cat([log.get("target", 0) for log in logging_outputs], dim=0) + .view(-1, conf_size, num_task) + .cpu() + .numpy() + .mean(axis=1) + ) + y_pred = ( + torch.cat([log.get("prob") for log in logging_outputs], dim=0) + .view(-1, conf_size, num_task) + .cpu() + .numpy() + .mean(axis=1) + ) + # [test_size, num_classes] = [test_size * conf_size, num_classes].mean(axis=1) + for i in range(y_true.shape[1]): + # AUC is only defined when there is at least one positive data. + if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == 0) > 0: + # ignore nan values + is_labeled = y_true[:, i] > -0.5 + agg_auc_list.append( + roc_auc_score(y_true[is_labeled, i], y_pred[is_labeled, i]) + ) + if len(agg_auc_list) < y_true.shape[1]: + warnings.warn("Some target is missing!") + if len(agg_auc_list) == 0: + raise RuntimeError( + "No positively labeled data available. Cannot compute Average Precision." + ) + agg_auc = sum(agg_auc_list) / len(agg_auc_list) + metrics.log_scalar(f"{split}_agg_auc", agg_auc, sample_size, round=4) + + @staticmethod + def logging_outputs_can_be_summed(is_train) -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return is_train + +@register_loss("BCE") +class BCELoss(CrossEntropyLoss): + def __init__(self, task): + super().__init__(task) + + def forward(self, model, sample, reduce=True, fix_encoder=False): + """Compute the loss for the given sample. + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model( + **sample["net_input"], + smi_list = sample["smi_name"], + pocket_list = sample["pocket_name"], + features_only=True, + fix_encoder=fix_encoder + ) + logit_output = net_output + loss = self.compute_loss( + model, logit_output, sample, reduce=reduce + ) + sample_size = sample["target"]["finetune_target"].size(0) + + if not self.training: + probs = torch.sigmoid(logit_output.float()) + #print(probs.size()) + logging_output = { + "loss": loss.data, + "prob": probs.data, + "target": sample["target"]["finetune_target"].view(-1).data, + "sample_size": sample_size, + "bsz": sample["target"]["finetune_target"].size(0), + } + else: + logging_output = { + "loss": loss.data, + "sample_size": sample_size, + "bsz": sample["target"]["finetune_target"].size(0), + } + return loss, sample_size, logging_output + + def compute_loss(self, model, net_output, sample, reduce=True, is_valid=None): + pred = net_output.float() + targets = sample["target"]["finetune_target"].float() + loss = F.binary_cross_entropy_with_logits( + pred, + targets, + reduction="sum" if reduce else "none", + ) + return loss + + @staticmethod + def reduce_metrics(logging_outputs, split="valid") -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + # we divide by log(2) to convert the loss from base e to base 2 + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + if "valid" in split or "test" in split: + y_true_list = [] + y_pred_list = [] + y_true_list = [log.get("target", 0) for log in logging_outputs] + y_pred_list = [log.get("prob") for log in logging_outputs] + y_true = ( + torch.cat(y_true_list, dim=0) + .cpu() + .numpy() + ) + y_pred = ( + torch.cat(y_pred_list, dim=0) + .cpu() + .numpy() + ) + # [test_size, num_classes] = [test_size * conf_size, num_classes].mean(axis=1) + + auc = roc_auc_score(y_true , y_pred) + + + agg_auc = auc + metrics.log_scalar(f"{split}_agg_auc", agg_auc, sample_size, round=4) + + @staticmethod + def logging_outputs_can_be_summed(is_train) -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return is_train + +@register_loss("finetune_cross_entropy_pocket") +class FinetuneCrossEntropyPocketLoss(FinetuneCrossEntropyLoss): + def __init__(self, task): + super().__init__(task) + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model( + **sample["net_input"], + features_only=True, + classification_head_name=self.args.classification_head_name, + ) + logit_output = net_output[0] + loss = self.compute_loss(model, logit_output, sample, reduce=reduce) + sample_size = sample["target"]["finetune_target"].size(0) + if not self.training: + probs = F.softmax(logit_output.float(), dim=-1).view( + -1, logit_output.size(-1) + ) + logging_output = { + "loss": loss.data, + "prob": probs.data, + "target": sample["target"]["finetune_target"].view(-1).data, + "sample_size": sample_size, + "bsz": sample["target"]["finetune_target"].size(0), + } + else: + logging_output = { + "loss": loss.data, + "sample_size": sample_size, + "bsz": sample["target"]["finetune_target"].size(0), + } + return loss, sample_size, logging_output + + @staticmethod + def reduce_metrics(logging_outputs, split="valid") -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + # we divide by log(2) to convert the loss from base e to base 2 + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + if "valid" in split or "test" in split: + acc_sum = sum( + sum(log.get("prob").argmax(dim=-1) == log.get("target")) + for log in logging_outputs + ) + metrics.log_scalar( + f"{split}_acc", acc_sum / sample_size, sample_size, round=3 + ) + preds = ( + torch.cat( + [log.get("prob").argmax(dim=-1) for log in logging_outputs], dim=0 + ) + .cpu() + .numpy() + ) + targets = ( + torch.cat([log.get("target", 0) for log in logging_outputs], dim=0) + .cpu() + .numpy() + ) + metrics.log_scalar(f"{split}_pre", precision_score(targets, preds), round=3) + metrics.log_scalar(f"{split}_rec", recall_score(targets, preds), round=3) + metrics.log_scalar( + f"{split}_f1", f1_score(targets, preds), sample_size, round=3 + )