--- a +++ b/monai/multilabel_train.py @@ -0,0 +1,315 @@ +import argparse +import gc +import importlib +import os +import sys +import shutil + +import numpy as np +import pandas as pd +import torch +from torch import nn +from monai.handlers.utils import from_engine +from monai.inferers import sliding_window_inference +from monai.data import decollate_batch +from torch.cuda.amp import GradScaler, autocast +from tqdm import tqdm + +from utils import * + +from monai.transforms import ( + Compose, + Activations, + AsDiscrete, + Activationsd, + AsDiscreted, + KeepLargestConnectedComponentd, + Invertd, + LoadImage, + Transposed, +) +import json +from metric import HausdorffScore +from monai.utils import set_determinism +from monai.losses import DiceLoss, DiceCELoss +from monai.networks.nets import UNet, SegResNet, DynUnet +from monai.optimizers import Novograd +from monai.metrics import DiceMetric +torch.backends.cudnn.enabled = True +torch.backends.cudnn.benchmark = True + + +def main(cfg): + + # data sequence + if cfg.fold != -1: + cfg.data_json_dir = cfg.data_dir + f"dataset_3d_fold_{cfg.fold}.json" + else: + cfg.data_json_dir = cfg.data_dir + f"dataset_3d_all.json" + + with open(cfg.data_json_dir, "r") as f: + cfg.data_json = json.load(f) + + if cfg.fold != -1: + fold_dir = f"fold{cfg.fold}" + else: + fold_dir = "all" + os.makedirs(str(cfg.output_dir + f"/{fold_dir}/"), exist_ok=True) + + # # set random seed + # set_determinism(cfg.seed) + + train_dataset = get_train_dataset(cfg) + train_dataloader = get_train_dataloader(train_dataset, cfg) + + val_dataset = get_val_dataset(cfg) + val_dataloader = get_val_dataloader(val_dataset, cfg) + + print(f"run fold {cfg.fold}, train len: {len(train_dataset)}") + + if cfg.model_type.startswith("segres"): + model = SegResNet( + spatial_dims = 3, + in_channels = 1, + out_channels = 3, + init_filters = int(cfg.model_type.replace("segres", "")), + norm = "BATCH", + act = "PRELU" + ).to(cfg.device) + + + print(cfg.weights) + if cfg.weights is not None: + stt = torch.load(cfg.weights, map_location = "cpu") + if "model" in stt: + stt = stt["model"] + if "state_dict" in stt: + stt = stt["state_dict"] + del stt["out.conv.conv.weight"], stt["out.conv.conv.bias"] + model.load_state_dict(stt, strict = False) + print(f"weights from: {cfg.weights} are loaded.") + + # set optimizer, lr scheduler + total_steps = len(train_dataset) + optimizer = get_optimizer(model, cfg) + # optimizer = Novograd(model.parameters(), cfg.lr) + scheduler = get_scheduler(cfg, optimizer, total_steps) + + seg_loss_func = DiceBceMultilabelLoss(w_dice=cfg.w_dice, w_bce=1-cfg.w_dice) + # seg_loss_func = DiceLoss(sigmoid=True, smooth_nr=0.01, smooth_dr=0.01, include_background=True, batch=True) + dice_metric = DiceMetric(reduction="mean") + hausdorff_metric = HausdorffScore(reduction="mean") + metric_function = [dice_metric, hausdorff_metric] + + post_pred = Compose([ + Activations(sigmoid=True), + AsDiscrete(threshold=0.5), + ]) + + # train and val loop + step = 0 + i = 0 + if cfg.eval is True: + best_val_metric = run_eval( + model=model, + val_dataloader=val_dataloader, + post_pred=post_pred, + metric_function=metric_function, + seg_loss_func=seg_loss_func, + cfg=cfg, + epoch=0, + ) + else: + best_val_metric = 0.0 + best_weights_name = "best_weights" + for epoch in range(cfg.epochs): + print("EPOCH:", epoch) + gc.collect() + if cfg.train is True: + run_train( + model=model, + train_dataloader=train_dataloader, + optimizer=optimizer, + scheduler=scheduler, + seg_loss_func=seg_loss_func, + cfg=cfg, + # writer=writer, + epoch=epoch, + step=step, + iteration=i, + ) + + if (epoch + 1) % cfg.eval_epochs == 0 and cfg.eval is True and epoch > cfg.start_eval_epoch: + val_metric = run_eval( + model=model, + val_dataloader=val_dataloader, + post_pred=post_pred, + metric_function=metric_function, + seg_loss_func=seg_loss_func, + cfg=cfg, + epoch=epoch, + ) + + if val_metric > best_val_metric: + print(f"Find better metric: val_metric {best_val_metric:.5} -> {val_metric:.5}") + best_val_metric = val_metric + checkpoint = create_checkpoint( + model, + optimizer, + epoch, + scheduler=scheduler, + ) + torch.save( + checkpoint, + f"{cfg.output_dir}/{fold_dir}/{best_weights_name}.pth", + ) + else: + if cfg.load_best_weights is True: + try: + model.load_state_dict(torch.load(f"{cfg.output_dir}/{fold_dir}/{best_weights_name}.pth")["model"]) + print(f"metric no improve, load the saved best weights with score: {best_val_metric}.") + except: + pass + + if (epoch + 1) == cfg.epochs: + # save final best weights, with its distinct name in order to avoid mistakes. + if os.path.exists(f"{cfg.output_dir}/{fold_dir}/{best_weights_name}.pth"): + shutil.copyfile( + f"{cfg.output_dir}/{fold_dir}/{best_weights_name}.pth", + f"{cfg.output_dir}/{fold_dir}/{best_weights_name}_{best_val_metric:.4f}.pth", + ) + + torch.save( + model.state_dict(), + f"{cfg.output_dir}/{fold_dir}/last.pth", + ) + + +def run_train( + model, + train_dataloader, + optimizer, + scheduler, + seg_loss_func, + cfg, + # writer, + epoch, + step, + iteration, +): + model.train() + scaler = GradScaler() + progress_bar = tqdm(range(len(train_dataloader))) + tr_it = iter(train_dataloader) + dataset_size = 0 + running_loss = 0.0 + + for itr in progress_bar: + iteration += 1 + batch = next(tr_it) + inputs, masks = ( + batch["image"].to(cfg.device), + batch["mask"].to(cfg.device), + ) + + step += cfg.batch_size + + if cfg.amp is True: + with autocast(): + outputs = model(inputs) + loss = seg_loss_func(outputs, masks) + else: + outputs = model(inputs) + loss = seg_loss_func(outputs, masks) + if cfg.amp is True: + scaler.scale(loss).backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 12) + scaler.step(optimizer) + scaler.update() + else: + loss.backward() + optimizer.step() + + optimizer.zero_grad() + scheduler.step() + + running_loss += (loss.item() * cfg.batch_size) + dataset_size += cfg.batch_size + losses = running_loss / dataset_size + progress_bar.set_description(f"loss: {losses:.4f} lr: {optimizer.param_groups[0]['lr']:.6f}") + del batch, inputs, masks, outputs, loss + print(f"Train loss: {losses:.4f}") + torch.cuda.empty_cache() + +def run_eval(model, val_dataloader, post_pred, metric_function, seg_loss_func, cfg, epoch): + + model.eval() + + dice_metric, hausdorff_metric = metric_function + + progress_bar = tqdm(range(len(val_dataloader))) + val_it = iter(val_dataloader) + with torch.no_grad(): + for itr in progress_bar: + batch = next(val_it) + val_inputs, val_masks = ( + batch["image"].to(cfg.device), + batch["mask"].to(cfg.device), + ) + if cfg.val_amp is True: + with autocast(): + val_outputs = sliding_window_inference(val_inputs, cfg.roi_size, cfg.sw_batch_size, model) + else: + val_outputs = sliding_window_inference(val_inputs, cfg.roi_size, cfg.sw_batch_size, model) + # cal metric + if cfg.run_tta_val is True: + tta_ct = 1 + for dims in [[2],[3],[2,3]]: + flip_val_outputs = sliding_window_inference(torch.flip(val_inputs, dims=dims), cfg.roi_size, cfg.sw_batch_size, model) + val_outputs += torch.flip(flip_val_outputs, dims=dims) + tta_ct += 1 + + val_outputs /= tta_ct + + val_outputs = [post_pred(i) for i in val_outputs] + val_outputs = torch.stack(val_outputs) + # metric is slice level put (n, c, h, w, d) to (n, d, c, h, w) to (n*d, c, h, w) + val_outputs = val_outputs.permute([0, 4, 1, 2, 3]).flatten(0, 1) + val_masks = val_masks.permute([0, 4, 1, 2, 3]).flatten(0, 1) + + hausdorff_metric(y_pred=val_outputs, y=val_masks) + dice_metric(y_pred=val_outputs, y=val_masks) + + del val_outputs, val_inputs, val_masks, batch + + dice_score = dice_metric.aggregate().item() + hausdorff_score = hausdorff_metric.aggregate().item() + dice_metric.reset() + hausdorff_metric.reset() + + all_score = dice_score * 0.4 + hausdorff_score * 0.6 + print(f"dice_score: {dice_score} hausdorff_score: {hausdorff_score} all_score: {all_score}") + torch.cuda.empty_cache() + + return all_score + + +if __name__ == "__main__": + + sys.path.append("configs") + + parser = argparse.ArgumentParser(description="") + + parser.add_argument("-c", "--config", default="cfg_unet_multilabel", help="config filename") + parser.add_argument("-f", "--fold", type=int, default=0, help="fold") + parser.add_argument("-s", "--seed", type=int, default=20220421, help="seed") + parser.add_argument("-w", "--weights", default=None, help="the path of weights") + + parser_args, _ = parser.parse_known_args(sys.argv) + + cfg = importlib.import_module(parser_args.config).cfg + cfg.fold = parser_args.fold + cfg.seed = parser_args.seed + cfg.weights = parser_args.weights + + main(cfg)