a | b/src/criterion.py | ||
---|---|---|---|
1 | import torch.nn as nn |
||
2 | |||
3 | def get_criterion(conf): |
||
4 | criterion = conf['criterion']['name'] |
||
5 | if criterion == 'CrossEntropyLoss': |
||
6 | criterion = nn.CrossEntropyLoss() |
||
7 | else: |
||
8 | print(f'Criterion {criterion} not supported.') |
||
9 | exit() |
||
10 | return criterion |