|
a |
|
b/monai/configs/segres12_all_round2.py |
|
|
1 |
import numpy as np |
|
|
2 |
import torch |
|
|
3 |
from monai.transforms import ( |
|
|
4 |
Compose, |
|
|
5 |
LoadImaged, |
|
|
6 |
RandSpatialCropd, |
|
|
7 |
EnsureTyped, |
|
|
8 |
CastToTyped, |
|
|
9 |
NormalizeIntensityd, |
|
|
10 |
RandFlipd, |
|
|
11 |
CenterSpatialCropd, |
|
|
12 |
ScaleIntensityRanged, |
|
|
13 |
RandAffined, |
|
|
14 |
RandScaleIntensityd, |
|
|
15 |
RandShiftIntensityd, |
|
|
16 |
RandCoarseDropoutd, |
|
|
17 |
Rand2DElasticd, |
|
|
18 |
Lambdad, |
|
|
19 |
Resized, |
|
|
20 |
AddChanneld, |
|
|
21 |
RandGaussianNoised, |
|
|
22 |
RandGridDistortiond, |
|
|
23 |
RepeatChanneld, |
|
|
24 |
Transposed, |
|
|
25 |
OneOf, |
|
|
26 |
EnsureChannelFirstd, |
|
|
27 |
RandLambdad, |
|
|
28 |
Spacingd, |
|
|
29 |
FgBgToIndicesd, |
|
|
30 |
CropForegroundd, |
|
|
31 |
RandCropByPosNegLabeld, |
|
|
32 |
ToDeviced, |
|
|
33 |
SpatialPadd, |
|
|
34 |
|
|
|
35 |
) |
|
|
36 |
|
|
|
37 |
from default_config import basic_cfg |
|
|
38 |
|
|
|
39 |
cfg = basic_cfg |
|
|
40 |
|
|
|
41 |
# train |
|
|
42 |
cfg.train = True |
|
|
43 |
cfg.eval = False |
|
|
44 |
cfg.start_eval_epoch = 0 # when use large lr, can set a large num |
|
|
45 |
cfg.run_org_eval = False |
|
|
46 |
cfg.run_tta_val = False |
|
|
47 |
cfg.load_best_weights = False |
|
|
48 |
cfg.amp = False |
|
|
49 |
cfg.val_amp = False |
|
|
50 |
cfg.num_workers = 8 |
|
|
51 |
cfg.model_type = "segres12" |
|
|
52 |
|
|
|
53 |
# device |
|
|
54 |
cfg.gpu = 0 |
|
|
55 |
cfg.device = "cuda:%d" % cfg.gpu |
|
|
56 |
|
|
|
57 |
# lr |
|
|
58 |
# warmup_restart, cosine |
|
|
59 |
cfg.lr_mode = "cosine" |
|
|
60 |
cfg.lr = 1e-4 |
|
|
61 |
cfg.min_lr = 1e-6 |
|
|
62 |
cfg.weight_decay = 1e-6 |
|
|
63 |
cfg.epochs = 20 |
|
|
64 |
cfg.restart_epoch = 2 # only for warmup_restart |
|
|
65 |
cfg.eval_epochs = 1 |
|
|
66 |
|
|
|
67 |
cfg.finetune_lb = -1 |
|
|
68 |
|
|
|
69 |
# dataset |
|
|
70 |
cfg.img_size = (224, 224, 80) |
|
|
71 |
cfg.spacing = (1.5, 1.5, 1.5) |
|
|
72 |
cfg.batch_size = 4 |
|
|
73 |
cfg.val_batch_size = 1 |
|
|
74 |
cfg.train_cache_rate = 0.0 |
|
|
75 |
cfg.val_cache_rate = 0.0 |
|
|
76 |
cfg.gpu_cache = False |
|
|
77 |
cfg.val_gpu_cache = False |
|
|
78 |
|
|
|
79 |
# val |
|
|
80 |
cfg.roi_size = (224, 224, 80) |
|
|
81 |
cfg.sw_batch_size = 4 |
|
|
82 |
|
|
|
83 |
# model |
|
|
84 |
|
|
|
85 |
# loss |
|
|
86 |
cfg.w_dice = 1.0 |
|
|
87 |
|
|
|
88 |
cfg.output_dir = "./output/segres12_all_round2" |
|
|
89 |
|
|
|
90 |
# transforms |
|
|
91 |
cfg.train_transforms = Compose( |
|
|
92 |
[ |
|
|
93 |
LoadImaged(keys=["image", "mask"]), |
|
|
94 |
EnsureChannelFirstd(keys=["image", "mask"]), |
|
|
95 |
# Spacingd(keys=["image", "mask"], pixdim=cfg.spacing, mode=("bilinear", "nearest")), |
|
|
96 |
RandSpatialCropd( |
|
|
97 |
keys=("image", "mask"), |
|
|
98 |
roi_size=cfg.img_size, |
|
|
99 |
random_size=False, |
|
|
100 |
), |
|
|
101 |
|
|
|
102 |
Lambdad(keys="image", func=lambda x: x / x.max()), |
|
|
103 |
|
|
|
104 |
RandFlipd(keys=("image", "mask"), prob=0.5, spatial_axis=[0]), |
|
|
105 |
RandFlipd(keys=("image", "mask"), prob=0.5, spatial_axis=[1]), |
|
|
106 |
# RandFlipd(keys=("image", "mask"), prob=0.5, spatial_axis=[2]), |
|
|
107 |
RandAffined( |
|
|
108 |
keys=("image", "mask"), |
|
|
109 |
prob=0.5, |
|
|
110 |
rotate_range=np.pi / 12, |
|
|
111 |
translate_range=(cfg.img_size[0]*0.0625, cfg.img_size[1]*0.0625), |
|
|
112 |
scale_range=(0.1, 0.1), |
|
|
113 |
mode="nearest", |
|
|
114 |
padding_mode="reflection", |
|
|
115 |
), |
|
|
116 |
OneOf( |
|
|
117 |
[ |
|
|
118 |
RandGridDistortiond(keys=("image", "mask"), prob=0.5, distort_limit=(-0.05, 0.05), mode="nearest", padding_mode="reflection"), |
|
|
119 |
RandCoarseDropoutd( |
|
|
120 |
keys=("image", "mask"), |
|
|
121 |
holes=5, |
|
|
122 |
max_holes=8, |
|
|
123 |
spatial_size=(1, 1, 1), |
|
|
124 |
max_spatial_size=(12, 12, 12), |
|
|
125 |
fill_value=0.0, |
|
|
126 |
prob=0.5, |
|
|
127 |
), |
|
|
128 |
] |
|
|
129 |
), |
|
|
130 |
RandScaleIntensityd(keys="image", factors=(-0.2, 0.2), prob=0.5), |
|
|
131 |
RandShiftIntensityd(keys="image", offsets=(-0.1, 0.1), prob=0.5), |
|
|
132 |
EnsureTyped(keys=("image", "mask"), dtype=torch.float32), |
|
|
133 |
] |
|
|
134 |
) |
|
|
135 |
|
|
|
136 |
cfg.val_transforms = Compose( |
|
|
137 |
[ |
|
|
138 |
LoadImaged(keys=["image", "mask"]), |
|
|
139 |
EnsureChannelFirstd(keys=["image", "mask"]), |
|
|
140 |
# Spacingd(keys=["image", "mask"], pixdim=cfg.spacing, mode=("bilinear", "nearest")), |
|
|
141 |
Lambdad(keys="image", func=lambda x: x / x.max()), |
|
|
142 |
|
|
|
143 |
EnsureTyped(keys=("image", "mask"), dtype=torch.float32), |
|
|
144 |
# ToDeviced(keys=["image", "mask"], device="cuda:0"), |
|
|
145 |
] |
|
|
146 |
) |
|
|
147 |
|
|
|
148 |
cfg.org_val_transforms = Compose( |
|
|
149 |
[ |
|
|
150 |
LoadImaged(keys="image"), |
|
|
151 |
EnsureChannelFirstd(keys="image"), |
|
|
152 |
# Spacingd(keys="image", pixdim=cfg.spacing, mode="bilinear"), |
|
|
153 |
Lambdad(keys="image", func=lambda x: x / x.max()), |
|
|
154 |
# SpatialPadd(keys="image", spatial_size=cfg.img_size), |
|
|
155 |
EnsureTyped(keys="image", dtype=torch.float32), |
|
|
156 |
] |
|
|
157 |
) |
|
|
158 |
|
|
|
159 |
|