Switch to unified view

a b/segmentation/initialize_train.py
1
'''
2
Copyright (c) Microsoft Corporation. All rights reserved.
3
Licensed under the MIT License.
4
'''
5
from monai.transforms import (
6
    EnsureChannelFirstd,
7
    Compose,
8
    CropForegroundd,
9
    LoadImaged,
10
    Orientationd,
11
    RandCropByPosNegLabeld,
12
    DeleteItemsd,
13
    Spacingd,
14
    RandAffined,
15
    ConcatItemsd,
16
    ScaleIntensityRanged,
17
    ResizeWithPadOrCropd,
18
    Invertd,
19
    AsDiscreted,
20
    SaveImaged,
21
    
22
)
23
from monai.networks.nets import UNet, SegResNet, DynUNet, SwinUNETR, UNETR, AttentionUnet
24
from monai.networks.layers import Norm
25
from monai.metrics import DiceMetric
26
from monai.losses import DiceLoss
27
import torch
28
import matplotlib.pyplot as plt
29
from glob import glob 
30
import pandas as pd
31
import numpy as np
32
from torch.optim.lr_scheduler import CosineAnnealingLR
33
import os
34
import sys 
35
config_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")
36
sys.path.append(config_dir)
37
from config import DATA_FOLDER, WORKING_FOLDER
38
#%%
39
def convert_to_4digits(str_num):
40
    if len(str_num) == 1:
41
        new_num = '000' + str_num
42
    elif len(str_num) == 2:
43
        new_num = '00' + str_num
44
    elif len(str_num) == 3:
45
        new_num = '0' + str_num
46
    else:
47
        new_num = str_num
48
    return new_num
49
50
def create_dictionary_ctptgt(ctpaths, ptpaths, gtpaths):
51
    data = []
52
    for i in range(len(gtpaths)):
53
        ctpath = ctpaths[i]
54
        ptpath = ptpaths[i]
55
        gtpath = gtpaths[i]
56
        data.append({'CT':ctpath, 'PT':ptpath, 'GT':gtpath})
57
    return data
58
59
def remove_all_extensions(filename):
60
    while True:
61
        name, ext = os.path.splitext(filename)
62
        if ext == '':
63
            return name
64
        filename = name
65
#%%
66
def create_data_split_files():
67
    """Creates filepaths data for training/validation and test images and saves 
68
    them as `train_filepaths.csv` and `test_filepaths.csv` files under WORKING_FOLDER/data_split/; 
69
    all training images will be assigned a FoldID specifying which fold (out of the 5 folds) 
70
    the image belongs to. If the `train_filepaths.csv` and `test_filepaths.csv` already exist, 
71
    this function is skipped
72
    """
73
    train_filepaths = os.path.join(WORKING_FOLDER, 'data_split', 'train_filepaths.csv')
74
    test_filepaths = os.path.join(WORKING_FOLDER, 'data_split', 'test_filepaths.csv')
75
    if os.path.exists(train_filepaths) and os.path.exists(test_filepaths):
76
        return 
77
    else:
78
        data_split_folder = os.path.join(WORKING_FOLDER, 'data_split')
79
        os.makedirs(data_split_folder, exist_ok=True)
80
        
81
        imagesTr = os.path.join(DATA_FOLDER, 'imagesTr')
82
        labelsTr = os.path.join(DATA_FOLDER, 'labelsTr')
83
84
        ctpaths = sorted(glob(os.path.join(imagesTr, '*0000.nii.gz')))
85
        ptpaths = sorted(glob(os.path.join(imagesTr, '*0001.nii.gz')))
86
        gtpaths = sorted(glob(os.path.join(labelsTr, '*.nii.gz')))
87
        imageids = [remove_all_extensions(os.path.basename(path)) for path in gtpaths]
88
89
        n_folds = 5
90
        part_size = len(imageids) // n_folds
91
        remaining_elements = len(imageids) % n_folds    
92
        start = 0
93
        train_folds = []
94
        for i in range(n_folds):
95
            end = start + part_size + (1 if i < remaining_elements else 0)
96
            train_folds.append(imageids[start:end])
97
            start = end
98
        
99
        fold_sizes = [len(fold) for fold in train_folds]
100
        foldids = [fold_sizes[i]*[i] for i in range(len(fold_sizes))]
101
        foldids = [item for sublist in foldids for item in sublist]
102
        
103
        trainfolds_data = np.column_stack((imageids, foldids, ctpaths, ptpaths, gtpaths))  
104
        train_df = pd.DataFrame(trainfolds_data, columns=['ImageID', 'FoldID', 'CTPATH', 'PTPATH', 'GTPATH'])
105
        
106
        train_df.to_csv(train_filepaths, index=False)
107
108
        imagesTs = os.path.join(DATA_FOLDER, 'imagesTs')
109
        labelsTs = os.path.join(DATA_FOLDER, 'labelsTs')
110
        ctpaths_test = sorted(glob(os.path.join(imagesTs, '*0000.nii.gz')))
