Diff of /run.py [000000] .. [6f9c00]

Switch to unified view

a b/run.py
1
import argparse
2
import os
3
import torch
4
import utils.evaluator as eu
5
from quicknat import QuickNat
6
from settings import Settings
7
from solver import Solver
8
from utils.data_utils import get_imdb_dataset
9
from utils.log_utils import LogWriter
10
import logging
11
import shutil
12
13
torch.set_default_tensor_type('torch.FloatTensor')
14
15
16
def load_data(data_params):
17
    print("Loading dataset")
18
    train_data, test_data = get_imdb_dataset(data_params)
19
    print("Train size: %i" % len(train_data))
20
    print("Test size: %i" % len(test_data))
21
    return train_data, test_data
22
23
24
def train(train_params, common_params, data_params, net_params):
25
26
    train_data, test_data = load_data(data_params)
27
28
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=train_params['train_batch_size'], shuffle=True,
29
                                               num_workers=4, pin_memory=True)
30
    val_loader = torch.utils.data.DataLoader(test_data, batch_size=train_params['val_batch_size'], shuffle=False,
31
                                             num_workers=4, pin_memory=True)
32
33
    net_params_ = net_params.copy()
34
    empty_model = QuickNat(net_params_)
35
    if train_params['use_pre_trained']:
36
        quicknat_model = torch.load(train_params['pre_trained_path'])
37
    else:
38
        quicknat_model = QuickNat(net_params)
39
40
    solver = Solver(quicknat_model,
41
                    device=common_params['device'],
42
                    num_class=net_params['num_class'],
43
                    optim_args={"lr": train_params['learning_rate'],
44
                                "betas": train_params['optim_betas'],
45
                                "eps": train_params['optim_eps'],
46
                                "weight_decay": train_params['optim_weight_decay']},
47
                    model_name=common_params['model_name'],
48
                    exp_name=train_params['exp_name'],
49
                    labels=data_params['labels'],
50
                    log_nth=train_params['log_nth'],
51
                    num_epochs=train_params['num_epochs'],
52
                    lr_scheduler_step_size=train_params['lr_scheduler_step_size'],
53
                    lr_scheduler_gamma=train_params['lr_scheduler_gamma'],
54
                    use_last_checkpoint=train_params['use_last_checkpoint'],
55
                    log_dir=common_params['log_dir'],
56
                    exp_dir=common_params['exp_dir'])
57
58
    solver.train(train_loader, val_loader)
59
    final_model_path = os.path.join(common_params['save_model_dir'], train_params['final_model_file'])
60
    # quicknat_model.save(final_model_path)
61
    solver.model = empty_model
62
    solver.save_best_model(final_model_path)
63
    print("final model saved @ " + str(final_model_path))
64
65
66
def evaluate(eval_params, net_params, data_params, common_params, train_params):
67
    eval_model_path = eval_params['eval_model_path']
68
    num_classes = net_params['num_class']
69
    labels = data_params['labels']
70
    data_dir = eval_params['data_dir']
71
    label_dir = eval_params['label_dir']
72
    volumes_txt_file = eval_params['volumes_txt_file']
73
    remap_config = eval_params['remap_config']
74
    device = common_params['device']
75
    log_dir = common_params['log_dir']
76
    exp_dir = common_params['exp_dir']
77
    exp_name = train_params['exp_name']
78
    save_predictions_dir = eval_params['save_predictions_dir']
79
    prediction_path = os.path.join(exp_dir, exp_name, save_predictions_dir)
80
    orientation = eval_params['orientation']
81
    data_id = eval_params['data_id']
82
83
    logWriter = LogWriter(num_classes, log_dir, exp_name, labels=labels)
84
85
    avg_dice_score, class_dist = eu.evaluate_dice_score(eval_model_path,
86
                                                        num_classes,
87
                                                        data_dir,
88
                                                        label_dir,
89
                                                        volumes_txt_file,
90
                                                        remap_config,
91
                                                        orientation,
92
                                                        prediction_path,
93
                                                        data_id,
94
                                                        device,
95
                                                        logWriter)
96
    logWriter.close()
97
98
99
def evaluate_bulk(eval_bulk):
100
    data_dir = eval_bulk['data_dir']
101
    prediction_path = eval_bulk['save_predictions_dir']
102
    volumes_txt_file = eval_bulk['volumes_txt_file']
103
    device = eval_bulk['device']
104
    label_names = ["vol_ID", "Background", "Left WM", "Left Cortex", "Left Lateral ventricle", "Left Inf LatVentricle",
105
                   "Left Cerebellum WM", "Left Cerebellum Cortex", "Left Thalamus", "Left Caudate", "Left Putamen",
106
                   "Left Pallidum", "3rd Ventricle", "4th Ventricle", "Brain Stem", "Left Hippocampus", "Left Amygdala",
107
                   "CSF (Cranial)", "Left Accumbens", "Left Ventral DC", "Right WM", "Right Cortex",
108
                   "Right Lateral Ventricle", "Right Inf LatVentricle", "Right Cerebellum WM",
109
                   "Right Cerebellum Cortex", "Right Thalamus", "Right Caudate", "Right Putamen", "Right Pallidum",
110
                   "Right Hippocampus", "Right Amygdala", "Right Accumbens", "Right Ventral DC"]
