Diff of /src/criterion.py [000000] .. [f45789]

Switch to unified view

a b/src/criterion.py
1
import torch.nn as nn
2
3
def get_criterion(conf):
4
    criterion = conf['criterion']['name']
5
    if criterion == 'CrossEntropyLoss':
6
        criterion = nn.CrossEntropyLoss()
7
    else:
8
        print(f'Criterion {criterion} not supported.')
9
        exit()
10
    return criterion