111
        ptpaths_test = sorted(glob(os.path.join(imagesTs, '*0001.nii.gz')))
112
        gtpaths_test = sorted(glob(os.path.join(labelsTs, '*.nii.gz')))
113
        imageids_test = [remove_all_extensions(os.path.basename(path)) for path in gtpaths_test]
114
        test_data = np.column_stack((imageids_test, ctpaths_test, ptpaths_test, gtpaths_test))
115
        test_df = pd.DataFrame(test_data, columns=['ImageID', 'CTPATH', 'PTPATH', 'GTPATH'])
116
        test_df.to_csv(test_filepaths, index=False)
117
118
#%%
119
def get_train_valid_data_in_dict_format(fold):
120
    trainvalid_fpath = os.path.join(WORKING_FOLDER, 'data_split/train_filepaths.csv')
121
    trainvalid_df = pd.read_csv(trainvalid_fpath)
122
    train_df = trainvalid_df[trainvalid_df['FoldID'] != fold]
123
    valid_df = trainvalid_df[trainvalid_df['FoldID'] == fold]
124
125
    ctpaths_train, ptpaths_train, gtpaths_train = list(train_df['CTPATH'].values), list(train_df['PTPATH'].values),  list(train_df['GTPATH'].values)
126
    ctpaths_valid, ptpaths_valid, gtpaths_valid = list(valid_df['CTPATH'].values), list(valid_df['PTPATH'].values),  list(valid_df['GTPATH'].values)
127
128
    train_data = create_dictionary_ctptgt(ctpaths_train, ptpaths_train, gtpaths_train)
129
    valid_data = create_dictionary_ctptgt(ctpaths_valid, ptpaths_valid, gtpaths_valid)
130
131
    return train_data, valid_data
132
133
#%%
134
def get_test_data_in_dict_format():
135
    test_fpaths = os.path.join(WORKING_FOLDER, 'data_split/test_filepaths.csv')
136
    test_df = pd.read_csv(test_fpaths)
137
    ctpaths_test, ptpaths_test, gtpaths_test = list(test_df['CTPATH'].values), list(test_df['PTPATH'].values),  list(test_df['GTPATH'].values)
138
    test_data = create_dictionary_ctptgt(ctpaths_test, ptpaths_test, gtpaths_test)
139
    return test_data
140
141
def get_spatial_size(input_patch_size=192):
142
    trsz = input_patch_size
143
    return (trsz, trsz, trsz)
144
145
def get_spacing():
146
    spc = 2
147
    return (spc, spc, spc)
148
149
def get_train_transforms(input_patch_size=192):
150
    spatialsize = get_spatial_size(input_patch_size)
151
    spacing = get_spacing()
152
    mod_keys = ['CT', 'PT', 'GT']
153
    train_transforms = Compose(
154
    [
155
        LoadImaged(keys=mod_keys, image_only=True),
156
        EnsureChannelFirstd(keys=mod_keys),
157
        CropForegroundd(keys=mod_keys, source_key='CT'),
158
        ScaleIntensityRanged(keys=['CT'], a_min=-154, a_max=325, b_min=0, b_max=1, clip=True),
159
        Orientationd(keys=mod_keys, axcodes="RAS"),
160
        Spacingd(keys=mod_keys, pixdim=spacing, mode=('bilinear', 'bilinear', 'nearest')),
161
        RandCropByPosNegLabeld(
162
            keys=mod_keys,
163
            label_key='GT',
164
            spatial_size = spatialsize,
165
            pos=2,
166
            neg=1,
167
            num_samples=1,
168
            image_key='PT',
169
            image_threshold=0,
170
            allow_smaller=True,
171
        ),
172
        ResizeWithPadOrCropd(
173
            keys=mod_keys,
174
            spatial_size=spatialsize,
175
            mode='constant'
176
        ),
177
        RandAffined(
178
            keys=mod_keys,
179
            mode=('bilinear', 'bilinear', 'nearest'),
180
            prob=0.5,
181
            spatial_size = spatialsize,
182
            translate_range=(10,10,10),
183
            rotate_range=(0, 0, np.pi/15),
184
            scale_range=(0.1, 0.1, 0.1)),
185
        ConcatItemsd(keys=['CT', 'PT'], name='CTPT', dim=0),
186
        DeleteItemsd(keys=['CT', 'PT'])
187
    ])
188
189
    return train_transforms
190
191
#%%
192
def get_valid_transforms():
193
    spacing = get_spacing()
194
    mod_keys = ['CT', 'PT', 'GT']
195
    valid_transforms = Compose(
196
    [
197
        LoadImaged(keys=mod_keys),
198
        EnsureChannelFirstd(keys=mod_keys),
199
        CropForegroundd(keys=mod_keys, source_key='CT'),
200
        ScaleIntensityRanged(keys=['CT'], a_min=-154, a_max=325, b_min=0, b_max=1, clip=True),
201
        Orientationd(keys=mod_keys, axcodes="RAS"),
202
        Spacingd(keys=mod_keys, pixdim=spacing, mode=('bilinear', 'bilinear', 'nearest')),
203
        ConcatItemsd(keys=['CT', 'PT'], name='CTPT', dim=0),
204
        DeleteItemsd(keys=['CT', 'PT'])
205
    ])
