Diff of /src/train.py [000000] .. [42b7b1]

Switch to unified view

a b/src/train.py
1
""" training, transfer learning, and validation (testing with reference ground truth images) of the proposed deep
2
learning model.
3
4
The script allows to train a deep learning model on a given PET images in folder with corresponding ground truth (GT).
5
It is assumed that the directory structure of the dataset for training and validation are given as follows:
6
    main_dir:
7
        -- patient_id_1:
8
            -- PET
9
                --give_name.nii [.gz]
10
            -- GT
11
                -- give_name.nii [.gz]
12
13
         -- patient_id_2:
14
            -- PET
15
                --give_name.nii [.gz]
16
            -- GT
17
                -- give_name.nii [.gz]
18
19
Please refer to the requirements.yml or requirements.txt files for the required packages to run this script. Using
20
anaconda virtual environment is recommended to runt the script.
21
22
e.g. python train.py  --input_dir path/to/input/data --task [train or valid]
23
24
By K.B. Girum
25
"""
26
import os
27
import sys
28
29
# setup directory
30
p = os.path.abspath('../')
31
if p not in sys.path:
32
    sys.path.append(p)
33
34
from LFBNet.utilities import train_valid_paths
35
from LFBNet.preprocessing import preprocessing
36
from run import trainer, parse_argument
37
38
39
def main():
40
    """ train and/or validate the selected model configuration. Get parsed arguments from parse_argument function, and
41
    preprocess data, generate MIP, and train and validate.
42
43
     :parameter:
44
           --input_dir [path_to_pet_images]
45
           -- dataset_name [unique_dataset_name]
46
           -- output_dir [path_to_save_predicted_values] [optional]
47
           -- task test # testing model of the model [Optional]
48
49
     :returns:
50
            - trained model if the task is train, and the predicted segmentation results if the task is valid.
51
            - predicted results will be saved to predicted folder when the task is to predict.
52
            - It also saves the dice, sen, and specificity of the model on each dataset and the average and median
53
            values.
54
            - It computes the quantitative surrogate metabolic tumor volume (sTMTV) and surrogate dissemination feature
55
            (sDmax) from the segmented and ground truth images and saves them as xls file. The xls file column would
56
            have
57
            [patient_id, sTMTV_gt, Sdmax_gt, TMTV_prd, and Dmax_prd].
58
            - Note pred: predicted estimate, and gt: ground truth or from expert.
59
60
    """
61
    # get the parsed arguments, such as input directory path, output directory path, task, test or training
62
    args = parse_argument.get_parsed_arguments()
63
64
    # get input and output data directories
65
    input_dir = args.input_dir
66
    train_valid_paths.directory_exist(input_dir)  # CHECK: check if the input directory has files
67
68
    # data identifier, or name
69
    dataset_name = args.data_identifier
70
71
    # how to split the training and validation data:
72
    # OPTION 1: provide a csv file with two columns of list of patient ids: columns 1 ['train'] and column 2 ['valid'].
73
    # [set args.from_csv = True].
74
    # OPTION 2: let the program divide the given whole data set into training and validation data randomly.
75
    # [set args.from_csv = False].
76
77
    train_valid_id_from_csv = args.from_csv
78
79
    # output directory to save
80
    if args.output_dir:  # if given
81
        output_dir = args.output_dir
82
    else:
83
        # if not given it will create under the folder "../../data/  str(dataset_name) + 'default_3d_dir'
84
        output_dir = '../data/predicted'  # directory to the MIP
85
        if not os.path.exists('../data'):
86
            os.mkdir('../data')
87
88
        if not os.path.exists(output_dir):
89
            os.mkdir(output_dir)
90
91
    # processed directory:
92
    preprocessing_dir = '../data/preprocessed'  # directory to the MIP
93
    if not os.path.exists('../data'):
94
        os.mkdir('../data')
95
96
    if not os.path.exists(preprocessing_dir):
97
        os.mkdir(preprocessing_dir)
98
99
    # default output data spacing
100
    desired_spacing = [4.0, 4.0, 4.0]
101
102
    # STEP 1:  read the raw .nii files in suv and resize, crop in 3D form, generate MIP, and save
103
    # get the directory path to the generated and saved MIPS, if it already exists, go for training or testing
104
    dir_mip = []
105
    # path to the training and validation data
106
    path_train_valid = dict(train=None, test=None)
107
108
    # preprocessing stage:
109
    preprocessing_params = dict(data_path=input_dir, data_name=dataset_name, saving_dir=preprocessing_dir, save_3D=True,
110
    output_resolution=[128, 128, 256], desired_spacing=desired_spacing, generate_mip=True)
111
112
    dir_mip = preprocessing.read_pet_gt_resize_crop_save_as_3d_andor_mip(**preprocessing_params)
113
114
    # training or validation/testing from the input argument task
115
    task = args.task  # true training and false testing or validation
116
117
    # training deep learning model
118
    if task == 'train':
119
        # get valid id from the csv file: rom escel files manually set: assuming this csv file has two column,
120
        # with 'train' column for training data
121
        if train_valid_id_from_csv:
122
            train_valid_ids_path_csv = r'../csv/'
123
            train_ids, valid_ids = trainer.get_training_and_validation_ids_from_csv(train_valid_ids_path_csv)
124
        else:
125
            # generate csv file for the validation and training data by dividing the data into training and validation
126
            path_train_valid = dict(train=dir_mip)
127
            train_ids, valid_ids = train_valid_paths.get_train_valid_ids_from_folder(path_train_valid=path_train_valid)
128
129
        # train or test on the given input arguments
130
        trainer_params = dict(folder_preprocessed_train=dir_mip, folder_preprocessed_valid=dir_mip,
131
                              ids_to_read_train=train_ids, ids_to_read_valid=valid_ids, task=task,
132
                              predicted_directory=output_dir)
133
134
        network_run = trainer.NetworkTrainer(**trainer_params)
135
        network_run.train()
136
137
    # validation
138
    elif task == 'valid':
139
        dir_mip = os.path.join(preprocessing_dir, str(dataset_name) + "_default_MIP_dir")
140
141
        # get valid id from the csv file: assume training ids are under column name "train" and testing under "test"
142
        if train_valid_id_from_csv:
143
            train_valid_ids_path_csv = r'../csv/'
144
            train_ids, valid_ids = trainer.get_training_and_validation_ids_from_csv(train_valid_ids_path_csv)
145
146
        else:
147
            # generate csv file for the validation and training data by dividing the data into training and validation
148
            path_train_valid = dict(train=dir_mip)
149
            train_ids, valid_ids = train_valid_paths.get_train_valid_ids_from_folder(path_train_valid=path_train_valid)
150
151
        trainer_params = dict(folder_preprocessed_train=dir_mip, folder_preprocessed_valid=dir_mip,
152
                              ids_to_read_train=train_ids, ids_to_read_valid=valid_ids, task=task,
153
                              predicted_directory=output_dir, save_predicted=True)
154
155
        network_run = trainer.NetworkTrainer(**trainer_params)
156
        network_run.train()
157
158
    else:
159
        print("key word %s not recognized !\n" % task)
160
161
162
# check
163
if __name__ == '__main__':
164
    print("Running the integrated framework ... \n\n")
165
    main()