|
a |
|
b/src/train.py |
|
|
1 |
""" training, transfer learning, and validation (testing with reference ground truth images) of the proposed deep |
|
|
2 |
learning model. |
|
|
3 |
|
|
|
4 |
The script allows to train a deep learning model on a given PET images in folder with corresponding ground truth (GT). |
|
|
5 |
It is assumed that the directory structure of the dataset for training and validation are given as follows: |
|
|
6 |
main_dir: |
|
|
7 |
-- patient_id_1: |
|
|
8 |
-- PET |
|
|
9 |
--give_name.nii [.gz] |
|
|
10 |
-- GT |
|
|
11 |
-- give_name.nii [.gz] |
|
|
12 |
|
|
|
13 |
-- patient_id_2: |
|
|
14 |
-- PET |
|
|
15 |
--give_name.nii [.gz] |
|
|
16 |
-- GT |
|
|
17 |
-- give_name.nii [.gz] |
|
|
18 |
|
|
|
19 |
Please refer to the requirements.yml or requirements.txt files for the required packages to run this script. Using |
|
|
20 |
anaconda virtual environment is recommended to runt the script. |
|
|
21 |
|
|
|
22 |
e.g. python train.py --input_dir path/to/input/data --task [train or valid] |
|
|
23 |
|
|
|
24 |
By K.B. Girum |
|
|
25 |
""" |
|
|
26 |
import os |
|
|
27 |
import sys |
|
|
28 |
|
|
|
29 |
# setup directory |
|
|
30 |
p = os.path.abspath('../') |
|
|
31 |
if p not in sys.path: |
|
|
32 |
sys.path.append(p) |
|
|
33 |
|
|
|
34 |
from LFBNet.utilities import train_valid_paths |
|
|
35 |
from LFBNet.preprocessing import preprocessing |
|
|
36 |
from run import trainer, parse_argument |
|
|
37 |
|
|
|
38 |
|
|
|
39 |
def main(): |
|
|
40 |
""" train and/or validate the selected model configuration. Get parsed arguments from parse_argument function, and |
|
|
41 |
preprocess data, generate MIP, and train and validate. |
|
|
42 |
|
|
|
43 |
:parameter: |
|
|
44 |
--input_dir [path_to_pet_images] |
|
|
45 |
-- dataset_name [unique_dataset_name] |
|
|
46 |
-- output_dir [path_to_save_predicted_values] [optional] |
|
|
47 |
-- task test # testing model of the model [Optional] |
|
|
48 |
|
|
|
49 |
:returns: |
|
|
50 |
- trained model if the task is train, and the predicted segmentation results if the task is valid. |
|
|
51 |
- predicted results will be saved to predicted folder when the task is to predict. |
|
|
52 |
- It also saves the dice, sen, and specificity of the model on each dataset and the average and median |
|
|
53 |
values. |
|
|
54 |
- It computes the quantitative surrogate metabolic tumor volume (sTMTV) and surrogate dissemination feature |
|
|
55 |
(sDmax) from the segmented and ground truth images and saves them as xls file. The xls file column would |
|
|
56 |
have |
|
|
57 |
[patient_id, sTMTV_gt, Sdmax_gt, TMTV_prd, and Dmax_prd]. |
|
|
58 |
- Note pred: predicted estimate, and gt: ground truth or from expert. |
|
|
59 |
|
|
|
60 |
""" |
|
|
61 |
# get the parsed arguments, such as input directory path, output directory path, task, test or training |
|
|
62 |
args = parse_argument.get_parsed_arguments() |
|
|
63 |
|
|
|
64 |
# get input and output data directories |
|
|
65 |
input_dir = args.input_dir |
|
|
66 |
train_valid_paths.directory_exist(input_dir) # CHECK: check if the input directory has files |
|
|
67 |
|
|
|
68 |
# data identifier, or name |
|
|
69 |
dataset_name = args.data_identifier |
|
|
70 |
|
|
|
71 |
# how to split the training and validation data: |
|
|
72 |
# OPTION 1: provide a csv file with two columns of list of patient ids: columns 1 ['train'] and column 2 ['valid']. |
|
|
73 |
# [set args.from_csv = True]. |
|
|
74 |
# OPTION 2: let the program divide the given whole data set into training and validation data randomly. |
|
|
75 |
# [set args.from_csv = False]. |
|
|
76 |
|
|
|
77 |
train_valid_id_from_csv = args.from_csv |
|
|
78 |
|
|
|
79 |
# output directory to save |
|
|
80 |
if args.output_dir: # if given |
|
|
81 |
output_dir = args.output_dir |
|
|
82 |
else: |
|
|
83 |
# if not given it will create under the folder "../../data/ str(dataset_name) + 'default_3d_dir' |
|
|
84 |
output_dir = '../data/predicted' # directory to the MIP |
|
|
85 |
if not os.path.exists('../data'): |
|
|
86 |
os.mkdir('../data') |
|
|
87 |
|
|
|
88 |
if not os.path.exists(output_dir): |
|
|
89 |
os.mkdir(output_dir) |
|
|
90 |
|
|
|
91 |
# processed directory: |
|
|
92 |
preprocessing_dir = '../data/preprocessed' # directory to the MIP |
|
|
93 |
if not os.path.exists('../data'): |
|
|
94 |
os.mkdir('../data') |
|
|
95 |
|
|
|
96 |
if not os.path.exists(preprocessing_dir): |
|
|
97 |
os.mkdir(preprocessing_dir) |
|
|
98 |
|
|
|
99 |
# default output data spacing |
|
|
100 |
desired_spacing = [4.0, 4.0, 4.0] |
|
|
101 |
|
|
|
102 |
# STEP 1: read the raw .nii files in suv and resize, crop in 3D form, generate MIP, and save |
|
|
103 |
# get the directory path to the generated and saved MIPS, if it already exists, go for training or testing |
|
|
104 |
dir_mip = [] |
|
|
105 |
# path to the training and validation data |
|
|
106 |
path_train_valid = dict(train=None, test=None) |
|
|
107 |
|
|
|
108 |
# preprocessing stage: |
|
|
109 |
preprocessing_params = dict(data_path=input_dir, data_name=dataset_name, saving_dir=preprocessing_dir, save_3D=True, |
|
|
110 |
output_resolution=[128, 128, 256], desired_spacing=desired_spacing, generate_mip=True) |
|
|
111 |
|
|
|
112 |
dir_mip = preprocessing.read_pet_gt_resize_crop_save_as_3d_andor_mip(**preprocessing_params) |
|
|
113 |
|
|
|
114 |
# training or validation/testing from the input argument task |
|
|
115 |
task = args.task # true training and false testing or validation |
|
|
116 |
|
|
|
117 |
# training deep learning model |
|
|
118 |
if task == 'train': |
|
|
119 |
# get valid id from the csv file: rom escel files manually set: assuming this csv file has two column, |
|
|
120 |
# with 'train' column for training data |
|
|
121 |
if train_valid_id_from_csv: |
|
|
122 |
train_valid_ids_path_csv = r'../csv/' |
|
|
123 |
train_ids, valid_ids = trainer.get_training_and_validation_ids_from_csv(train_valid_ids_path_csv) |
|
|
124 |
else: |
|
|
125 |
# generate csv file for the validation and training data by dividing the data into training and validation |
|
|
126 |
path_train_valid = dict(train=dir_mip) |
|
|
127 |
train_ids, valid_ids = train_valid_paths.get_train_valid_ids_from_folder(path_train_valid=path_train_valid) |
|
|
128 |
|
|
|
129 |
# train or test on the given input arguments |
|
|
130 |
trainer_params = dict(folder_preprocessed_train=dir_mip, folder_preprocessed_valid=dir_mip, |
|
|
131 |
ids_to_read_train=train_ids, ids_to_read_valid=valid_ids, task=task, |
|
|
132 |
predicted_directory=output_dir) |
|
|
133 |
|
|
|
134 |
network_run = trainer.NetworkTrainer(**trainer_params) |
|
|
135 |
network_run.train() |
|
|
136 |
|
|
|
137 |
# validation |
|
|
138 |
elif task == 'valid': |
|
|
139 |
dir_mip = os.path.join(preprocessing_dir, str(dataset_name) + "_default_MIP_dir") |
|
|
140 |
|
|
|
141 |
# get valid id from the csv file: assume training ids are under column name "train" and testing under "test" |
|
|
142 |
if train_valid_id_from_csv: |
|
|
143 |
train_valid_ids_path_csv = r'../csv/' |
|
|
144 |
train_ids, valid_ids = trainer.get_training_and_validation_ids_from_csv(train_valid_ids_path_csv) |
|
|
145 |
|
|
|
146 |
else: |
|
|
147 |
# generate csv file for the validation and training data by dividing the data into training and validation |
|
|
148 |
path_train_valid = dict(train=dir_mip) |
|
|
149 |
train_ids, valid_ids = train_valid_paths.get_train_valid_ids_from_folder(path_train_valid=path_train_valid) |
|
|
150 |
|
|
|
151 |
trainer_params = dict(folder_preprocessed_train=dir_mip, folder_preprocessed_valid=dir_mip, |
|
|
152 |
ids_to_read_train=train_ids, ids_to_read_valid=valid_ids, task=task, |
|
|
153 |
predicted_directory=output_dir, save_predicted=True) |
|
|
154 |
|
|
|
155 |
network_run = trainer.NetworkTrainer(**trainer_params) |
|
|
156 |
network_run.train() |
|
|
157 |
|
|
|
158 |
else: |
|
|
159 |
print("key word %s not recognized !\n" % task) |
|
|
160 |
|
|
|
161 |
|
|
|
162 |
# check |
|
|
163 |
if __name__ == '__main__': |
|
|
164 |
print("Running the integrated framework ... \n\n") |
|
|
165 |
main() |