206
207
    return valid_transforms
208
209
210
def get_post_transforms(test_transforms, save_preds_dir):
211
    post_transforms = Compose([
212
        Invertd(
213
            keys="Pred",
214
            transform=test_transforms,
215
            orig_keys="GT",
216
            meta_keys="pred_meta_dict",
217
            orig_meta_keys="image_meta_dict",
218
            meta_key_postfix="meta_dict",
219
            nearest_interp=False,
220
            to_tensor=True,
221
        ),
222
        AsDiscreted(keys="Pred", argmax=True),
223
        SaveImaged(keys="Pred", meta_keys="pred_meta_dict", output_dir=save_preds_dir, output_postfix="", separate_folder=False, resample=False),
224
    ])
225
    return post_transforms
226
227
def get_kernels_strides(patch_size, spacings):
228
    """
229
    This function is only used for decathlon datasets with the provided patch sizes.
230
    When refering this method for other tasks, please ensure that the patch size for each spatial dimension should
231
    be divisible by the product of all strides in the corresponding dimension.
232
    In addition, the minimal spatial size should have at least one dimension that has twice the size of
233
    the product of all strides. For patch sizes that cannot find suitable strides, an error will be raised.
234
    """
235
    sizes, spacings = patch_size, spacings
236
    input_size = sizes
237
    strides, kernels = [], []
238
    while True:
239
        spacing_ratio = [sp / min(spacings) for sp in spacings]
240
        stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)]
241
        kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
242
        if all(s == 1 for s in stride):
243
            break
244
        for idx, (i, j) in enumerate(zip(sizes, stride)):
245
            if i % j != 0:
246
                raise ValueError(
247
                    f"Patch size is not supported, please try to modify the size {input_size[idx]} in the spatial dimension {idx}."
248
                )
249
        sizes = [i / j for i, j in zip(sizes, stride)]
250
        spacings = [i * j for i, j in zip(spacings, stride)]
251
        kernels.append(kernel)
252
        strides.append(stride)
253
254
    strides.insert(0, len(spacings) * [1])
255
    kernels.append(len(spacings) * [3])
256
    return kernels, strides
257
#%%
258
def get_model(network_name = 'unet', input_patch_size=192):
259
    if network_name == 'unet':
260
        model = UNet(
261
            spatial_dims=3,
262
            in_channels=2,
263
            out_channels=2,
264
            channels=(16, 32, 64, 128, 256, 512),
265
            strides=(2, 2, 2, 2, 2),
266
            num_res_units=2,
267
            norm=Norm.BATCH
268
        )
269
    elif network_name == 'swinunetr':
270
        spatialsize = get_spatial_size(input_patch_size)
271
        model = SwinUNETR(
272
            img_size=spatialsize,
273
            in_channels=2,
274
            out_channels=2,
275
            feature_size=12,
276
            use_checkpoint=False,
277
        )
278
    elif network_name =='segresnet':
279
        model = SegResNet(
280
            spatial_dims=3,
281
            blocks_down=[1, 2, 2, 4],
282
            blocks_up=[1, 1, 1],
283
            init_filters=16,
284
            in_channels=2,
285
            out_channels=2,
286
        )
287
    elif network_name == 'dynunet':
288
        spatialsize = get_spatial_size(input_patch_size)
289
        spacing = get_spacing()
290
        krnls, strds = get_kernels_strides(spatialsize, spacing)
291
        model = DynUNet(
292
            spatial_dims=3,
293
            in_channels=2,
294
            out_channels=2,
295
            kernel_size=krnls,
296
            strides=strds,
297
            upsample_kernel_size=strds[1:],
298
        )
299
    else:
300
        pass
301
    return model
302
303
304
#%%
305
def get_loss_function():
306
    loss_function = DiceLoss(to_onehot_y=True, softmax=True)
307
    return loss_function
308
309
def get_optimizer(model, learning_rate=2e-4, weight_decay=1e-5):
310
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
311
    return optimizer
312
313
def get_metric():
314
    metric = DiceMetric(include_background=False, reduction="mean")
315
    return metric
316
317
def get_scheduler(optimizer, max_epochs=500):
318
    scheduler = CosineAnnealingLR(optimizer, T_max=max_epochs, eta_min=0)
319
    return scheduler
320
321
def get_validation_sliding_window_size(input_patch_size=192):
322
    dict_W_for_N = {
323
        96:128,
324
        128:160,
325
        160:192,
326
        192:192,
327
        224:224,
328
        256:256
329
    }
330
    vlsz = dict_W_for_N[input_patch_size]
331
    return (vlsz, vlsz, vlsz)