Switch to unified view

a b/monai/configs/segres12_all.py
1
import numpy as np
2
import torch
3
from monai.transforms import (
4
    Compose,
5
    LoadImaged,
6
    RandSpatialCropd,
7
    EnsureTyped,
8
    CastToTyped,
9
    NormalizeIntensityd,
10
    RandFlipd,
11
    CenterSpatialCropd,
12
    ScaleIntensityRanged,
13
    RandAffined,
14
    RandScaleIntensityd,
15
    RandShiftIntensityd,
16
    RandCoarseDropoutd,
17
    Rand2DElasticd,
18
    Lambdad,
19
    Resized,
20
    AddChanneld,
21
    RandGaussianNoised,
22
    RandGridDistortiond,
23
    RepeatChanneld,
24
    Transposed,
25
    OneOf,
26
    EnsureChannelFirstd,
27
    RandLambdad,
28
    Spacingd,
29
    FgBgToIndicesd,
30
    CropForegroundd,
31
    RandCropByPosNegLabeld,
32
    ToDeviced,
33
    SpatialPadd,
34
35
)
36
37
from default_config import basic_cfg
38
39
cfg = basic_cfg
40
41
# train
42
cfg.train = True
43
cfg.eval = False
44
cfg.start_eval_epoch = 5  # when use large lr, can set a large num
45
cfg.run_org_eval = False
46
cfg.run_tta_val = False
47
cfg.load_best_weights = False
48
cfg.amp = False
49
cfg.val_amp = False
50
cfg.num_workers = 8
51
cfg.model_type = "segres12"
52
53
# device
54
cfg.gpu = 0
55
cfg.device = "cuda:%d" % cfg.gpu
56
57
# lr
58
# warmup_restart, cosine
59
cfg.lr_mode = "warmup_restart"
60
cfg.lr = 5e-4
61
cfg.min_lr = 2e-4
62
cfg.weight_decay = 1e-6
63
cfg.epochs = 1000
64
cfg.restart_epoch = 100  # only for warmup_restart
65
cfg.eval_epochs = 10
66
67
cfg.finetune_lb = -1
68
69
# dataset
70
cfg.img_size = (160, 160, 80)
71
cfg.spacing = (1.5, 1.5, 1.5)
72
cfg.batch_size = 4
73
cfg.val_batch_size = 1
74
cfg.train_cache_rate = 0.0
75
cfg.val_cache_rate = 0.0
76
cfg.gpu_cache = False
77
cfg.val_gpu_cache = False
78
79
# val
80
cfg.roi_size = (224, 224, 80)
81
cfg.sw_batch_size = 4
82
83
# model
84
85
# loss
86
cfg.w_dice = 1.0
87
88
cfg.output_dir = "./output/segres12_all"
89
        
90
# transforms
91
cfg.train_transforms = Compose(
92
    [
93
        LoadImaged(keys=["image", "mask"]),
94
        EnsureChannelFirstd(keys=["image", "mask"]),
95
        # Spacingd(keys=["image", "mask"], pixdim=cfg.spacing, mode=("bilinear", "nearest")),
96
        RandSpatialCropd(
97
            keys=("image", "mask"),
98
            roi_size=cfg.img_size,
99
            random_size=False,
100
        ),
101
       
102
        Lambdad(keys="image", func=lambda x: x / x.max()),
103
104
        RandFlipd(keys=("image", "mask"), prob=0.5, spatial_axis=[0]),
105
        RandFlipd(keys=("image", "mask"), prob=0.5, spatial_axis=[1]),
106
        # RandFlipd(keys=("image", "mask"), prob=0.5, spatial_axis=[2]),
107
        RandAffined(
108
            keys=("image", "mask"),
109
            prob=0.5,
110
            rotate_range=np.pi / 12,
111
            translate_range=(cfg.img_size[0]*0.0625, cfg.img_size[1]*0.0625),
112
            scale_range=(0.1, 0.1),
113
            mode="nearest",
114
            padding_mode="reflection",
115
        ),
116
        OneOf(
117
            [
118
                RandGridDistortiond(keys=("image", "mask"), prob=0.5, distort_limit=(-0.05, 0.05), mode="nearest", padding_mode="reflection"),
119
                RandCoarseDropoutd(
120
                    keys=("image", "mask"),
121
                    holes=5,
122
                    max_holes=8,
123
                    spatial_size=(1, 1, 1),
124
                    max_spatial_size=(12, 12, 12),
125
                    fill_value=0.0,
126
                    prob=0.5,
127
                ),
128
            ]
129
        ),
130
        RandScaleIntensityd(keys="image", factors=(-0.2, 0.2), prob=0.5),
131
        RandShiftIntensityd(keys="image", offsets=(-0.1, 0.1), prob=0.5),
132
        EnsureTyped(keys=("image", "mask"), dtype=torch.float32),
133
    ]
134
)
135
136
cfg.val_transforms = Compose(
137
    [
138
        LoadImaged(keys=["image", "mask"]),
139
        EnsureChannelFirstd(keys=["image", "mask"]),
140
        # Spacingd(keys=["image", "mask"], pixdim=cfg.spacing, mode=("bilinear", "nearest")),
141
        Lambdad(keys="image", func=lambda x: x / x.max()),
142
       
143
        EnsureTyped(keys=("image", "mask"), dtype=torch.float32),
144
        # ToDeviced(keys=["image", "mask"], device="cuda:0"),
145
    ]
146
)
147
148
cfg.org_val_transforms = Compose(
149
    [
150
        LoadImaged(keys="image"),
151
        EnsureChannelFirstd(keys="image"),
152
        # Spacingd(keys="image", pixdim=cfg.spacing, mode="bilinear"),
153
        Lambdad(keys="image", func=lambda x: x / x.max()),
154
        # SpatialPadd(keys="image", spatial_size=cfg.img_size),
155
        EnsureTyped(keys="image", dtype=torch.float32),
156
    ]
157
)
158
159