[f45789]: / src / criterion.py

Download this file

11 lines (9 with data), 273 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
import torch.nn as nn
def get_criterion(conf):
criterion = conf['criterion']['name']
if criterion == 'CrossEntropyLoss':
criterion = nn.CrossEntropyLoss()
else:
print(f'Criterion {criterion} not supported.')
exit()
return criterion