[fceaa9]: / losses.py

Download this file

21 lines (13 with data), 580 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
import torch
from torch import nn
class MaskedMultiTaskCrossEntropy(nn.Module):
def forward(self, input, target):
scores = torch.sigmoid(input)
target_active = (target == 1).float() # from -1/1 to 0/1
loss_terms = -(target_active * torch.log(scores) + (1 - target_active) * torch.log(1 - scores))
missing_values_mask = (target != 0).float()
return (loss_terms * missing_values_mask).sum() / missing_values_mask.sum()
LOSS_FUNCTIONS = {
'MaskedMultiTaskCrossEntropy': MaskedMultiTaskCrossEntropy(),
'MSE': nn.MSELoss()
}