[4fa73e]: / pytorch / graphs / losses / cross_entropy.py

Download this file

21 lines (16 with data), 683 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
class CrossEntropyLoss(nn.Module):
def __init__(self, config=None):
super(CrossEntropyLoss, self).__init__()
if config == None:
self.loss = nn.CrossEntropyLoss()
else:
class_weights = np.load(config.class_weights)
self.loss = nn.CrossEntropyLoss(ignore_index=config.ignore_index,
weight=torch.from_numpy(class_weights.astype(np.float32)),
size_average=True, reduce=True)
def forward(self, inputs, targets):
return self.loss(inputs, targets)