[4fa73e]: / pytorch / graphs / losses / example.py

Download this file

17 lines (12 with data), 401 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
"""
An example for loss class definition, that will be used in the agent
"""
import torch.nn as nn
class CrossEntropyLoss3d(nn.Module):
def __init__(self, weight=None, size_average=True):
super(CrossEntropyLoss2d, self).__init__()
self.loss = nn.CrossEntropyLoss()
def forward(self, logits, labels):
loss = self.loss(logits, labels)
return loss
# Dice loss