|
a |
|
b/segmentation/initialize_train.py |
|
|
1 |
''' |
|
|
2 |
Copyright (c) Microsoft Corporation. All rights reserved. |
|
|
3 |
Licensed under the MIT License. |
|
|
4 |
''' |
|
|
5 |
from monai.transforms import ( |
|
|
6 |
EnsureChannelFirstd, |
|
|
7 |
Compose, |
|
|
8 |
CropForegroundd, |
|
|
9 |
LoadImaged, |
|
|
10 |
Orientationd, |
|
|
11 |
RandCropByPosNegLabeld, |
|
|
12 |
DeleteItemsd, |
|
|
13 |
Spacingd, |
|
|
14 |
RandAffined, |
|
|
15 |
ConcatItemsd, |
|
|
16 |
ScaleIntensityRanged, |
|
|
17 |
ResizeWithPadOrCropd, |
|
|
18 |
Invertd, |
|
|
19 |
AsDiscreted, |
|
|
20 |
SaveImaged, |
|
|
21 |
|
|
|
22 |
) |
|
|
23 |
from monai.networks.nets import UNet, SegResNet, DynUNet, SwinUNETR, UNETR, AttentionUnet |
|
|
24 |
from monai.networks.layers import Norm |
|
|
25 |
from monai.metrics import DiceMetric |
|
|
26 |
from monai.losses import DiceLoss |
|
|
27 |
import torch |
|
|
28 |
import matplotlib.pyplot as plt |
|
|
29 |
from glob import glob |
|
|
30 |
import pandas as pd |
|
|
31 |
import numpy as np |
|
|
32 |
from torch.optim.lr_scheduler import CosineAnnealingLR |
|
|
33 |
import os |
|
|
34 |
import sys |
|
|
35 |
config_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..") |
|
|
36 |
sys.path.append(config_dir) |
|
|
37 |
from config import DATA_FOLDER, WORKING_FOLDER |
|
|
38 |
#%% |
|
|
39 |
def convert_to_4digits(str_num): |
|
|
40 |
if len(str_num) == 1: |
|
|
41 |
new_num = '000' + str_num |
|
|
42 |
elif len(str_num) == 2: |
|
|
43 |
new_num = '00' + str_num |
|
|
44 |
elif len(str_num) == 3: |
|
|
45 |
new_num = '0' + str_num |
|
|
46 |
else: |
|
|
47 |
new_num = str_num |
|
|
48 |
return new_num |
|
|
49 |
|
|
|
50 |
def create_dictionary_ctptgt(ctpaths, ptpaths, gtpaths): |
|
|
51 |
data = [] |
|
|
52 |
for i in range(len(gtpaths)): |
|
|
53 |
ctpath = ctpaths[i] |
|
|
54 |
ptpath = ptpaths[i] |
|
|
55 |
gtpath = gtpaths[i] |
|
|
56 |
data.append({'CT':ctpath, 'PT':ptpath, 'GT':gtpath}) |
|
|
57 |
return data |
|
|
58 |
|
|
|
59 |
def remove_all_extensions(filename): |
|
|
60 |
while True: |
|
|
61 |
name, ext = os.path.splitext(filename) |
|
|
62 |
if ext == '': |
|
|
63 |
return name |
|
|
64 |
filename = name |
|
|
65 |
#%% |
|
|
66 |
def create_data_split_files(): |
|
|
67 |
"""Creates filepaths data for training/validation and test images and saves |
|
|
68 |
them as `train_filepaths.csv` and `test_filepaths.csv` files under WORKING_FOLDER/data_split/; |
|
|
69 |
all training images will be assigned a FoldID specifying which fold (out of the 5 folds) |
|
|
70 |
the image belongs to. If the `train_filepaths.csv` and `test_filepaths.csv` already exist, |
|
|
71 |
this function is skipped |
|
|
72 |
""" |
|
|
73 |
train_filepaths = os.path.join(WORKING_FOLDER, 'data_split', 'train_filepaths.csv') |
|
|
74 |
test_filepaths = os.path.join(WORKING_FOLDER, 'data_split', 'test_filepaths.csv') |
|
|
75 |
if os.path.exists(train_filepaths) and os.path.exists(test_filepaths): |
|
|
76 |
return |
|
|
77 |
else: |
|
|
78 |
data_split_folder = os.path.join(WORKING_FOLDER, 'data_split') |
|
|
79 |
os.makedirs(data_split_folder, exist_ok=True) |
|
|
80 |
|
|
|
81 |
imagesTr = os.path.join(DATA_FOLDER, 'imagesTr') |
|
|
82 |
labelsTr = os.path.join(DATA_FOLDER, 'labelsTr') |
|
|
83 |
|
|
|
84 |
ctpaths = sorted(glob(os.path.join(imagesTr, '*0000.nii.gz'))) |
|
|
85 |
ptpaths = sorted(glob(os.path.join(imagesTr, '*0001.nii.gz'))) |
|
|
86 |
gtpaths = sorted(glob(os.path.join(labelsTr, '*.nii.gz'))) |
|
|
87 |
imageids = [remove_all_extensions(os.path.basename(path)) for path in gtpaths] |
|
|
88 |
|
|
|
89 |
n_folds = 5 |
|
|
90 |
part_size = len(imageids) // n_folds |
|
|
91 |
remaining_elements = len(imageids) % n_folds |
|
|
92 |
start = 0 |
|
|
93 |
train_folds = [] |
|
|
94 |
for i in range(n_folds): |
|
|
95 |
end = start + part_size + (1 if i < remaining_elements else 0) |
|
|
96 |
train_folds.append(imageids[start:end]) |
|
|
97 |
start = end |
|
|
98 |
|
|
|
99 |
fold_sizes = [len(fold) for fold in train_folds] |
|
|
100 |
foldids = [fold_sizes[i]*[i] for i in range(len(fold_sizes))] |
|
|
101 |
foldids = [item for sublist in foldids for item in sublist] |
|
|
102 |
|
|
|
103 |
trainfolds_data = np.column_stack((imageids, foldids, ctpaths, ptpaths, gtpaths)) |
|
|
104 |
train_df = pd.DataFrame(trainfolds_data, columns=['ImageID', 'FoldID', 'CTPATH', 'PTPATH', 'GTPATH']) |
|
|
105 |
|
|
|
106 |
train_df.to_csv(train_filepaths, index=False) |
|
|
107 |
|
|
|
108 |
imagesTs = os.path.join(DATA_FOLDER, 'imagesTs') |
|
|
109 |
labelsTs = os.path.join(DATA_FOLDER, 'labelsTs') |
|
|
110 |
ctpaths_test = sorted(glob(os.path.join(imagesTs, '*0000.nii.gz'))) |
|
|
111 |
ptpaths_test = sorted(glob(os.path.join(imagesTs, '*0001.nii.gz'))) |
|
|
112 |
gtpaths_test = sorted(glob(os.path.join(labelsTs, '*.nii.gz'))) |
|
|
113 |
imageids_test = [remove_all_extensions(os.path.basename(path)) for path in gtpaths_test] |
|
|
114 |
test_data = np.column_stack((imageids_test, ctpaths_test, ptpaths_test, gtpaths_test)) |
|
|
115 |
test_df = pd.DataFrame(test_data, columns=['ImageID', 'CTPATH', 'PTPATH', 'GTPATH']) |
|
|
116 |
test_df.to_csv(test_filepaths, index=False) |
|
|
117 |
|
|
|
118 |
#%% |
|
|
119 |
def get_train_valid_data_in_dict_format(fold): |
|
|
120 |
trainvalid_fpath = os.path.join(WORKING_FOLDER, 'data_split/train_filepaths.csv') |
|
|
121 |
trainvalid_df = pd.read_csv(trainvalid_fpath) |
|
|
122 |
train_df = trainvalid_df[trainvalid_df['FoldID'] != fold] |
|
|
123 |
valid_df = trainvalid_df[trainvalid_df['FoldID'] == fold] |
|
|
124 |
|
|
|
125 |
ctpaths_train, ptpaths_train, gtpaths_train = list(train_df['CTPATH'].values), list(train_df['PTPATH'].values), list(train_df['GTPATH'].values) |
|
|
126 |
ctpaths_valid, ptpaths_valid, gtpaths_valid = list(valid_df['CTPATH'].values), list(valid_df['PTPATH'].values), list(valid_df['GTPATH'].values) |
|
|
127 |
|
|
|
128 |
train_data = create_dictionary_ctptgt(ctpaths_train, ptpaths_train, gtpaths_train) |
|
|
129 |
valid_data = create_dictionary_ctptgt(ctpaths_valid, ptpaths_valid, gtpaths_valid) |
|
|
130 |
|
|
|
131 |
return train_data, valid_data |
|
|
132 |
|
|
|
133 |
#%% |
|
|
134 |
def get_test_data_in_dict_format(): |
|
|
135 |
test_fpaths = os.path.join(WORKING_FOLDER, 'data_split/test_filepaths.csv') |
|
|
136 |
test_df = pd.read_csv(test_fpaths) |
|
|
137 |
ctpaths_test, ptpaths_test, gtpaths_test = list(test_df['CTPATH'].values), list(test_df['PTPATH'].values), list(test_df['GTPATH'].values) |
|
|
138 |
test_data = create_dictionary_ctptgt(ctpaths_test, ptpaths_test, gtpaths_test) |
|
|
139 |
return test_data |
|
|
140 |
|
|
|
141 |
def get_spatial_size(input_patch_size=192): |
|
|
142 |
trsz = input_patch_size |
|
|
143 |
return (trsz, trsz, trsz) |
|
|
144 |
|
|
|
145 |
def get_spacing(): |
|
|
146 |
spc = 2 |
|
|
147 |
return (spc, spc, spc) |
|
|
148 |
|
|
|
149 |
def get_train_transforms(input_patch_size=192): |
|
|
150 |
spatialsize = get_spatial_size(input_patch_size) |
|
|
151 |
spacing = get_spacing() |
|
|
152 |
mod_keys = ['CT', 'PT', 'GT'] |
|
|
153 |
train_transforms = Compose( |
|
|
154 |
[ |
|
|
155 |
LoadImaged(keys=mod_keys, image_only=True), |
|
|
156 |
EnsureChannelFirstd(keys=mod_keys), |
|
|
157 |
CropForegroundd(keys=mod_keys, source_key='CT'), |
|
|
158 |
ScaleIntensityRanged(keys=['CT'], a_min=-154, a_max=325, b_min=0, b_max=1, clip=True), |
|
|
159 |
Orientationd(keys=mod_keys, axcodes="RAS"), |
|
|
160 |
Spacingd(keys=mod_keys, pixdim=spacing, mode=('bilinear', 'bilinear', 'nearest')), |
|
|
161 |
RandCropByPosNegLabeld( |
|
|
162 |
keys=mod_keys, |
|
|
163 |
label_key='GT', |
|
|
164 |
spatial_size = spatialsize, |
|
|
165 |
pos=2, |
|
|
166 |
neg=1, |
|
|
167 |
num_samples=1, |
|
|
168 |
image_key='PT', |
|
|
169 |
image_threshold=0, |
|
|
170 |
allow_smaller=True, |
|
|
171 |
), |
|
|
172 |
ResizeWithPadOrCropd( |
|
|
173 |
keys=mod_keys, |
|
|
174 |
spatial_size=spatialsize, |
|
|
175 |
mode='constant' |
|
|
176 |
), |
|
|
177 |
RandAffined( |
|
|
178 |
keys=mod_keys, |
|
|
179 |
mode=('bilinear', 'bilinear', 'nearest'), |
|
|
180 |
prob=0.5, |
|
|
181 |
spatial_size = spatialsize, |
|
|
182 |
translate_range=(10,10,10), |
|
|
183 |
rotate_range=(0, 0, np.pi/15), |
|
|
184 |
scale_range=(0.1, 0.1, 0.1)), |
|
|
185 |
ConcatItemsd(keys=['CT', 'PT'], name='CTPT', dim=0), |
|
|
186 |
DeleteItemsd(keys=['CT', 'PT']) |
|
|
187 |
]) |
|
|
188 |
|
|
|
189 |
return train_transforms |
|
|
190 |
|
|
|
191 |
#%% |
|
|
192 |
def get_valid_transforms(): |
|
|
193 |
spacing = get_spacing() |
|
|
194 |
mod_keys = ['CT', 'PT', 'GT'] |
|
|
195 |
valid_transforms = Compose( |
|
|
196 |
[ |
|
|
197 |
LoadImaged(keys=mod_keys), |
|
|
198 |
EnsureChannelFirstd(keys=mod_keys), |
|
|
199 |
CropForegroundd(keys=mod_keys, source_key='CT'), |
|
|
200 |
ScaleIntensityRanged(keys=['CT'], a_min=-154, a_max=325, b_min=0, b_max=1, clip=True), |
|
|
201 |
Orientationd(keys=mod_keys, axcodes="RAS"), |
|
|
202 |
Spacingd(keys=mod_keys, pixdim=spacing, mode=('bilinear', 'bilinear', 'nearest')), |
|
|
203 |
ConcatItemsd(keys=['CT', 'PT'], name='CTPT', dim=0), |
|
|
204 |
DeleteItemsd(keys=['CT', 'PT']) |
|
|
205 |
]) |
|
|
206 |
|
|
|
207 |
return valid_transforms |
|
|
208 |
|
|
|
209 |
|
|
|
210 |
def get_post_transforms(test_transforms, save_preds_dir): |
|
|
211 |
post_transforms = Compose([ |
|
|
212 |
Invertd( |
|
|
213 |
keys="Pred", |
|
|
214 |
transform=test_transforms, |
|
|
215 |
orig_keys="GT", |
|
|
216 |
meta_keys="pred_meta_dict", |
|
|
217 |
orig_meta_keys="image_meta_dict", |
|
|
218 |
meta_key_postfix="meta_dict", |
|
|
219 |
nearest_interp=False, |
|
|
220 |
to_tensor=True, |
|
|
221 |
), |
|
|
222 |
AsDiscreted(keys="Pred", argmax=True), |
|
|
223 |
SaveImaged(keys="Pred", meta_keys="pred_meta_dict", output_dir=save_preds_dir, output_postfix="", separate_folder=False, resample=False), |
|
|
224 |
]) |
|
|
225 |
return post_transforms |
|
|
226 |
|
|
|
227 |
def get_kernels_strides(patch_size, spacings): |
|
|
228 |
""" |
|
|
229 |
This function is only used for decathlon datasets with the provided patch sizes. |
|
|
230 |
When refering this method for other tasks, please ensure that the patch size for each spatial dimension should |
|
|
231 |
be divisible by the product of all strides in the corresponding dimension. |
|
|
232 |
In addition, the minimal spatial size should have at least one dimension that has twice the size of |
|
|
233 |
the product of all strides. For patch sizes that cannot find suitable strides, an error will be raised. |
|
|
234 |
""" |
|
|
235 |
sizes, spacings = patch_size, spacings |
|
|
236 |
input_size = sizes |
|
|
237 |
strides, kernels = [], [] |
|
|
238 |
while True: |
|
|
239 |
spacing_ratio = [sp / min(spacings) for sp in spacings] |
|
|
240 |
stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)] |
|
|
241 |
kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio] |
|
|
242 |
if all(s == 1 for s in stride): |
|
|
243 |
break |
|
|
244 |
for idx, (i, j) in enumerate(zip(sizes, stride)): |
|
|
245 |
if i % j != 0: |
|
|
246 |
raise ValueError( |
|
|
247 |
f"Patch size is not supported, please try to modify the size {input_size[idx]} in the spatial dimension {idx}." |
|
|
248 |
) |
|
|
249 |
sizes = [i / j for i, j in zip(sizes, stride)] |
|
|
250 |
spacings = [i * j for i, j in zip(spacings, stride)] |
|
|
251 |
kernels.append(kernel) |
|
|
252 |
strides.append(stride) |
|
|
253 |
|
|
|
254 |
strides.insert(0, len(spacings) * [1]) |
|
|
255 |
kernels.append(len(spacings) * [3]) |
|
|
256 |
return kernels, strides |
|
|
257 |
#%% |
|
|
258 |
def get_model(network_name = 'unet', input_patch_size=192): |
|
|
259 |
if network_name == 'unet': |
|
|
260 |
model = UNet( |
|
|
261 |
spatial_dims=3, |
|
|
262 |
in_channels=2, |
|
|
263 |
out_channels=2, |
|
|
264 |
channels=(16, 32, 64, 128, 256, 512), |
|
|
265 |
strides=(2, 2, 2, 2, 2), |
|
|
266 |
num_res_units=2, |
|
|
267 |
norm=Norm.BATCH |
|
|
268 |
) |
|
|
269 |
elif network_name == 'swinunetr': |
|
|
270 |
spatialsize = get_spatial_size(input_patch_size) |
|
|
271 |
model = SwinUNETR( |
|
|
272 |
img_size=spatialsize, |
|
|
273 |
in_channels=2, |
|
|
274 |
out_channels=2, |
|
|
275 |
feature_size=12, |
|
|
276 |
use_checkpoint=False, |
|
|
277 |
) |
|
|
278 |
elif network_name =='segresnet': |
|
|
279 |
model = SegResNet( |
|
|
280 |
spatial_dims=3, |
|
|
281 |
blocks_down=[1, 2, 2, 4], |
|
|
282 |
blocks_up=[1, 1, 1], |
|
|
283 |
init_filters=16, |
|
|
284 |
in_channels=2, |
|
|
285 |
out_channels=2, |
|
|
286 |
) |
|
|
287 |
elif network_name == 'dynunet': |
|
|
288 |
spatialsize = get_spatial_size(input_patch_size) |
|
|
289 |
spacing = get_spacing() |
|
|
290 |
krnls, strds = get_kernels_strides(spatialsize, spacing) |
|
|
291 |
model = DynUNet( |
|
|
292 |
spatial_dims=3, |
|
|
293 |
in_channels=2, |
|
|
294 |
out_channels=2, |
|
|
295 |
kernel_size=krnls, |
|
|
296 |
strides=strds, |
|
|
297 |
upsample_kernel_size=strds[1:], |
|
|
298 |
) |
|
|
299 |
else: |
|
|
300 |
pass |
|
|
301 |
return model |
|
|
302 |
|
|
|
303 |
|
|
|
304 |
#%% |
|
|
305 |
def get_loss_function(): |
|
|
306 |
loss_function = DiceLoss(to_onehot_y=True, softmax=True) |
|
|
307 |
return loss_function |
|
|
308 |
|
|
|
309 |
def get_optimizer(model, learning_rate=2e-4, weight_decay=1e-5): |
|
|
310 |
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) |
|
|
311 |
return optimizer |
|
|
312 |
|
|
|
313 |
def get_metric(): |
|
|
314 |
metric = DiceMetric(include_background=False, reduction="mean") |
|
|
315 |
return metric |
|
|
316 |
|
|
|
317 |
def get_scheduler(optimizer, max_epochs=500): |
|
|
318 |
scheduler = CosineAnnealingLR(optimizer, T_max=max_epochs, eta_min=0) |
|
|
319 |
return scheduler |
|
|
320 |
|
|
|
321 |
def get_validation_sliding_window_size(input_patch_size=192): |
|
|
322 |
dict_W_for_N = { |
|
|
323 |
96:128, |
|
|
324 |
128:160, |
|
|
325 |
160:192, |
|
|
326 |
192:192, |
|
|
327 |
224:224, |
|
|
328 |
256:256 |
|
|
329 |
} |
|
|
330 |
vlsz = dict_W_for_N[input_patch_size] |
|
|
331 |
return (vlsz, vlsz, vlsz) |