Diff of /monai/utils.py [000000] .. [4e96d3]

Switch to side-by-side view

--- a
+++ b/monai/utils.py
@@ -0,0 +1,144 @@
+import torch
+from torch import nn
+from torch.optim import lr_scheduler
+from monai.utils import set_determinism
+from torch import optim
+from monai.data import CacheDataset, DataLoader, ThreadDataLoader, PersistentDataset
+from torch.nn.modules.loss import _Loss
+from monai.utils import LossReduction
+from monai.losses import DiceLoss
+from monai.transforms import LoadImage
+from monai.optimizers.lr_scheduler import WarmupCosineSchedule
+
+
+class DiceBceMultilabelLoss(_Loss):
+    def __init__(
+        self,
+        w_dice = 0.5,
+        w_bce = 0.5,
+        reduction = LossReduction.MEAN,
+    ):
+        super().__init__(reduction=LossReduction(reduction).value)
+        self.w_dice = w_dice
+        self.w_bce = w_bce
+        self.dice_loss = DiceLoss(sigmoid=True, smooth_nr=0.01, smooth_dr=0.01, include_background=True, batch=True, squared_pred=True)
+        self.bce_loss = nn.BCEWithLogitsLoss()
+    def forward(self, pred, label):
+        
+        loss = self.dice_loss(pred, label) * self.w_dice + self.bce_loss(pred, label) * self.w_bce
+        return loss
+
+
+def get_train_dataloader(train_dataset, cfg):
+
+    if cfg.gpu_cache:
+        train_dataloader = ThreadDataLoader(
+            train_dataset,
+            shuffle=True,
+            batch_size=cfg.batch_size,
+            num_workers=0,
+            drop_last=True,
+        )
+        return train_dataloader 
+
+    train_dataloader = DataLoader(
+        train_dataset,
+        shuffle=True,
+        batch_size=cfg.batch_size,
+        num_workers=cfg.num_workers,
+        drop_last=True,
+    )
+    return train_dataloader
+
+
+def get_val_dataloader(val_dataset, cfg):
+    if cfg.val_gpu_cache:
+        val_dataloader = ThreadDataLoader(
+            val_dataset,
+            batch_size=cfg.val_batch_size,
+            num_workers=0,
+        )
+        return val_dataloader
+
+    val_dataloader = DataLoader(
+        val_dataset,
+        batch_size=cfg.val_batch_size,
+        num_workers=cfg.num_workers,
+    )
+    return val_dataloader
+
+def get_train_dataset(cfg):
+    train_ds = CacheDataset(
+        data=cfg.data_json["train"],
+        transform=cfg.train_transforms,
+        cache_rate=cfg.train_cache_rate,
+        num_workers=cfg.num_workers,
+        copy_cache=False,
+    )
+    # train_ds = PersistentDataset(
+    #     data=cfg.data_json["train"],
+    #     transform=cfg.train_transforms,
+    #     cache_dir="cache_data",
+    # )
+    return train_ds
+
+def get_val_dataset(cfg):
+    val_ds = CacheDataset(
+        data=cfg.data_json["val"],
+        transform=cfg.val_transforms,
+        cache_rate=cfg.val_cache_rate,
+        num_workers=cfg.num_workers,
+        copy_cache=False,
+    )
+    return val_ds
+
+def get_val_org_dataset(cfg):
+    val_ds = CacheDataset(
+        data=cfg.data_json["val"],
+        transform=cfg.org_val_transforms,
+        cache_rate=cfg.val_cache_rate,
+        num_workers=cfg.num_workers,
+        copy_cache=False,
+    )
+    return val_ds
+
+def get_optimizer(model, cfg):
+
+    params = model.parameters()
+    optimizer = optim.Adam(params, lr=cfg.lr, weight_decay=cfg.weight_decay)
+
+    return optimizer
+
+def get_scheduler(cfg, optimizer, total_steps):
+
+    if cfg.lr_mode == "cosine":
+        scheduler = lr_scheduler.CosineAnnealingLR(
+            optimizer,
+            T_max=cfg.epochs * (total_steps // cfg.batch_size),
+            eta_min=cfg.min_lr,
+        )
+
+    elif cfg.lr_mode == "warmup_restart":
+        scheduler = lr_scheduler.CosineAnnealingWarmRestarts(
+            optimizer,
+            T_0=cfg.restart_epoch * (total_steps // cfg.batch_size),
+            T_mult=1,
+            eta_min=cfg.min_lr,
+        )
+
+    return scheduler
+
+
+def create_checkpoint(model, optimizer, epoch, scheduler=None, scaler=None):
+    checkpoint = {
+        "model": model.state_dict(),
+        "optimizer": optimizer.state_dict(),
+        "epoch": epoch,
+    }
+
+    if scheduler is not None:
+        checkpoint["scheduler"] = scheduler.state_dict()
+
+    if scaler is not None:
+        checkpoint["scaler"] = scaler.state_dict()
+    return checkpoint