a b/segment.py
1
import sys
2
import argparse
3
import os
4
from glob import glob
5
import json
6
import SimpleITK as sitk
7
8
# importing utils and 
9
from utils.logger import logger, pprint
10
from utils.dataset import segment_lungs_and_remove_trachea
11
from enums.dtype import DataTypes
12
13
14
if __name__ == "__main__":
15
    # optional arguments from the command line 
16
    parser = argparse.ArgumentParser()
17
18
    parser.add_argument('--dataset_path', type=str, default='dataset/train', help='root dir for nifti training data')
19
20
    # parse the arguments
21
    args = parser.parse_args()
22
23
    # check if the dataset_path exists
24
    if not os.path.exists(args.dataset_path):
25
        logger.error(f"Path {args.dataset_path} does not exist")
26
        sys.exit(1)
27
28
    # get the list of exhale and inhale files from the dataset_path
29
    logger.info(f"Reading nifti data from '{args.dataset_path}'")
30
    exhale_volumes = [path.replace('\\', '/') for path in sorted(glob(os.path.join(args.dataset_path, "***" , "*eBHCT.nii.gz"), recursive=True))]
31
    inhale_volumes = [path.replace('\\', '/') for path in sorted(glob(os.path.join(args.dataset_path, "***" , "*iBHCT.nii.gz"), recursive=True))]
32
33
    # log the number of exhale and inhale files
34
    logger.info(f"Found {len(exhale_volumes)} exhale volumes: ({[subject.split('/')[-2] for subject in exhale_volumes]})")
35
    logger.info(f"Found {len(inhale_volumes)} inhale volumes: ({[subject.split('/')[-2] for subject in inhale_volumes]})\n")
36
    pprint(exhale_volumes, inhale_volumes)
37
    print('\n')
38
39
    # read the data dictionary
40
    with open(os.path.join(args.dataset_path.replace("train", "", 1).replace("test", "", 1), 'description.json'), 'r') as json_file:
41
        dictionary = json.loads(json_file.read())
42
43
    # iterate over all of the nifti inhale and exhale volumes and segment the lungs
44
    for volume in exhale_volumes + inhale_volumes:
45
        # get the subject name and information
46
        subject_name = volume.split('/')[-2]
47
        subject_information = dictionary[args.dataset_path.replace('\\', '/').split("/")[-1]][subject_name]
48
49
        logger.info(f"Segmenting {volume}")
50
        sitk_image = sitk.ReadImage(volume)
51
        np_image = sitk.GetArrayFromImage(sitk_image)
52
53
        # logs
54
        print(subject_information)
55
        print("sitk:\t\t", sitk_image.GetSize(), sitk_image.GetPixelIDTypeAsString(), sitk_image.GetOrigin(), sitk_image.GetSpacing())
56
        print("np:\t\t", np_image.shape, np_image.dtype)
57
58
        # segment the lungs
59
        if subject_name == 'copd2':
60
            # set a specific threshold to copd2
61
            threshold = 430
62
            fill_holes_before_trachea_removal = True
63
        else:
64
            threshold = 700 
65
            fill_holes_before_trachea_removal = False
66
67
        print("thresh:\t\t", threshold)
68
        print("fill_holes:\t", fill_holes_before_trachea_removal)
69
70
        _, _, _, lung_segmentation = \
71
            segment_lungs_and_remove_trachea(np_image, 
72
                                            threshold=threshold, structure=(7, 7, 7), fill_holes_before_trachea_removal=fill_holes_before_trachea_removal)
73
        
74
        lung_segmentation_sitk = sitk.GetImageFromArray(lung_segmentation)
75
        lung_segmentation_sitk.CopyInformation(sitk_image)
76
77
        # logs
78
        print("lung:\t\t", lung_segmentation.shape, lung_segmentation.dtype)
79
        print("lung_sitk:\t", lung_segmentation_sitk.GetSize(), lung_segmentation_sitk.GetPixelIDTypeAsString(), lung_segmentation_sitk.GetOrigin(), lung_segmentation_sitk.GetSpacing(), "\n")
80
81
        # save the lung segmentation
82
        sitk.WriteImage(lung_segmentation_sitk, volume.replace(".nii.gz", "_lung.nii.gz"))
83
84
    print("Segmentation complete!")
85
86