--- a
+++ b/monai/multilabel_train.py
@@ -0,0 +1,315 @@
+import argparse
+import gc
+import importlib
+import os
+import sys
+import shutil
+
+import numpy as np
+import pandas as pd
+import torch
+from torch import nn
+from monai.handlers.utils import from_engine
+from monai.inferers import sliding_window_inference
+from monai.data import decollate_batch
+from torch.cuda.amp import GradScaler, autocast
+from tqdm import tqdm
+
+from utils import *
+
+from monai.transforms import (
+    Compose,
+    Activations,
+    AsDiscrete,
+    Activationsd,
+    AsDiscreted,
+    KeepLargestConnectedComponentd,
+    Invertd,
+    LoadImage,
+    Transposed,
+)
+import json
+from metric import HausdorffScore
+from monai.utils import set_determinism
+from monai.losses import DiceLoss, DiceCELoss
+from monai.networks.nets import UNet, SegResNet, DynUnet
+from monai.optimizers import Novograd
+from monai.metrics import DiceMetric
+torch.backends.cudnn.enabled = True
+torch.backends.cudnn.benchmark = True
+
+
+def main(cfg):
+
+    # data sequence
+    if cfg.fold != -1:
+        cfg.data_json_dir = cfg.data_dir + f"dataset_3d_fold_{cfg.fold}.json"
+    else:
+        cfg.data_json_dir = cfg.data_dir + f"dataset_3d_all.json"
+
+    with open(cfg.data_json_dir, "r") as f:
+        cfg.data_json = json.load(f)
+
+    if cfg.fold != -1:
+        fold_dir = f"fold{cfg.fold}"
+    else:
+        fold_dir = "all"
+    os.makedirs(str(cfg.output_dir + f"/{fold_dir}/"), exist_ok=True)
+
+    # # set random seed
+    # set_determinism(cfg.seed)
+
+    train_dataset = get_train_dataset(cfg)
+    train_dataloader = get_train_dataloader(train_dataset, cfg)
+
+    val_dataset = get_val_dataset(cfg)
+    val_dataloader = get_val_dataloader(val_dataset, cfg)
+
+    print(f"run fold {cfg.fold}, train len: {len(train_dataset)}")
+
+    if cfg.model_type.startswith("segres"):
+        model = SegResNet(
+            spatial_dims = 3,
+            in_channels = 1,
+            out_channels = 3,
+            init_filters = int(cfg.model_type.replace("segres", "")),
+            norm = "BATCH",
+            act = "PRELU"
+        ).to(cfg.device)
+
+
+    print(cfg.weights)
+    if cfg.weights is not None:
+        stt = torch.load(cfg.weights, map_location = "cpu")
+        if "model" in stt:
+            stt = stt["model"]
+        if "state_dict" in stt:
+            stt = stt["state_dict"]
+            del stt["out.conv.conv.weight"], stt["out.conv.conv.bias"]
+        model.load_state_dict(stt, strict = False)
+        print(f"weights from: {cfg.weights} are loaded.")
+
+    # set optimizer, lr scheduler
+    total_steps = len(train_dataset)
+    optimizer = get_optimizer(model, cfg)
+    # optimizer = Novograd(model.parameters(), cfg.lr)
+    scheduler = get_scheduler(cfg, optimizer, total_steps)
+
+    seg_loss_func = DiceBceMultilabelLoss(w_dice=cfg.w_dice, w_bce=1-cfg.w_dice)
+    # seg_loss_func = DiceLoss(sigmoid=True, smooth_nr=0.01, smooth_dr=0.01, include_background=True, batch=True)
+    dice_metric = DiceMetric(reduction="mean")
+    hausdorff_metric = HausdorffScore(reduction="mean")
+    metric_function = [dice_metric, hausdorff_metric]
+
+    post_pred = Compose([
+        Activations(sigmoid=True),
+        AsDiscrete(threshold=0.5),
+    ])
+
+    # train and val loop
+    step = 0
+    i = 0
+    if cfg.eval is True:
+        best_val_metric = run_eval(
+            model=model,
+            val_dataloader=val_dataloader,
+            post_pred=post_pred,
+            metric_function=metric_function,
+            seg_loss_func=seg_loss_func,
+            cfg=cfg,
+            epoch=0,
+        )
+    else:
+        best_val_metric = 0.0
+    best_weights_name = "best_weights"
+    for epoch in range(cfg.epochs):
+        print("EPOCH:", epoch)
+        gc.collect()
+        if cfg.train is True:
+            run_train(
+                model=model,
+                train_dataloader=train_dataloader,
+                optimizer=optimizer,
+                scheduler=scheduler,
+                seg_loss_func=seg_loss_func,
+                cfg=cfg,
+                # writer=writer,
+                epoch=epoch,
+                step=step,
+                iteration=i,
+            )
+
+        if (epoch + 1) % cfg.eval_epochs == 0 and cfg.eval is True and epoch > cfg.start_eval_epoch:
+            val_metric = run_eval(
+                model=model,
+                val_dataloader=val_dataloader,
+                post_pred=post_pred,
+                metric_function=metric_function,
+                seg_loss_func=seg_loss_func,
+                cfg=cfg,
+                epoch=epoch,
+            )
+
+            if val_metric > best_val_metric:
+                print(f"Find better metric: val_metric {best_val_metric:.5} -> {val_metric:.5}")
+                best_val_metric = val_metric
+                checkpoint = create_checkpoint(
+                    model,
+                    optimizer,
+                    epoch,
+                    scheduler=scheduler,
+                )
+                torch.save(
+                    checkpoint,
+                    f"{cfg.output_dir}/{fold_dir}/{best_weights_name}.pth",
+                )
+            else:
+                if cfg.load_best_weights is True:
+                    try:
+                        model.load_state_dict(torch.load(f"{cfg.output_dir}/{fold_dir}/{best_weights_name}.pth")["model"])
+                        print(f"metric no improve, load the saved best weights with score: {best_val_metric}.")
+                    except:
+                        pass
+
+        if (epoch + 1) == cfg.epochs:
+            # save final best weights, with its distinct name in order to avoid mistakes.
+            if os.path.exists(f"{cfg.output_dir}/{fold_dir}/{best_weights_name}.pth"):
+                shutil.copyfile(
+                    f"{cfg.output_dir}/{fold_dir}/{best_weights_name}.pth",
+                    f"{cfg.output_dir}/{fold_dir}/{best_weights_name}_{best_val_metric:.4f}.pth",
+                )
+
+        torch.save(
+            model.state_dict(),
+            f"{cfg.output_dir}/{fold_dir}/last.pth",
+        )
+            
+
+def run_train(
+    model,
+    train_dataloader,
+    optimizer,
+    scheduler,
+    seg_loss_func,
+    cfg,
+    # writer,
+    epoch,
+    step,
+    iteration,
+):
+    model.train()
+    scaler = GradScaler()
+    progress_bar = tqdm(range(len(train_dataloader)))
+    tr_it = iter(train_dataloader)
+    dataset_size = 0
+    running_loss = 0.0
+
+    for itr in progress_bar:
+        iteration += 1
+        batch = next(tr_it)
+        inputs, masks = (
+            batch["image"].to(cfg.device),
+            batch["mask"].to(cfg.device),
+        )
+
+        step += cfg.batch_size
+
+        if cfg.amp is True:
+            with autocast():
+                outputs = model(inputs)
+                loss = seg_loss_func(outputs, masks)
+        else:
+            outputs = model(inputs)
+            loss = seg_loss_func(outputs, masks)
+        if cfg.amp is True:
+            scaler.scale(loss).backward()
+            torch.nn.utils.clip_grad_norm_(model.parameters(), 12)
+            scaler.step(optimizer)
+            scaler.update()
+        else:
+            loss.backward()
+            optimizer.step()
+
+        optimizer.zero_grad()
+        scheduler.step()
+        
+        running_loss += (loss.item() * cfg.batch_size)
+        dataset_size += cfg.batch_size
+        losses = running_loss / dataset_size
+        progress_bar.set_description(f"loss: {losses:.4f} lr: {optimizer.param_groups[0]['lr']:.6f}")
+        del batch, inputs, masks, outputs, loss
+    print(f"Train loss: {losses:.4f}")
+    torch.cuda.empty_cache()
+
+def run_eval(model, val_dataloader, post_pred, metric_function, seg_loss_func, cfg, epoch):
+
+    model.eval()
+
+    dice_metric, hausdorff_metric = metric_function
+
+    progress_bar = tqdm(range(len(val_dataloader)))
+    val_it = iter(val_dataloader)
+    with torch.no_grad():
+        for itr in progress_bar:
+            batch = next(val_it)
+            val_inputs, val_masks = (
+                batch["image"].to(cfg.device),
+                batch["mask"].to(cfg.device),
+            )
+            if cfg.val_amp is True:
+                with autocast():
+                    val_outputs = sliding_window_inference(val_inputs, cfg.roi_size, cfg.sw_batch_size, model)
+            else:
+                val_outputs = sliding_window_inference(val_inputs, cfg.roi_size, cfg.sw_batch_size, model)
+            # cal metric
+            if cfg.run_tta_val is True:
+                tta_ct = 1
+                for dims in [[2],[3],[2,3]]:
+                    flip_val_outputs = sliding_window_inference(torch.flip(val_inputs, dims=dims), cfg.roi_size, cfg.sw_batch_size, model)
+                    val_outputs += torch.flip(flip_val_outputs, dims=dims)
+                    tta_ct += 1
+                
+                val_outputs /= tta_ct
+
+            val_outputs = [post_pred(i) for i in val_outputs]
+            val_outputs = torch.stack(val_outputs)
+            # metric is slice level put (n, c, h, w, d) to (n, d, c, h, w) to (n*d, c, h, w)
+            val_outputs = val_outputs.permute([0, 4, 1, 2, 3]).flatten(0, 1)
+            val_masks = val_masks.permute([0, 4, 1, 2, 3]).flatten(0, 1)
+
+            hausdorff_metric(y_pred=val_outputs, y=val_masks)
+            dice_metric(y_pred=val_outputs, y=val_masks)
+
+            del val_outputs, val_inputs, val_masks, batch
+
+    dice_score = dice_metric.aggregate().item()
+    hausdorff_score = hausdorff_metric.aggregate().item()
+    dice_metric.reset()
+    hausdorff_metric.reset()
+
+    all_score = dice_score * 0.4 + hausdorff_score * 0.6
+    print(f"dice_score: {dice_score} hausdorff_score: {hausdorff_score} all_score: {all_score}")
+    torch.cuda.empty_cache()
+
+    return all_score
+
+
+if __name__ == "__main__":
+
+    sys.path.append("configs")
+
+    parser = argparse.ArgumentParser(description="")
+
+    parser.add_argument("-c", "--config", default="cfg_unet_multilabel", help="config filename")
+    parser.add_argument("-f", "--fold", type=int, default=0, help="fold")
+    parser.add_argument("-s", "--seed", type=int, default=20220421, help="seed")
+    parser.add_argument("-w", "--weights", default=None, help="the path of weights")
+
+    parser_args, _ = parser.parse_known_args(sys.argv)
+
+    cfg = importlib.import_module(parser_args.config).cfg
+    cfg.fold = parser_args.fold
+    cfg.seed = parser_args.seed
+    cfg.weights = parser_args.weights
+
+    main(cfg)