[42b7b1]: / src / train.py

Download this file

166 lines (129 with data), 7.2 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
""" training, transfer learning, and validation (testing with reference ground truth images) of the proposed deep
learning model.
The script allows to train a deep learning model on a given PET images in folder with corresponding ground truth (GT).
It is assumed that the directory structure of the dataset for training and validation are given as follows:
main_dir:
-- patient_id_1:
-- PET
--give_name.nii [.gz]
-- GT
-- give_name.nii [.gz]
-- patient_id_2:
-- PET
--give_name.nii [.gz]
-- GT
-- give_name.nii [.gz]
Please refer to the requirements.yml or requirements.txt files for the required packages to run this script. Using
anaconda virtual environment is recommended to runt the script.
e.g. python train.py --input_dir path/to/input/data --task [train or valid]
By K.B. Girum
"""
import os
import sys
# setup directory
p = os.path.abspath('../')
if p not in sys.path:
sys.path.append(p)
from LFBNet.utilities import train_valid_paths
from LFBNet.preprocessing import preprocessing
from run import trainer, parse_argument
def main():
""" train and/or validate the selected model configuration. Get parsed arguments from parse_argument function, and
preprocess data, generate MIP, and train and validate.
:parameter:
--input_dir [path_to_pet_images]
-- dataset_name [unique_dataset_name]
-- output_dir [path_to_save_predicted_values] [optional]
-- task test # testing model of the model [Optional]
:returns:
- trained model if the task is train, and the predicted segmentation results if the task is valid.
- predicted results will be saved to predicted folder when the task is to predict.
- It also saves the dice, sen, and specificity of the model on each dataset and the average and median
values.
- It computes the quantitative surrogate metabolic tumor volume (sTMTV) and surrogate dissemination feature
(sDmax) from the segmented and ground truth images and saves them as xls file. The xls file column would
have
[patient_id, sTMTV_gt, Sdmax_gt, TMTV_prd, and Dmax_prd].
- Note pred: predicted estimate, and gt: ground truth or from expert.
"""
# get the parsed arguments, such as input directory path, output directory path, task, test or training
args = parse_argument.get_parsed_arguments()
# get input and output data directories
input_dir = args.input_dir
train_valid_paths.directory_exist(input_dir) # CHECK: check if the input directory has files
# data identifier, or name
dataset_name = args.data_identifier
# how to split the training and validation data:
# OPTION 1: provide a csv file with two columns of list of patient ids: columns 1 ['train'] and column 2 ['valid'].
# [set args.from_csv = True].
# OPTION 2: let the program divide the given whole data set into training and validation data randomly.
# [set args.from_csv = False].
train_valid_id_from_csv = args.from_csv
# output directory to save
if args.output_dir: # if given
output_dir = args.output_dir
else:
# if not given it will create under the folder "../../data/ str(dataset_name) + 'default_3d_dir'
output_dir = '../data/predicted' # directory to the MIP
if not os.path.exists('../data'):
os.mkdir('../data')
if not os.path.exists(output_dir):
os.mkdir(output_dir)
# processed directory:
preprocessing_dir = '../data/preprocessed' # directory to the MIP
if not os.path.exists('../data'):
os.mkdir('../data')
if not os.path.exists(preprocessing_dir):
os.mkdir(preprocessing_dir)
# default output data spacing
desired_spacing = [4.0, 4.0, 4.0]
# STEP 1: read the raw .nii files in suv and resize, crop in 3D form, generate MIP, and save
# get the directory path to the generated and saved MIPS, if it already exists, go for training or testing
dir_mip = []
# path to the training and validation data
path_train_valid = dict(train=None, test=None)
# preprocessing stage:
preprocessing_params = dict(data_path=input_dir, data_name=dataset_name, saving_dir=preprocessing_dir, save_3D=True,
output_resolution=[128, 128, 256], desired_spacing=desired_spacing, generate_mip=True)
dir_mip = preprocessing.read_pet_gt_resize_crop_save_as_3d_andor_mip(**preprocessing_params)
# training or validation/testing from the input argument task
task = args.task # true training and false testing or validation
# training deep learning model
if task == 'train':
# get valid id from the csv file: rom escel files manually set: assuming this csv file has two column,
# with 'train' column for training data
if train_valid_id_from_csv:
train_valid_ids_path_csv = r'../csv/'
train_ids, valid_ids = trainer.get_training_and_validation_ids_from_csv(train_valid_ids_path_csv)
else:
# generate csv file for the validation and training data by dividing the data into training and validation
path_train_valid = dict(train=dir_mip)
train_ids, valid_ids = train_valid_paths.get_train_valid_ids_from_folder(path_train_valid=path_train_valid)
# train or test on the given input arguments
trainer_params = dict(folder_preprocessed_train=dir_mip, folder_preprocessed_valid=dir_mip,
ids_to_read_train=train_ids, ids_to_read_valid=valid_ids, task=task,
predicted_directory=output_dir)
network_run = trainer.NetworkTrainer(**trainer_params)
network_run.train()
# validation
elif task == 'valid':
dir_mip = os.path.join(preprocessing_dir, str(dataset_name) + "_default_MIP_dir")
# get valid id from the csv file: assume training ids are under column name "train" and testing under "test"
if train_valid_id_from_csv:
train_valid_ids_path_csv = r'../csv/'
train_ids, valid_ids = trainer.get_training_and_validation_ids_from_csv(train_valid_ids_path_csv)
else:
# generate csv file for the validation and training data by dividing the data into training and validation
path_train_valid = dict(train=dir_mip)
train_ids, valid_ids = train_valid_paths.get_train_valid_ids_from_folder(path_train_valid=path_train_valid)
trainer_params = dict(folder_preprocessed_train=dir_mip, folder_preprocessed_valid=dir_mip,
ids_to_read_train=train_ids, ids_to_read_valid=valid_ids, task=task,
predicted_directory=output_dir, save_predicted=True)
network_run = trainer.NetworkTrainer(**trainer_params)
network_run.train()
else:
print("key word %s not recognized !\n" % task)
# check
if __name__ == '__main__':
print("Running the integrated framework ... \n\n")
main()