Diff of /monai/utils.py [000000] .. [4e96d3]

Switch to unified view

a b/monai/utils.py
1
import torch
2
from torch import nn
3
from torch.optim import lr_scheduler
4
from monai.utils import set_determinism
5
from torch import optim
6
from monai.data import CacheDataset, DataLoader, ThreadDataLoader, PersistentDataset
7
from torch.nn.modules.loss import _Loss
8
from monai.utils import LossReduction
9
from monai.losses import DiceLoss
10
from monai.transforms import LoadImage
11
from monai.optimizers.lr_scheduler import WarmupCosineSchedule
12
13
14
class DiceBceMultilabelLoss(_Loss):
15
    def __init__(
16
        self,
17
        w_dice = 0.5,
18
        w_bce = 0.5,
19
        reduction = LossReduction.MEAN,
20
    ):
21
        super().__init__(reduction=LossReduction(reduction).value)
22
        self.w_dice = w_dice
23
        self.w_bce = w_bce
24
        self.dice_loss = DiceLoss(sigmoid=True, smooth_nr=0.01, smooth_dr=0.01, include_background=True, batch=True, squared_pred=True)
25
        self.bce_loss = nn.BCEWithLogitsLoss()
26
    def forward(self, pred, label):
27
        
28
        loss = self.dice_loss(pred, label) * self.w_dice + self.bce_loss(pred, label) * self.w_bce
29
        return loss
30
31
32
def get_train_dataloader(train_dataset, cfg):
33
34
    if cfg.gpu_cache:
35
        train_dataloader = ThreadDataLoader(
36
            train_dataset,
37
            shuffle=True,
38
            batch_size=cfg.batch_size,
39
            num_workers=0,
40
            drop_last=True,
41
        )
42
        return train_dataloader 
43
44
    train_dataloader = DataLoader(
45
        train_dataset,
46
        shuffle=True,
47
        batch_size=cfg.batch_size,
48
        num_workers=cfg.num_workers,
49
        drop_last=True,
50
    )
51
    return train_dataloader
52
53
54
def get_val_dataloader(val_dataset, cfg):
55
    if cfg.val_gpu_cache:
56
        val_dataloader = ThreadDataLoader(
57
            val_dataset,
58
            batch_size=cfg.val_batch_size,
59
            num_workers=0,
60
        )
61
        return val_dataloader
62
63
    val_dataloader = DataLoader(
64
        val_dataset,
65
        batch_size=cfg.val_batch_size,
66
        num_workers=cfg.num_workers,
67
    )
68
    return val_dataloader
69
70
def get_train_dataset(cfg):
71
    train_ds = CacheDataset(
72
        data=cfg.data_json["train"],
73
        transform=cfg.train_transforms,
74
        cache_rate=cfg.train_cache_rate,
75
        num_workers=cfg.num_workers,
76
        copy_cache=False,
77
    )
78
    # train_ds = PersistentDataset(
79
    #     data=cfg.data_json["train"],
80
    #     transform=cfg.train_transforms,
81
    #     cache_dir="cache_data",
82
    # )
83
    return train_ds
84
85
def get_val_dataset(cfg):
86
    val_ds = CacheDataset(
87
        data=cfg.data_json["val"],
88
        transform=cfg.val_transforms,
89
        cache_rate=cfg.val_cache_rate,
90
        num_workers=cfg.num_workers,
91
        copy_cache=False,
92
    )
93
    return val_ds
94
95
def get_val_org_dataset(cfg):
96
    val_ds = CacheDataset(
97
        data=cfg.data_json["val"],
98
        transform=cfg.org_val_transforms,
99
        cache_rate=cfg.val_cache_rate,
100
        num_workers=cfg.num_workers,
101
        copy_cache=False,
102
    )
103
    return val_ds
104
105
def get_optimizer(model, cfg):
106
107
    params = model.parameters()
108
    optimizer = optim.Adam(params, lr=cfg.lr, weight_decay=cfg.weight_decay)
109
110
    return optimizer
111
112
def get_scheduler(cfg, optimizer, total_steps):
113
114
    if cfg.lr_mode == "cosine":
115
        scheduler = lr_scheduler.CosineAnnealingLR(
116
            optimizer,
117
            T_max=cfg.epochs * (total_steps // cfg.batch_size),
118
            eta_min=cfg.min_lr,
119
        )
120
121
    elif cfg.lr_mode == "warmup_restart":
122
        scheduler = lr_scheduler.CosineAnnealingWarmRestarts(
123
            optimizer,
124
            T_0=cfg.restart_epoch * (total_steps // cfg.batch_size),
125
            T_mult=1,
126
            eta_min=cfg.min_lr,
127
        )
128
129
    return scheduler
130
131
132
def create_checkpoint(model, optimizer, epoch, scheduler=None, scaler=None):
133
    checkpoint = {
134
        "model": model.state_dict(),
135
        "optimizer": optimizer.state_dict(),
136
        "epoch": epoch,
137
    }
138
139
    if scheduler is not None:
140
        checkpoint["scheduler"] = scheduler.state_dict()
141
142
    if scaler is not None:
143
        checkpoint["scaler"] = scaler.state_dict()
144
    return checkpoint