--- a +++ b/rocaseg/components/__init__.py @@ -0,0 +1,36 @@ +from torch import nn +from torch import optim +from rocaseg.components.losses import CrossEntropyLoss +from rocaseg.components.metrics import (confusion_matrix, dice_score, + dice_score_from_cm) +from rocaseg.components.checkpoint import CheckpointHandler + + +dict_losses = { + 'bce_loss': nn.BCEWithLogitsLoss, + 'multi_ce_loss': CrossEntropyLoss, +} + + +dict_metrics = { + 'confusion_matrix': confusion_matrix, + 'dice_score': dice_score, + 'bce_loss': nn.BCELoss(), +} + + +dict_optimizers = { + 'sgd': optim.SGD, + 'adam': optim.Adam, +} + + +__all__ = [ + 'dict_losses', + 'dict_metrics', + 'dict_optimizers', + 'confusion_matrix', + 'dice_score', + 'dice_score_from_cm', + 'CheckpointHandler', +]