a b/rocaseg/components/__init__.py
1
from torch import nn
2
from torch import optim
3
from rocaseg.components.losses import CrossEntropyLoss
4
from rocaseg.components.metrics import (confusion_matrix, dice_score,
5
                                        dice_score_from_cm)
6
from rocaseg.components.checkpoint import CheckpointHandler
7
8
9
dict_losses = {
10
    'bce_loss': nn.BCEWithLogitsLoss,
11
    'multi_ce_loss': CrossEntropyLoss,
12
}
13
14
15
dict_metrics = {
16
    'confusion_matrix': confusion_matrix,
17
    'dice_score': dice_score,
18
    'bce_loss': nn.BCELoss(),
19
}
20
21
22
dict_optimizers = {
23
    'sgd': optim.SGD,
24
    'adam': optim.Adam,
25
}
26
27
28
__all__ = [
29
    'dict_losses',
30
    'dict_metrics',
31
    'dict_optimizers',
32
    'confusion_matrix',
33
    'dice_score',
34
    'dice_score_from_cm',
35
    'CheckpointHandler',
36
]