--- a +++ b/monai/configs/segres20_all_round2.py @@ -0,0 +1,159 @@ +import numpy as np +import torch +from monai.transforms import ( + Compose, + LoadImaged, + RandSpatialCropd, + EnsureTyped, + CastToTyped, + NormalizeIntensityd, + RandFlipd, + CenterSpatialCropd, + ScaleIntensityRanged, + RandAffined, + RandScaleIntensityd, + RandShiftIntensityd, + RandCoarseDropoutd, + Rand2DElasticd, + Lambdad, + Resized, + AddChanneld, + RandGaussianNoised, + RandGridDistortiond, + RepeatChanneld, + Transposed, + OneOf, + EnsureChannelFirstd, + RandLambdad, + Spacingd, + FgBgToIndicesd, + CropForegroundd, + RandCropByPosNegLabeld, + ToDeviced, + SpatialPadd, + +) + +from default_config import basic_cfg + +cfg = basic_cfg + +# train +cfg.train = True +cfg.eval = False +cfg.start_eval_epoch = 0 # when use large lr, can set a large num +cfg.run_org_eval = False +cfg.run_tta_val = False +cfg.load_best_weights = False +cfg.amp = False +cfg.val_amp = False +cfg.num_workers = 8 +cfg.model_type = "segres20" + +# device +cfg.gpu = 0 +cfg.device = "cuda:%d" % cfg.gpu + +# lr +# warmup_restart, cosine +cfg.lr_mode = "cosine" +cfg.lr = 1e-4 +cfg.min_lr = 1e-6 +cfg.weight_decay = 1e-6 +cfg.epochs = 20 +cfg.restart_epoch = 2 # only for warmup_restart +cfg.eval_epochs = 1 + +cfg.finetune_lb = -1 + +# dataset +cfg.img_size = (224, 224, 80) +cfg.spacing = (1.5, 1.5, 1.5) +cfg.batch_size = 4 +cfg.val_batch_size = 1 +cfg.train_cache_rate = 0.0 +cfg.val_cache_rate = 0.0 +cfg.gpu_cache = False +cfg.val_gpu_cache = False + +# val +cfg.roi_size = (224, 224, 80) +cfg.sw_batch_size = 4 + +# model + +# loss +cfg.w_dice = 1.0 + +cfg.output_dir = "./output/segres20_all_round2" + +# transforms +cfg.train_transforms = Compose( + [ + LoadImaged(keys=["image", "mask"]), + EnsureChannelFirstd(keys=["image", "mask"]), + # Spacingd(keys=["image", "mask"], pixdim=cfg.spacing, mode=("bilinear", "nearest")), + RandSpatialCropd( + keys=("image", "mask"), + roi_size=cfg.img_size, + random_size=False, + ), + + Lambdad(keys="image", func=lambda x: x / x.max()), + + RandFlipd(keys=("image", "mask"), prob=0.5, spatial_axis=[0]), + RandFlipd(keys=("image", "mask"), prob=0.5, spatial_axis=[1]), + # RandFlipd(keys=("image", "mask"), prob=0.5, spatial_axis=[2]), + RandAffined( + keys=("image", "mask"), + prob=0.5, + rotate_range=np.pi / 12, + translate_range=(cfg.img_size[0]*0.0625, cfg.img_size[1]*0.0625), + scale_range=(0.1, 0.1), + mode="nearest", + padding_mode="reflection", + ), + OneOf( + [ + RandGridDistortiond(keys=("image", "mask"), prob=0.5, distort_limit=(-0.05, 0.05), mode="nearest", padding_mode="reflection"), + RandCoarseDropoutd( + keys=("image", "mask"), + holes=5, + max_holes=8, + spatial_size=(1, 1, 1), + max_spatial_size=(12, 12, 12), + fill_value=0.0, + prob=0.5, + ), + ] + ), + RandScaleIntensityd(keys="image", factors=(-0.2, 0.2), prob=0.5), + RandShiftIntensityd(keys="image", offsets=(-0.1, 0.1), prob=0.5), + EnsureTyped(keys=("image", "mask"), dtype=torch.float32), + ] +) + +cfg.val_transforms = Compose( + [ + LoadImaged(keys=["image", "mask"]), + EnsureChannelFirstd(keys=["image", "mask"]), + # Spacingd(keys=["image", "mask"], pixdim=cfg.spacing, mode=("bilinear", "nearest")), + Lambdad(keys="image", func=lambda x: x / x.max()), + + EnsureTyped(keys=("image", "mask"), dtype=torch.float32), + # ToDeviced(keys=["image", "mask"], device="cuda:0"), + ] +) + +cfg.org_val_transforms = Compose( + [ + LoadImaged(keys="image"), + EnsureChannelFirstd(keys="image"), + # Spacingd(keys="image", pixdim=cfg.spacing, mode="bilinear"), + Lambdad(keys="image", func=lambda x: x / x.max()), + # SpatialPadd(keys="image", spatial_size=cfg.img_size), + EnsureTyped(keys="image", dtype=torch.float32), + ] +) + +