[f1e01c]: / monai / utils.py

Download this file

145 lines (120 with data), 4.0 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import torch
from torch import nn
from torch.optim import lr_scheduler
from monai.utils import set_determinism
from torch import optim
from monai.data import CacheDataset, DataLoader, ThreadDataLoader, PersistentDataset
from torch.nn.modules.loss import _Loss
from monai.utils import LossReduction
from monai.losses import DiceLoss
from monai.transforms import LoadImage
from monai.optimizers.lr_scheduler import WarmupCosineSchedule
class DiceBceMultilabelLoss(_Loss):
def __init__(
self,
w_dice = 0.5,
w_bce = 0.5,
reduction = LossReduction.MEAN,
):
super().__init__(reduction=LossReduction(reduction).value)
self.w_dice = w_dice
self.w_bce = w_bce
self.dice_loss = DiceLoss(sigmoid=True, smooth_nr=0.01, smooth_dr=0.01, include_background=True, batch=True, squared_pred=True)
self.bce_loss = nn.BCEWithLogitsLoss()
def forward(self, pred, label):
loss = self.dice_loss(pred, label) * self.w_dice + self.bce_loss(pred, label) * self.w_bce
return loss
def get_train_dataloader(train_dataset, cfg):
if cfg.gpu_cache:
train_dataloader = ThreadDataLoader(
train_dataset,
shuffle=True,
batch_size=cfg.batch_size,
num_workers=0,
drop_last=True,
)
return train_dataloader
train_dataloader = DataLoader(
train_dataset,
shuffle=True,
batch_size=cfg.batch_size,
num_workers=cfg.num_workers,
drop_last=True,
)
return train_dataloader
def get_val_dataloader(val_dataset, cfg):
if cfg.val_gpu_cache:
val_dataloader = ThreadDataLoader(
val_dataset,
batch_size=cfg.val_batch_size,
num_workers=0,
)
return val_dataloader
val_dataloader = DataLoader(
val_dataset,
batch_size=cfg.val_batch_size,
num_workers=cfg.num_workers,
)
return val_dataloader
def get_train_dataset(cfg):
train_ds = CacheDataset(
data=cfg.data_json["train"],
transform=cfg.train_transforms,
cache_rate=cfg.train_cache_rate,
num_workers=cfg.num_workers,
copy_cache=False,
)
# train_ds = PersistentDataset(
# data=cfg.data_json["train"],
# transform=cfg.train_transforms,
# cache_dir="cache_data",
# )
return train_ds
def get_val_dataset(cfg):
val_ds = CacheDataset(
data=cfg.data_json["val"],
transform=cfg.val_transforms,
cache_rate=cfg.val_cache_rate,
num_workers=cfg.num_workers,
copy_cache=False,
)
return val_ds
def get_val_org_dataset(cfg):
val_ds = CacheDataset(
data=cfg.data_json["val"],
transform=cfg.org_val_transforms,
cache_rate=cfg.val_cache_rate,
num_workers=cfg.num_workers,
copy_cache=False,
)
return val_ds
def get_optimizer(model, cfg):
params = model.parameters()
optimizer = optim.Adam(params, lr=cfg.lr, weight_decay=cfg.weight_decay)
return optimizer
def get_scheduler(cfg, optimizer, total_steps):
if cfg.lr_mode == "cosine":
scheduler = lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=cfg.epochs * (total_steps // cfg.batch_size),
eta_min=cfg.min_lr,
)
elif cfg.lr_mode == "warmup_restart":
scheduler = lr_scheduler.CosineAnnealingWarmRestarts(
optimizer,
T_0=cfg.restart_epoch * (total_steps // cfg.batch_size),
T_mult=1,
eta_min=cfg.min_lr,
)
return scheduler
def create_checkpoint(model, optimizer, epoch, scheduler=None, scaler=None):
checkpoint = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"epoch": epoch,
}
if scheduler is not None:
checkpoint["scheduler"] = scheduler.state_dict()
if scaler is not None:
checkpoint["scaler"] = scaler.state_dict()
return checkpoint