|
a |
|
b/run.py |
|
|
1 |
import argparse |
|
|
2 |
import os |
|
|
3 |
import torch |
|
|
4 |
import utils.evaluator as eu |
|
|
5 |
from quicknat import QuickNat |
|
|
6 |
from settings import Settings |
|
|
7 |
from solver import Solver |
|
|
8 |
from utils.data_utils import get_imdb_dataset |
|
|
9 |
from utils.log_utils import LogWriter |
|
|
10 |
import logging |
|
|
11 |
import shutil |
|
|
12 |
|
|
|
13 |
torch.set_default_tensor_type('torch.FloatTensor') |
|
|
14 |
|
|
|
15 |
|
|
|
16 |
def load_data(data_params): |
|
|
17 |
print("Loading dataset") |
|
|
18 |
train_data, test_data = get_imdb_dataset(data_params) |
|
|
19 |
print("Train size: %i" % len(train_data)) |
|
|
20 |
print("Test size: %i" % len(test_data)) |
|
|
21 |
return train_data, test_data |
|
|
22 |
|
|
|
23 |
|
|
|
24 |
def train(train_params, common_params, data_params, net_params): |
|
|
25 |
|
|
|
26 |
train_data, test_data = load_data(data_params) |
|
|
27 |
|
|
|
28 |
train_loader = torch.utils.data.DataLoader(train_data, batch_size=train_params['train_batch_size'], shuffle=True, |
|
|
29 |
num_workers=4, pin_memory=True) |
|
|
30 |
val_loader = torch.utils.data.DataLoader(test_data, batch_size=train_params['val_batch_size'], shuffle=False, |
|
|
31 |
num_workers=4, pin_memory=True) |
|
|
32 |
|
|
|
33 |
net_params_ = net_params.copy() |
|
|
34 |
empty_model = QuickNat(net_params_) |
|
|
35 |
if train_params['use_pre_trained']: |
|
|
36 |
quicknat_model = torch.load(train_params['pre_trained_path']) |
|
|
37 |
else: |
|
|
38 |
quicknat_model = QuickNat(net_params) |
|
|
39 |
|
|
|
40 |
solver = Solver(quicknat_model, |
|
|
41 |
device=common_params['device'], |
|
|
42 |
num_class=net_params['num_class'], |
|
|
43 |
optim_args={"lr": train_params['learning_rate'], |
|
|
44 |
"betas": train_params['optim_betas'], |
|
|
45 |
"eps": train_params['optim_eps'], |
|
|
46 |
"weight_decay": train_params['optim_weight_decay']}, |
|
|
47 |
model_name=common_params['model_name'], |
|
|
48 |
exp_name=train_params['exp_name'], |
|
|
49 |
labels=data_params['labels'], |
|
|
50 |
log_nth=train_params['log_nth'], |
|
|
51 |
num_epochs=train_params['num_epochs'], |
|
|
52 |
lr_scheduler_step_size=train_params['lr_scheduler_step_size'], |
|
|
53 |
lr_scheduler_gamma=train_params['lr_scheduler_gamma'], |
|
|
54 |
use_last_checkpoint=train_params['use_last_checkpoint'], |
|
|
55 |
log_dir=common_params['log_dir'], |
|
|
56 |
exp_dir=common_params['exp_dir']) |
|
|
57 |
|
|
|
58 |
solver.train(train_loader, val_loader) |
|
|
59 |
final_model_path = os.path.join(common_params['save_model_dir'], train_params['final_model_file']) |
|
|
60 |
# quicknat_model.save(final_model_path) |
|
|
61 |
solver.model = empty_model |
|
|
62 |
solver.save_best_model(final_model_path) |
|
|
63 |
print("final model saved @ " + str(final_model_path)) |
|
|
64 |
|
|
|
65 |
|
|
|
66 |
def evaluate(eval_params, net_params, data_params, common_params, train_params): |
|
|
67 |
eval_model_path = eval_params['eval_model_path'] |
|
|
68 |
num_classes = net_params['num_class'] |
|
|
69 |
labels = data_params['labels'] |
|
|
70 |
data_dir = eval_params['data_dir'] |
|
|
71 |
label_dir = eval_params['label_dir'] |
|
|
72 |
volumes_txt_file = eval_params['volumes_txt_file'] |
|
|
73 |
remap_config = eval_params['remap_config'] |
|
|
74 |
device = common_params['device'] |
|
|
75 |
log_dir = common_params['log_dir'] |
|
|
76 |
exp_dir = common_params['exp_dir'] |
|
|
77 |
exp_name = train_params['exp_name'] |
|
|
78 |
save_predictions_dir = eval_params['save_predictions_dir'] |
|
|
79 |
prediction_path = os.path.join(exp_dir, exp_name, save_predictions_dir) |
|
|
80 |
orientation = eval_params['orientation'] |
|
|
81 |
data_id = eval_params['data_id'] |
|
|
82 |
|
|
|
83 |
logWriter = LogWriter(num_classes, log_dir, exp_name, labels=labels) |
|
|
84 |
|
|
|
85 |
avg_dice_score, class_dist = eu.evaluate_dice_score(eval_model_path, |
|
|
86 |
num_classes, |
|
|
87 |
data_dir, |
|
|
88 |
label_dir, |
|
|
89 |
volumes_txt_file, |
|
|
90 |
remap_config, |
|
|
91 |
orientation, |
|
|
92 |
prediction_path, |
|
|
93 |
data_id, |
|
|
94 |
device, |
|
|
95 |
logWriter) |
|
|
96 |
logWriter.close() |
|
|
97 |
|
|
|
98 |
|
|
|
99 |
def evaluate_bulk(eval_bulk): |
|
|
100 |
data_dir = eval_bulk['data_dir'] |
|
|
101 |
prediction_path = eval_bulk['save_predictions_dir'] |
|
|
102 |
volumes_txt_file = eval_bulk['volumes_txt_file'] |
|
|
103 |
device = eval_bulk['device'] |
|
|
104 |
label_names = ["vol_ID", "Background", "Left WM", "Left Cortex", "Left Lateral ventricle", "Left Inf LatVentricle", |
|
|
105 |
"Left Cerebellum WM", "Left Cerebellum Cortex", "Left Thalamus", "Left Caudate", "Left Putamen", |
|
|
106 |
"Left Pallidum", "3rd Ventricle", "4th Ventricle", "Brain Stem", "Left Hippocampus", "Left Amygdala", |
|
|
107 |
"CSF (Cranial)", "Left Accumbens", "Left Ventral DC", "Right WM", "Right Cortex", |
|
|
108 |
"Right Lateral Ventricle", "Right Inf LatVentricle", "Right Cerebellum WM", |
|
|
109 |
"Right Cerebellum Cortex", "Right Thalamus", "Right Caudate", "Right Putamen", "Right Pallidum", |
|
|
110 |
"Right Hippocampus", "Right Amygdala", "Right Accumbens", "Right Ventral DC"] |
|
|
111 |
batch_size = eval_bulk['batch_size'] |
|
|
112 |
need_unc = eval_bulk['estimate_uncertainty'] |
|
|
113 |
mc_samples = eval_bulk['mc_samples'] |
|
|
114 |
dir_struct = eval_bulk['directory_struct'] |
|
|
115 |
if 'exit_on_error' in eval_bulk.keys(): |
|
|
116 |
exit_on_error = eval_bulk['exit_on_error'] |
|
|
117 |
else: |
|
|
118 |
exit_on_error = False |
|
|
119 |
|
|
|
120 |
if eval_bulk['view_agg'] == 'True': |
|
|
121 |
coronal_model_path = eval_bulk['coronal_model_path'] |
|
|
122 |
axial_model_path = eval_bulk['axial_model_path'] |
|
|
123 |
eu.evaluate2view(coronal_model_path, |
|
|
124 |
axial_model_path, |
|
|
125 |
volumes_txt_file, |
|
|
126 |
data_dir, device, |
|
|
127 |
prediction_path, |
|
|
128 |
batch_size, |
|
|
129 |
label_names, |
|
|
130 |
dir_struct, |
|
|
131 |
need_unc, |
|
|
132 |
mc_samples, |
|
|
133 |
exit_on_error=exit_on_error) |
|
|
134 |
else: |
|
|
135 |
coronal_model_path = eval_bulk['coronal_model_path'] |
|
|
136 |
eu.evaluate(coronal_model_path, |
|
|
137 |
volumes_txt_file, |
|
|
138 |
data_dir, |
|
|
139 |
device, |
|
|
140 |
prediction_path, |
|
|
141 |
batch_size, |
|
|
142 |
"COR", |
|
|
143 |
label_names, |
|
|
144 |
dir_struct, |
|
|
145 |
need_unc, |
|
|
146 |
mc_samples, |
|
|
147 |
exit_on_error=exit_on_error) |
|
|
148 |
|
|
|
149 |
def compute_vol(eval_bulk): |
|
|
150 |
prediction_path = eval_bulk['save_predictions_dir'] |
|
|
151 |
label_names = ["vol_ID", "Background", "Left WM", "Left Cortex", "Left Lateral ventricle", "Left Inf LatVentricle", |
|
|
152 |
"Left Cerebellum WM", "Left Cerebellum Cortex", "Left Thalamus", "Left Caudate", "Left Putamen", |
|
|
153 |
"Left Pallidum", "3rd Ventricle", "4th Ventricle", "Brain Stem", "Left Hippocampus", "Left Amygdala", |
|
|
154 |
"CSF (Cranial)", "Left Accumbens", "Left Ventral DC", "Right WM", "Right Cortex", |
|
|
155 |
"Right Lateral Ventricle", "Right Inf LatVentricle", "Right Cerebellum WM", |
|
|
156 |
"Right Cerebellum Cortex", "Right Thalamus", "Right Caudate", "Right Putamen", "Right Pallidum", |
|
|
157 |
"Right Hippocampus", "Right Amygdala", "Right Accumbens", "Right Ventral DC"] |
|
|
158 |
volumes_txt_file = eval_bulk['volumes_txt_file'] |
|
|
159 |
|
|
|
160 |
eu.compute_vol_bulk(prediction_path, "Linear", label_names, volumes_txt_file) |
|
|
161 |
|
|
|
162 |
|
|
|
163 |
|
|
|
164 |
def delete_contents(folder): |
|
|
165 |
for the_file in os.listdir(folder): |
|
|
166 |
file_path = os.path.join(folder, the_file) |
|
|
167 |
try: |
|
|
168 |
if os.path.isfile(file_path): |
|
|
169 |
os.unlink(file_path) |
|
|
170 |
elif os.path.isdir(file_path): |
|
|
171 |
shutil.rmtree(file_path) |
|
|
172 |
except Exception as e: |
|
|
173 |
print(e) |
|
|
174 |
|
|
|
175 |
|
|
|
176 |
if __name__ == '__main__': |
|
|
177 |
|
|
|
178 |
parser = argparse.ArgumentParser() |
|
|
179 |
parser.add_argument('--mode', '-m', required=True, help='run mode, valid values are train and eval') |
|
|
180 |
parser.add_argument('--setting_path', '-sp', required=False, help='optional path to settings_eval.ini') |
|
|
181 |
args = parser.parse_args() |
|
|
182 |
|
|
|
183 |
settings = Settings('settings.ini') |
|
|
184 |
common_params, data_params, net_params, train_params, eval_params = settings['COMMON'], settings['DATA'], \ |
|
|
185 |
settings[ |
|
|
186 |
'NETWORK'], settings['TRAINING'], \ |
|
|
187 |
settings['EVAL'] |
|
|
188 |
if args.mode == 'train': |
|
|
189 |
train(train_params, common_params, data_params, net_params) |
|
|
190 |
elif args.mode == 'eval': |
|
|
191 |
evaluate(eval_params, net_params, data_params, common_params, train_params) |
|
|
192 |
elif args.mode == 'eval_bulk': |
|
|
193 |
logging.basicConfig(filename='error.log') |
|
|
194 |
if args.setting_path is not None: |
|
|
195 |
settings_eval = Settings(args.setting_path) |
|
|
196 |
else: |
|
|
197 |
settings_eval = Settings('settings_eval.ini') |
|
|
198 |
evaluate_bulk(settings_eval['EVAL_BULK']) |
|
|
199 |
elif args.mode == 'clear': |
|
|
200 |
shutil.rmtree(os.path.join(common_params['exp_dir'], train_params['exp_name'])) |
|
|
201 |
print("Cleared current experiment directory successfully!!") |
|
|
202 |
shutil.rmtree(os.path.join(common_params['log_dir'], train_params['exp_name'])) |
|
|
203 |
print("Cleared current log directory successfully!!") |
|
|
204 |
|
|
|
205 |
elif args.mode == 'clear-all': |
|
|
206 |
delete_contents(common_params['exp_dir']) |
|
|
207 |
print("Cleared experiments directory successfully!!") |
|
|
208 |
delete_contents(common_params['log_dir']) |
|
|
209 |
print("Cleared logs directory successfully!!") |
|
|
210 |
|
|
|
211 |
elif args.mode == 'compute_vol': |
|
|
212 |
if args.setting_path is not None: |
|
|
213 |
settings_eval = Settings(args.setting_path) |
|
|
214 |
else: |
|
|
215 |
settings_eval = Settings('settings_eval.ini') |
|
|
216 |
compute_vol(settings_eval['EVAL_BULK']) |
|
|
217 |
else: |
|
|
218 |
raise ValueError('Invalid value for mode. only support values are train, eval and clear') |