111
    batch_size = eval_bulk['batch_size']
112
    need_unc = eval_bulk['estimate_uncertainty']
113
    mc_samples = eval_bulk['mc_samples']
114
    dir_struct = eval_bulk['directory_struct']
115
    if 'exit_on_error' in eval_bulk.keys():
116
        exit_on_error = eval_bulk['exit_on_error']
117
    else:
118
        exit_on_error = False
119
120
    if eval_bulk['view_agg'] == 'True':
121
        coronal_model_path = eval_bulk['coronal_model_path']
122
        axial_model_path = eval_bulk['axial_model_path']
123
        eu.evaluate2view(coronal_model_path,
124
                         axial_model_path,
125
                         volumes_txt_file,
126
                         data_dir, device,
127
                         prediction_path,
128
                         batch_size,
129
                         label_names,
130
                         dir_struct,
131
                         need_unc,
132
                         mc_samples,
133
                         exit_on_error=exit_on_error)
134
    else:
135
        coronal_model_path = eval_bulk['coronal_model_path']
136
        eu.evaluate(coronal_model_path,
137
                    volumes_txt_file,
138
                    data_dir,
139
                    device,
140
                    prediction_path,
141
                    batch_size,
142
                    "COR",
143
                    label_names,
144
                    dir_struct,
145
                    need_unc,
146
                    mc_samples,
147
                    exit_on_error=exit_on_error)
148
149
def compute_vol(eval_bulk):
150
    prediction_path = eval_bulk['save_predictions_dir']
151
    label_names = ["vol_ID", "Background", "Left WM", "Left Cortex", "Left Lateral ventricle", "Left Inf LatVentricle",
152
                   "Left Cerebellum WM", "Left Cerebellum Cortex", "Left Thalamus", "Left Caudate", "Left Putamen",
153
                   "Left Pallidum", "3rd Ventricle", "4th Ventricle", "Brain Stem", "Left Hippocampus", "Left Amygdala",
154
                   "CSF (Cranial)", "Left Accumbens", "Left Ventral DC", "Right WM", "Right Cortex",
155
                   "Right Lateral Ventricle", "Right Inf LatVentricle", "Right Cerebellum WM",
156
                   "Right Cerebellum Cortex", "Right Thalamus", "Right Caudate", "Right Putamen", "Right Pallidum",
157
                   "Right Hippocampus", "Right Amygdala", "Right Accumbens", "Right Ventral DC"]
158
    volumes_txt_file = eval_bulk['volumes_txt_file']
159
160
    eu.compute_vol_bulk(prediction_path, "Linear", label_names, volumes_txt_file)
161
162
163
164
def delete_contents(folder):
165
    for the_file in os.listdir(folder):
166
        file_path = os.path.join(folder, the_file)
167
        try:
168
            if os.path.isfile(file_path):
169
                os.unlink(file_path)
170
            elif os.path.isdir(file_path):
171
                shutil.rmtree(file_path)
172
        except Exception as e:
173
            print(e)
174
175
176
if __name__ == '__main__':
177
178
    parser = argparse.ArgumentParser()
179
    parser.add_argument('--mode', '-m', required=True, help='run mode, valid values are train and eval')
180
    parser.add_argument('--setting_path', '-sp', required=False, help='optional path to settings_eval.ini')
181
    args = parser.parse_args()
182
183
    settings = Settings('settings.ini')
184
    common_params, data_params, net_params, train_params, eval_params = settings['COMMON'], settings['DATA'], \
185
                                                                        settings[
186
                                                                            'NETWORK'], settings['TRAINING'], \
187
                                                                        settings['EVAL']
188
    if args.mode == 'train':
189
        train(train_params, common_params, data_params, net_params)
190
    elif args.mode == 'eval':
191
        evaluate(eval_params, net_params, data_params, common_params, train_params)
192
    elif args.mode == 'eval_bulk':
193
        logging.basicConfig(filename='error.log')
194
        if args.setting_path is not None:
195
            settings_eval = Settings(args.setting_path)
196
        else:
197
            settings_eval = Settings('settings_eval.ini')
198
        evaluate_bulk(settings_eval['EVAL_BULK'])
199
    elif args.mode == 'clear':
200
        shutil.rmtree(os.path.join(common_params['exp_dir'], train_params['exp_name']))
201
        print("Cleared current experiment directory successfully!!")
202
        shutil.rmtree(os.path.join(common_params['log_dir'], train_params['exp_name']))
203
        print("Cleared current log directory successfully!!")
204
205
    elif args.mode == 'clear-all':
206
        delete_contents(common_params['exp_dir'])
207
        print("Cleared experiments directory successfully!!")
208
        delete_contents(common_params['log_dir'])
209
        print("Cleared logs directory successfully!!")
210
211
    elif args.mode == 'compute_vol':
212
        if args.setting_path is not None:
213
            settings_eval = Settings(args.setting_path)
214
        else:
215
            settings_eval = Settings('settings_eval.ini')
216
        compute_vol(settings_eval['EVAL_BULK'])
217
    else:
218
        raise ValueError('Invalid value for mode. only support values are train, eval and clear')