[d3ab9c]: / bert_mixup / early_mixup / utils.py

Download this file

34 lines (25 with data), 1.0 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
from sklearn.metrics import roc_auc_score
from torch import sigmoid
from torch.nn.functional import softmax
def flat_auroc_score(preds, labels):
"""
Function to calculate the roc_auc_score of our predictions vs labels
"""
pred_flat = softmax(preds, dim=1)[:, 1]
# labels_flat = np.argmax(labels, axis=1)
return roc_auc_score(labels, pred_flat.detach().cpu().numpy())
def mixup_augment(embedding1, embedding2, label1, label2, lam):
embedding_output = lam * embedding1 + (1.0 - lam) * embedding2
label_output = lam * label1 + (1.0 - lam) * label2
return (embedding_output, label_output)
def get_perm(x, args):
"""get random permutation"""
batch_size = x.size()[0]
if args.cuda and torch.cuda.is_available():
index = torch.randperm(batch_size).cuda()
else:
index = torch.randperm(batch_size)
return index
def mixup_criterion_cross_entropy(criterion, pred, y_a, y_b, lam):
return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)