|
a |
|
b/monai/utils.py |
|
|
1 |
import torch |
|
|
2 |
from torch import nn |
|
|
3 |
from torch.optim import lr_scheduler |
|
|
4 |
from monai.utils import set_determinism |
|
|
5 |
from torch import optim |
|
|
6 |
from monai.data import CacheDataset, DataLoader, ThreadDataLoader, PersistentDataset |
|
|
7 |
from torch.nn.modules.loss import _Loss |
|
|
8 |
from monai.utils import LossReduction |
|
|
9 |
from monai.losses import DiceLoss |
|
|
10 |
from monai.transforms import LoadImage |
|
|
11 |
from monai.optimizers.lr_scheduler import WarmupCosineSchedule |
|
|
12 |
|
|
|
13 |
|
|
|
14 |
class DiceBceMultilabelLoss(_Loss): |
|
|
15 |
def __init__( |
|
|
16 |
self, |
|
|
17 |
w_dice = 0.5, |
|
|
18 |
w_bce = 0.5, |
|
|
19 |
reduction = LossReduction.MEAN, |
|
|
20 |
): |
|
|
21 |
super().__init__(reduction=LossReduction(reduction).value) |
|
|
22 |
self.w_dice = w_dice |
|
|
23 |
self.w_bce = w_bce |
|
|
24 |
self.dice_loss = DiceLoss(sigmoid=True, smooth_nr=0.01, smooth_dr=0.01, include_background=True, batch=True, squared_pred=True) |
|
|
25 |
self.bce_loss = nn.BCEWithLogitsLoss() |
|
|
26 |
def forward(self, pred, label): |
|
|
27 |
|
|
|
28 |
loss = self.dice_loss(pred, label) * self.w_dice + self.bce_loss(pred, label) * self.w_bce |
|
|
29 |
return loss |
|
|
30 |
|
|
|
31 |
|
|
|
32 |
def get_train_dataloader(train_dataset, cfg): |
|
|
33 |
|
|
|
34 |
if cfg.gpu_cache: |
|
|
35 |
train_dataloader = ThreadDataLoader( |
|
|
36 |
train_dataset, |
|
|
37 |
shuffle=True, |
|
|
38 |
batch_size=cfg.batch_size, |
|
|
39 |
num_workers=0, |
|
|
40 |
drop_last=True, |
|
|
41 |
) |
|
|
42 |
return train_dataloader |
|
|
43 |
|
|
|
44 |
train_dataloader = DataLoader( |
|
|
45 |
train_dataset, |
|
|
46 |
shuffle=True, |
|
|
47 |
batch_size=cfg.batch_size, |
|
|
48 |
num_workers=cfg.num_workers, |
|
|
49 |
drop_last=True, |
|
|
50 |
) |
|
|
51 |
return train_dataloader |
|
|
52 |
|
|
|
53 |
|
|
|
54 |
def get_val_dataloader(val_dataset, cfg): |
|
|
55 |
if cfg.val_gpu_cache: |
|
|
56 |
val_dataloader = ThreadDataLoader( |
|
|
57 |
val_dataset, |
|
|
58 |
batch_size=cfg.val_batch_size, |
|
|
59 |
num_workers=0, |
|
|
60 |
) |
|
|
61 |
return val_dataloader |
|
|
62 |
|
|
|
63 |
val_dataloader = DataLoader( |
|
|
64 |
val_dataset, |
|
|
65 |
batch_size=cfg.val_batch_size, |
|
|
66 |
num_workers=cfg.num_workers, |
|
|
67 |
) |
|
|
68 |
return val_dataloader |
|
|
69 |
|
|
|
70 |
def get_train_dataset(cfg): |
|
|
71 |
train_ds = CacheDataset( |
|
|
72 |
data=cfg.data_json["train"], |
|
|
73 |
transform=cfg.train_transforms, |
|
|
74 |
cache_rate=cfg.train_cache_rate, |
|
|
75 |
num_workers=cfg.num_workers, |
|
|
76 |
copy_cache=False, |
|
|
77 |
) |
|
|
78 |
# train_ds = PersistentDataset( |
|
|
79 |
# data=cfg.data_json["train"], |
|
|
80 |
# transform=cfg.train_transforms, |
|
|
81 |
# cache_dir="cache_data", |
|
|
82 |
# ) |
|
|
83 |
return train_ds |
|
|
84 |
|
|
|
85 |
def get_val_dataset(cfg): |
|
|
86 |
val_ds = CacheDataset( |
|
|
87 |
data=cfg.data_json["val"], |
|
|
88 |
transform=cfg.val_transforms, |
|
|
89 |
cache_rate=cfg.val_cache_rate, |
|
|
90 |
num_workers=cfg.num_workers, |
|
|
91 |
copy_cache=False, |
|
|
92 |
) |
|
|
93 |
return val_ds |
|
|
94 |
|
|
|
95 |
def get_val_org_dataset(cfg): |
|
|
96 |
val_ds = CacheDataset( |
|
|
97 |
data=cfg.data_json["val"], |
|
|
98 |
transform=cfg.org_val_transforms, |
|
|
99 |
cache_rate=cfg.val_cache_rate, |
|
|
100 |
num_workers=cfg.num_workers, |
|
|
101 |
copy_cache=False, |
|
|
102 |
) |
|
|
103 |
return val_ds |
|
|
104 |
|
|
|
105 |
def get_optimizer(model, cfg): |
|
|
106 |
|
|
|
107 |
params = model.parameters() |
|
|
108 |
optimizer = optim.Adam(params, lr=cfg.lr, weight_decay=cfg.weight_decay) |
|
|
109 |
|
|
|
110 |
return optimizer |
|
|
111 |
|
|
|
112 |
def get_scheduler(cfg, optimizer, total_steps): |
|
|
113 |
|
|
|
114 |
if cfg.lr_mode == "cosine": |
|
|
115 |
scheduler = lr_scheduler.CosineAnnealingLR( |
|
|
116 |
optimizer, |
|
|
117 |
T_max=cfg.epochs * (total_steps // cfg.batch_size), |
|
|
118 |
eta_min=cfg.min_lr, |
|
|
119 |
) |
|
|
120 |
|
|
|
121 |
elif cfg.lr_mode == "warmup_restart": |
|
|
122 |
scheduler = lr_scheduler.CosineAnnealingWarmRestarts( |
|
|
123 |
optimizer, |
|
|
124 |
T_0=cfg.restart_epoch * (total_steps // cfg.batch_size), |
|
|
125 |
T_mult=1, |
|
|
126 |
eta_min=cfg.min_lr, |
|
|
127 |
) |
|
|
128 |
|
|
|
129 |
return scheduler |
|
|
130 |
|
|
|
131 |
|
|
|
132 |
def create_checkpoint(model, optimizer, epoch, scheduler=None, scaler=None): |
|
|
133 |
checkpoint = { |
|
|
134 |
"model": model.state_dict(), |
|
|
135 |
"optimizer": optimizer.state_dict(), |
|
|
136 |
"epoch": epoch, |
|
|
137 |
} |
|
|
138 |
|
|
|
139 |
if scheduler is not None: |
|
|
140 |
checkpoint["scheduler"] = scheduler.state_dict() |
|
|
141 |
|
|
|
142 |
if scaler is not None: |
|
|
143 |
checkpoint["scaler"] = scaler.state_dict() |
|
|
144 |
return checkpoint |