--- a +++ b/segmentation/initialize_train.py @@ -0,0 +1,331 @@ +''' +Copyright (c) Microsoft Corporation. All rights reserved. +Licensed under the MIT License. +''' +from monai.transforms import ( + EnsureChannelFirstd, + Compose, + CropForegroundd, + LoadImaged, + Orientationd, + RandCropByPosNegLabeld, + DeleteItemsd, + Spacingd, + RandAffined, + ConcatItemsd, + ScaleIntensityRanged, + ResizeWithPadOrCropd, + Invertd, + AsDiscreted, + SaveImaged, + +) +from monai.networks.nets import UNet, SegResNet, DynUNet, SwinUNETR, UNETR, AttentionUnet +from monai.networks.layers import Norm +from monai.metrics import DiceMetric +from monai.losses import DiceLoss +import torch +import matplotlib.pyplot as plt +from glob import glob +import pandas as pd +import numpy as np +from torch.optim.lr_scheduler import CosineAnnealingLR +import os +import sys +config_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..") +sys.path.append(config_dir) +from config import DATA_FOLDER, WORKING_FOLDER +#%% +def convert_to_4digits(str_num): + if len(str_num) == 1: + new_num = '000' + str_num + elif len(str_num) == 2: + new_num = '00' + str_num + elif len(str_num) == 3: + new_num = '0' + str_num + else: + new_num = str_num + return new_num + +def create_dictionary_ctptgt(ctpaths, ptpaths, gtpaths): + data = [] + for i in range(len(gtpaths)): + ctpath = ctpaths[i] + ptpath = ptpaths[i] + gtpath = gtpaths[i] + data.append({'CT':ctpath, 'PT':ptpath, 'GT':gtpath}) + return data + +def remove_all_extensions(filename): + while True: + name, ext = os.path.splitext(filename) + if ext == '': + return name + filename = name +#%% +def create_data_split_files(): + """Creates filepaths data for training/validation and test images and saves + them as `train_filepaths.csv` and `test_filepaths.csv` files under WORKING_FOLDER/data_split/; + all training images will be assigned a FoldID specifying which fold (out of the 5 folds) + the image belongs to. If the `train_filepaths.csv` and `test_filepaths.csv` already exist, + this function is skipped + """ + train_filepaths = os.path.join(WORKING_FOLDER, 'data_split', 'train_filepaths.csv') + test_filepaths = os.path.join(WORKING_FOLDER, 'data_split', 'test_filepaths.csv') + if os.path.exists(train_filepaths) and os.path.exists(test_filepaths): + return + else: + data_split_folder = os.path.join(WORKING_FOLDER, 'data_split') + os.makedirs(data_split_folder, exist_ok=True) + + imagesTr = os.path.join(DATA_FOLDER, 'imagesTr') + labelsTr = os.path.join(DATA_FOLDER, 'labelsTr') + + ctpaths = sorted(glob(os.path.join(imagesTr, '*0000.nii.gz'))) + ptpaths = sorted(glob(os.path.join(imagesTr, '*0001.nii.gz'))) + gtpaths = sorted(glob(os.path.join(labelsTr, '*.nii.gz'))) + imageids = [remove_all_extensions(os.path.basename(path)) for path in gtpaths] + + n_folds = 5 + part_size = len(imageids) // n_folds + remaining_elements = len(imageids) % n_folds + start = 0 + train_folds = [] + for i in range(n_folds): + end = start + part_size + (1 if i < remaining_elements else 0) + train_folds.append(imageids[start:end]) + start = end + + fold_sizes = [len(fold) for fold in train_folds] + foldids = [fold_sizes[i]*[i] for i in range(len(fold_sizes))] + foldids = [item for sublist in foldids for item in sublist] + + trainfolds_data = np.column_stack((imageids, foldids, ctpaths, ptpaths, gtpaths)) + train_df = pd.DataFrame(trainfolds_data, columns=['ImageID', 'FoldID', 'CTPATH', 'PTPATH', 'GTPATH']) + + train_df.to_csv(train_filepaths, index=False) + + imagesTs = os.path.join(DATA_FOLDER, 'imagesTs') + labelsTs = os.path.join(DATA_FOLDER, 'labelsTs') + ctpaths_test = sorted(glob(os.path.join(imagesTs, '*0000.nii.gz'))) + ptpaths_test = sorted(glob(os.path.join(imagesTs, '*0001.nii.gz'))) + gtpaths_test = sorted(glob(os.path.join(labelsTs, '*.nii.gz'))) + imageids_test = [remove_all_extensions(os.path.basename(path)) for path in gtpaths_test] + test_data = np.column_stack((imageids_test, ctpaths_test, ptpaths_test, gtpaths_test)) + test_df = pd.DataFrame(test_data, columns=['ImageID', 'CTPATH', 'PTPATH', 'GTPATH']) + test_df.to_csv(test_filepaths, index=False) + +#%% +def get_train_valid_data_in_dict_format(fold): + trainvalid_fpath = os.path.join(WORKING_FOLDER, 'data_split/train_filepaths.csv') + trainvalid_df = pd.read_csv(trainvalid_fpath) + train_df = trainvalid_df[trainvalid_df['FoldID'] != fold] + valid_df = trainvalid_df[trainvalid_df['FoldID'] == fold] + + ctpaths_train, ptpaths_train, gtpaths_train = list(train_df['CTPATH'].values), list(train_df['PTPATH'].values), list(train_df['GTPATH'].values) + ctpaths_valid, ptpaths_valid, gtpaths_valid = list(valid_df['CTPATH'].values), list(valid_df['PTPATH'].values), list(valid_df['GTPATH'].values) + + train_data = create_dictionary_ctptgt(ctpaths_train, ptpaths_train, gtpaths_train) + valid_data = create_dictionary_ctptgt(ctpaths_valid, ptpaths_valid, gtpaths_valid) + + return train_data, valid_data + +#%% +def get_test_data_in_dict_format(): + test_fpaths = os.path.join(WORKING_FOLDER, 'data_split/test_filepaths.csv') + test_df = pd.read_csv(test_fpaths) + ctpaths_test, ptpaths_test, gtpaths_test = list(test_df['CTPATH'].values), list(test_df['PTPATH'].values), list(test_df['GTPATH'].values) + test_data = create_dictionary_ctptgt(ctpaths_test, ptpaths_test, gtpaths_test) + return test_data + +def get_spatial_size(input_patch_size=192): + trsz = input_patch_size + return (trsz, trsz, trsz) + +def get_spacing(): + spc = 2 + return (spc, spc, spc) + +def get_train_transforms(input_patch_size=192): + spatialsize = get_spatial_size(input_patch_size) + spacing = get_spacing() + mod_keys = ['CT', 'PT', 'GT'] + train_transforms = Compose( + [ + LoadImaged(keys=mod_keys, image_only=True), + EnsureChannelFirstd(keys=mod_keys), + CropForegroundd(keys=mod_keys, source_key='CT'), + ScaleIntensityRanged(keys=['CT'], a_min=-154, a_max=325, b_min=0, b_max=1, clip=True), + Orientationd(keys=mod_keys, axcodes="RAS"), + Spacingd(keys=mod_keys, pixdim=spacing, mode=('bilinear', 'bilinear', 'nearest')), + RandCropByPosNegLabeld( + keys=mod_keys, + label_key='GT', + spatial_size = spatialsize, + pos=2, + neg=1, + num_samples=1, + image_key='PT', + image_threshold=0, + allow_smaller=True, + ), + ResizeWithPadOrCropd( + keys=mod_keys, + spatial_size=spatialsize, + mode='constant' + ), + RandAffined( + keys=mod_keys, + mode=('bilinear', 'bilinear', 'nearest'), + prob=0.5, + spatial_size = spatialsize, + translate_range=(10,10,10), + rotate_range=(0, 0, np.pi/15), + scale_range=(0.1, 0.1, 0.1)), + ConcatItemsd(keys=['CT', 'PT'], name='CTPT', dim=0), + DeleteItemsd(keys=['CT', 'PT']) + ]) + + return train_transforms + +#%% +def get_valid_transforms(): + spacing = get_spacing() + mod_keys = ['CT', 'PT', 'GT'] + valid_transforms = Compose( + [ + LoadImaged(keys=mod_keys), + EnsureChannelFirstd(keys=mod_keys), + CropForegroundd(keys=mod_keys, source_key='CT'), + ScaleIntensityRanged(keys=['CT'], a_min=-154, a_max=325, b_min=0, b_max=1, clip=True), + Orientationd(keys=mod_keys, axcodes="RAS"), + Spacingd(keys=mod_keys, pixdim=spacing, mode=('bilinear', 'bilinear', 'nearest')), + ConcatItemsd(keys=['CT', 'PT'], name='CTPT', dim=0), + DeleteItemsd(keys=['CT', 'PT']) + ]) + + return valid_transforms + + +def get_post_transforms(test_transforms, save_preds_dir): + post_transforms = Compose([ + Invertd( + keys="Pred", + transform=test_transforms, + orig_keys="GT", + meta_keys="pred_meta_dict", + orig_meta_keys="image_meta_dict", + meta_key_postfix="meta_dict", + nearest_interp=False, + to_tensor=True, + ), + AsDiscreted(keys="Pred", argmax=True), + SaveImaged(keys="Pred", meta_keys="pred_meta_dict", output_dir=save_preds_dir, output_postfix="", separate_folder=False, resample=False), + ]) + return post_transforms + +def get_kernels_strides(patch_size, spacings): + """ + This function is only used for decathlon datasets with the provided patch sizes. + When refering this method for other tasks, please ensure that the patch size for each spatial dimension should + be divisible by the product of all strides in the corresponding dimension. + In addition, the minimal spatial size should have at least one dimension that has twice the size of + the product of all strides. For patch sizes that cannot find suitable strides, an error will be raised. + """ + sizes, spacings = patch_size, spacings + input_size = sizes + strides, kernels = [], [] + while True: + spacing_ratio = [sp / min(spacings) for sp in spacings] + stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)] + kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio] + if all(s == 1 for s in stride): + break + for idx, (i, j) in enumerate(zip(sizes, stride)): + if i % j != 0: + raise ValueError( + f"Patch size is not supported, please try to modify the size {input_size[idx]} in the spatial dimension {idx}." + ) + sizes = [i / j for i, j in zip(sizes, stride)] + spacings = [i * j for i, j in zip(spacings, stride)] + kernels.append(kernel) + strides.append(stride) + + strides.insert(0, len(spacings) * [1]) + kernels.append(len(spacings) * [3]) + return kernels, strides +#%% +def get_model(network_name = 'unet', input_patch_size=192): + if network_name == 'unet': + model = UNet( + spatial_dims=3, + in_channels=2, + out_channels=2, + channels=(16, 32, 64, 128, 256, 512), + strides=(2, 2, 2, 2, 2), + num_res_units=2, + norm=Norm.BATCH + ) + elif network_name == 'swinunetr': + spatialsize = get_spatial_size(input_patch_size) + model = SwinUNETR( + img_size=spatialsize, + in_channels=2, + out_channels=2, + feature_size=12, + use_checkpoint=False, + ) + elif network_name =='segresnet': + model = SegResNet( + spatial_dims=3, + blocks_down=[1, 2, 2, 4], + blocks_up=[1, 1, 1], + init_filters=16, + in_channels=2, + out_channels=2, + ) + elif network_name == 'dynunet': + spatialsize = get_spatial_size(input_patch_size) + spacing = get_spacing() + krnls, strds = get_kernels_strides(spatialsize, spacing) + model = DynUNet( + spatial_dims=3, + in_channels=2, + out_channels=2, + kernel_size=krnls, + strides=strds, + upsample_kernel_size=strds[1:], + ) + else: + pass + return model + + +#%% +def get_loss_function(): + loss_function = DiceLoss(to_onehot_y=True, softmax=True) + return loss_function + +def get_optimizer(model, learning_rate=2e-4, weight_decay=1e-5): + optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) + return optimizer + +def get_metric(): + metric = DiceMetric(include_background=False, reduction="mean") + return metric + +def get_scheduler(optimizer, max_epochs=500): + scheduler = CosineAnnealingLR(optimizer, T_max=max_epochs, eta_min=0) + return scheduler + +def get_validation_sliding_window_size(input_patch_size=192): + dict_W_for_N = { + 96:128, + 128:160, + 160:192, + 192:192, + 224:224, + 256:256 + } + vlsz = dict_W_for_N[input_patch_size] + return (vlsz, vlsz, vlsz) \ No newline at end of file