|
a |
|
b/utils/exp_utils.py |
|
|
1 |
#!/usr/bin/env python |
|
|
2 |
# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ). |
|
|
3 |
# |
|
|
4 |
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
5 |
# you may not use this file except in compliance with the License. |
|
|
6 |
# You may obtain a copy of the License at |
|
|
7 |
# |
|
|
8 |
# http://www.apache.org/licenses/LICENSE-2.0 |
|
|
9 |
# |
|
|
10 |
# Unless required by applicable law or agreed to in writing, software |
|
|
11 |
# distributed under the License is distributed on an "AS IS" BASIS, |
|
|
12 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
13 |
# See the License for the specific language governing permissions and |
|
|
14 |
# limitations under the License. |
|
|
15 |
# ============================================================================== |
|
|
16 |
from typing import Iterable, Tuple, Any, Union |
|
|
17 |
import os, sys |
|
|
18 |
import subprocess |
|
|
19 |
from multiprocessing import Process |
|
|
20 |
|
|
|
21 |
import importlib.util |
|
|
22 |
import pickle |
|
|
23 |
|
|
|
24 |
import logging |
|
|
25 |
from torch.utils.tensorboard import SummaryWriter |
|
|
26 |
|
|
|
27 |
from collections import OrderedDict |
|
|
28 |
import numpy as np |
|
|
29 |
import torch |
|
|
30 |
import pandas as pd |
|
|
31 |
|
|
|
32 |
def split_off_process(target, *args, daemon: bool=False, **kwargs): |
|
|
33 |
"""Start a process that won't block parent script. |
|
|
34 |
No join(), no return value. If daemon=False: before parent exits, it waits for this to finish. |
|
|
35 |
:param target: the target function of the process. |
|
|
36 |
:params *args: args to pass to target. |
|
|
37 |
:param daemon: if False: before parent exits, it waits for this process to finish. |
|
|
38 |
:params **kwargs: kwargs to pass to target. |
|
|
39 |
""" |
|
|
40 |
p = Process(target=target, args=tuple(args), kwargs=kwargs, daemon=daemon) |
|
|
41 |
p.start() |
|
|
42 |
return p |
|
|
43 |
|
|
|
44 |
def get_formatted_duration(seconds: float, format: str="hms") -> str: |
|
|
45 |
"""Format a time in seconds. |
|
|
46 |
:param format: "hms" for hours mins secs or "ms" for min secs. |
|
|
47 |
""" |
|
|
48 |
mins, secs = divmod(seconds, 60) |
|
|
49 |
if format == "ms": |
|
|
50 |
t = "{:d}m:{:02d}s".format(int(mins), int(secs)) |
|
|
51 |
elif format == "hms": |
|
|
52 |
h, mins = divmod(mins, 60) |
|
|
53 |
t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs)) |
|
|
54 |
else: |
|
|
55 |
raise Exception("Format {} not available, only 'hms' or 'ms'".format(format)) |
|
|
56 |
return t |
|
|
57 |
|
|
|
58 |
class CombinedLogger(object): |
|
|
59 |
"""Combine console and tensorboard logger and record system metrics. |
|
|
60 |
""" |
|
|
61 |
|
|
|
62 |
def __init__(self, name: str, log_dir: str, server_env: bool=True, fold: Union[int, str]="all"): |
|
|
63 |
self.pylogger = logging.getLogger(name) |
|
|
64 |
self.tboard = SummaryWriter(log_dir=os.path.join(log_dir, "tboard")) |
|
|
65 |
self.log_dir = log_dir |
|
|
66 |
self.fold = str(fold) |
|
|
67 |
self.server_env = server_env |
|
|
68 |
|
|
|
69 |
self.pylogger.setLevel(logging.DEBUG) |
|
|
70 |
self.log_file = os.path.join(log_dir, "fold_"+self.fold, 'exec.log') |
|
|
71 |
os.makedirs(os.path.dirname(self.log_file), exist_ok=True) |
|
|
72 |
self.pylogger.addHandler(logging.FileHandler(self.log_file)) |
|
|
73 |
if not server_env: |
|
|
74 |
self.pylogger.addHandler(ColorHandler()) |
|
|
75 |
else: |
|
|
76 |
self.pylogger.addHandler(logging.StreamHandler()) |
|
|
77 |
self.pylogger.propagate = False |
|
|
78 |
|
|
|
79 |
def __getattr__(self, attr): |
|
|
80 |
"""delegate all undefined method requests to objects of |
|
|
81 |
this class in order pylogger, tboard (first find first serve). |
|
|
82 |
E.g., combinedlogger.add_scalars(...) should trigger self.tboard.add_scalars(...) |
|
|
83 |
""" |
|
|
84 |
for obj in [self.pylogger, self.tboard]: |
|
|
85 |
if attr in dir(obj): |
|
|
86 |
return getattr(obj, attr) |
|
|
87 |
print("logger attr not found") |
|
|
88 |
|
|
|
89 |
def set_logfile(self, fold: Union[int, str, None]=None, log_file: Union[str, None]=None): |
|
|
90 |
if fold is not None: |
|
|
91 |
self.fold = str(fold) |
|
|
92 |
if log_file is None: |
|
|
93 |
self.log_file = os.path.join(self.log_dir, "fold_"+self.fold, 'exec.log') |
|
|
94 |
else: |
|
|
95 |
self.log_file = log_file |
|
|
96 |
os.makedirs(os.path.dirname(self.log_file), exist_ok=True) |
|
|
97 |
for hdlr in self.pylogger.handlers: |
|
|
98 |
hdlr.close() |
|
|
99 |
self.pylogger.handlers = [] |
|
|
100 |
self.pylogger.addHandler(logging.FileHandler(self.log_file)) |
|
|
101 |
if not self.server_env: |
|
|
102 |
self.pylogger.addHandler(ColorHandler()) |
|
|
103 |
else: |
|
|
104 |
self.pylogger.addHandler(logging.StreamHandler()) |
|
|
105 |
|
|
|
106 |
def metrics2tboard(self, metrics, global_step=None, suptitle=None): |
|
|
107 |
""" |
|
|
108 |
:param metrics: {'train': dataframe, 'val':df}, df as produced in |
|
|
109 |
evaluator.py.evaluate_predictions |
|
|
110 |
""" |
|
|
111 |
# print("metrics", metrics) |
|
|
112 |
if global_step is None: |
|
|
113 |
global_step = len(metrics['train'][list(metrics['train'].keys())[0]]) - 1 |
|
|
114 |
if suptitle is not None: |
|
|
115 |
suptitle = str(suptitle) |
|
|
116 |
else: |
|
|
117 |
suptitle = "Fold_" + str(self.fold) |
|
|
118 |
|
|
|
119 |
for key in ['train', 'val']: |
|
|
120 |
# series = {k:np.array(v[-1]) for (k,v) in metrics[key].items() if not np.isnan(v[-1]) and not 'Bin_Stats' in k} |
|
|
121 |
loss_series = {} |
|
|
122 |
mon_met_series = {} |
|
|
123 |
for tag, val in metrics[key].items(): |
|
|
124 |
val = val[-1] # maybe remove list wrapping, recording in evaluator? |
|
|
125 |
if 'loss' in tag.lower() and not np.isnan(val): |
|
|
126 |
loss_series["{}".format(tag)] = val |
|
|
127 |
elif not np.isnan(val): |
|
|
128 |
mon_met_series["{}".format(tag)] = val |
|
|
129 |
|
|
|
130 |
self.tboard.add_scalars(suptitle + "/Losses/{}".format(key), loss_series, global_step) |
|
|
131 |
self.tboard.add_scalars(suptitle + "/Monitor_Metrics/{}".format(key), mon_met_series, global_step) |
|
|
132 |
self.tboard.add_scalars(suptitle + "/Learning_Rate", metrics["lr"], global_step) |
|
|
133 |
return |
|
|
134 |
|
|
|
135 |
def __del__(self): # otherwise might produce multiple prints e.g. in ipython console |
|
|
136 |
for hdlr in self.pylogger.handlers: |
|
|
137 |
hdlr.close() |
|
|
138 |
self.pylogger.handlers = [] |
|
|
139 |
del self.pylogger |
|
|
140 |
self.tboard.flush() |
|
|
141 |
# close somehow prevents main script from exiting |
|
|
142 |
# maybe revise this issue in a later pytorch version |
|
|
143 |
#self.tboard.close() |
|
|
144 |
|
|
|
145 |
|
|
|
146 |
def get_logger(exp_dir: str, server_env: bool=False) -> CombinedLogger: |
|
|
147 |
""" |
|
|
148 |
creates logger instance. writing out info to file, to terminal and to tensorboard. |
|
|
149 |
:param exp_dir: experiment directory, where exec.log file is stored. |
|
|
150 |
:param server_env: True if operating in server environment (e.g., gpu cluster) |
|
|
151 |
:return: custom CombinedLogger instance. |
|
|
152 |
""" |
|
|
153 |
log_dir = os.path.join(exp_dir, "logs") |
|
|
154 |
logger = CombinedLogger('medicaldetectiontoolkit', log_dir, server_env=server_env) |
|
|
155 |
print("Logging to {}".format(logger.log_file)) |
|
|
156 |
return logger |
|
|
157 |
|
|
|
158 |
|
|
|
159 |
def prep_exp(dataset_path, exp_path, server_env, use_stored_settings=True, is_training=True): |
|
|
160 |
""" |
|
|
161 |
I/O handling, creating of experiment folder structure. Also creates a snapshot of configs/model scripts and copies them to the exp_dir. |
|
|
162 |
This way the exp_dir contains all info needed to conduct an experiment, independent to changes in actual source code. Thus, training/inference of this experiment can be started at anytime. Therefore, the model script is copied back to the source code dir as tmp_model (tmp_backbone). |
|
|
163 |
Provides robust structure for cloud deployment. |
|
|
164 |
:param dataset_path: path to source code for specific data set. (e.g. medicaldetectiontoolkit/lidc_exp) |
|
|
165 |
:param exp_path: path to experiment directory. |
|
|
166 |
:param server_env: boolean flag. pass to configs script for cloud deployment. |
|
|
167 |
:param use_stored_settings: boolean flag. When starting training: If True, starts training from snapshot in existing experiment directory, else creates experiment directory on the fly using configs/model scripts from source code. |
|
|
168 |
:param is_training: boolean flag. distinguishes train vs. inference mode. |
|
|
169 |
:return: |
|
|
170 |
""" |
|
|
171 |
|
|
|
172 |
if is_training: |
|
|
173 |
if use_stored_settings: |
|
|
174 |
cf_file = import_module('cf_file', os.path.join(exp_path, 'configs.py')) |
|
|
175 |
cf = cf_file.configs(server_env) |
|
|
176 |
# in this mode, previously saved model and backbone need to be found in exp dir. |
|
|
177 |
if not os.path.isfile(os.path.join(exp_path, 'model.py')) or \ |
|
|
178 |
not os.path.isfile(os.path.join(exp_path, 'backbone.py')): |
|
|
179 |
raise Exception( |
|
|
180 |
"Selected use_stored_settings option but no model and/or backbone source files exist in exp dir.") |
|
|
181 |
cf.model_path = os.path.join(exp_path, 'model.py') |
|
|
182 |
cf.backbone_path = os.path.join(exp_path, 'backbone.py') |
|
|
183 |
else: |
|
|
184 |
# this case overwrites settings files in exp dir, i.e., default_configs, configs, backbone, model |
|
|
185 |
os.makedirs(exp_path, exist_ok=True) |
|
|
186 |
# run training with source code info and copy snapshot of model to exp_dir for later testing (overwrite scripts if exp_dir already exists.) |
|
|
187 |
subprocess.call('cp {} {}'.format('default_configs.py', os.path.join(exp_path, 'default_configs.py')), |
|
|
188 |
shell=True) |
|
|
189 |
subprocess.call( |
|
|
190 |
'cp {} {}'.format(os.path.join(dataset_path, 'configs.py'), os.path.join(exp_path, 'configs.py')), |
|
|
191 |
shell=True) |
|
|
192 |
cf_file = import_module('cf_file', os.path.join(dataset_path, 'configs.py')) |
|
|
193 |
cf = cf_file.configs(server_env) |
|
|
194 |
subprocess.call('cp {} {}'.format(cf.model_path, os.path.join(exp_path, 'model.py')), shell=True) |
|
|
195 |
subprocess.call('cp {} {}'.format(cf.backbone_path, os.path.join(exp_path, 'backbone.py')), shell=True) |
|
|
196 |
if os.path.isfile(os.path.join(exp_path, "folds_ids.pickle")): |
|
|
197 |
subprocess.call('rm {}'.format(os.path.join(exp_path, "folds_ids.pickle")), shell=True) |
|
|
198 |
|
|
|
199 |
else: |
|
|
200 |
# testing, use model and backbone stored in exp dir. |
|
|
201 |
cf_file = import_module('cf_file', os.path.join(exp_path, 'configs.py')) |
|
|
202 |
cf = cf_file.configs(server_env) |
|
|
203 |
cf.model_path = os.path.join(exp_path, 'model.py') |
|
|
204 |
cf.backbone_path = os.path.join(exp_path, 'backbone.py') |
|
|
205 |
|
|
|
206 |
|
|
|
207 |
cf.exp_dir = exp_path |
|
|
208 |
cf.test_dir = os.path.join(cf.exp_dir, 'test') |
|
|
209 |
cf.plot_dir = os.path.join(cf.exp_dir, 'plots') |
|
|
210 |
if not os.path.exists(cf.test_dir): |
|
|
211 |
os.mkdir(cf.test_dir) |
|
|
212 |
if not os.path.exists(cf.plot_dir): |
|
|
213 |
os.mkdir(cf.plot_dir) |
|
|
214 |
cf.experiment_name = exp_path.split("/")[-1] |
|
|
215 |
cf.created_fold_id_pickle = False |
|
|
216 |
|
|
|
217 |
return cf |
|
|
218 |
|
|
|
219 |
|
|
|
220 |
|
|
|
221 |
def import_module(name: str, path: str): |
|
|
222 |
""" |
|
|
223 |
correct way of importing a module dynamically in python 3. |
|
|
224 |
:param name: name given to module instance. |
|
|
225 |
:param path: path to module. |
|
|
226 |
:return: module: returned module instance. |
|
|
227 |
""" |
|
|
228 |
spec = importlib.util.spec_from_file_location(name, path) |
|
|
229 |
module = importlib.util.module_from_spec(spec) |
|
|
230 |
spec.loader.exec_module(module) |
|
|
231 |
return module |
|
|
232 |
|
|
|
233 |
|
|
|
234 |
def set_params_flag(module: torch.nn.Module, flag: Tuple[str, Any], check_overwrite: bool = True) -> torch.nn.Module: |
|
|
235 |
"""Set an attribute for all passed module parameters. |
|
|
236 |
|
|
|
237 |
:param flag: tuple (str attribute name : attr value) |
|
|
238 |
:param check_overwrite: if True, assert that attribute not already exists. |
|
|
239 |
|
|
|
240 |
""" |
|
|
241 |
for param in module.parameters(): |
|
|
242 |
if check_overwrite: |
|
|
243 |
assert not hasattr(param, flag[0]), \ |
|
|
244 |
"param {} already has attr {} (w/ val {})".format(param, flag[0], getattr(param, flag[0])) |
|
|
245 |
setattr(param, flag[0], flag[1]) |
|
|
246 |
return module |
|
|
247 |
|
|
|
248 |
def parse_params_for_optim(net: torch.nn.Module, weight_decay: float = 0., exclude_from_wd: Iterable = ("norm",)) -> list: |
|
|
249 |
"""Split network parameters into weight-decay dependent groups for the optimizer. |
|
|
250 |
:param net: network. |
|
|
251 |
:param weight_decay: weight decay value for the parameters that it is applied to. excluded parameters will have |
|
|
252 |
weight decay 0. |
|
|
253 |
:param exclude_from_wd: List of strings of parameter-group names to exclude from weight decay. Options: "norm", "bias". |
|
|
254 |
:return: |
|
|
255 |
""" |
|
|
256 |
if weight_decay is None: |
|
|
257 |
weight_decay = 0. |
|
|
258 |
# pytorch implements parameter groups as dicts {'params': ...} and |
|
|
259 |
# weight decay as p.data.mul_(1 - group['lr'] * group['weight_decay']) |
|
|
260 |
norm_types = [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, |
|
|
261 |
torch.nn.InstanceNorm1d, torch.nn.InstanceNorm2d, torch.nn.InstanceNorm3d, |
|
|
262 |
torch.nn.LayerNorm, torch.nn.GroupNorm, torch.nn.SyncBatchNorm, torch.nn.LocalResponseNorm] |
|
|
263 |
level_map = {"bias": "weight", |
|
|
264 |
"norm": "module"} |
|
|
265 |
type_map = {"norm": norm_types} |
|
|
266 |
|
|
|
267 |
exclude_from_wd = [str(name).lower() for name in exclude_from_wd] |
|
|
268 |
exclude_weight_names = [k for k, v in level_map.items() if k in exclude_from_wd and v == "weight"] |
|
|
269 |
exclude_module_types = tuple([type_ for k, v in level_map.items() if (k in exclude_from_wd and v == "module") |
|
|
270 |
for type_ in type_map[k]]) |
|
|
271 |
|
|
|
272 |
if exclude_from_wd: |
|
|
273 |
print("excluding {} from weight decay.".format(exclude_from_wd)) |
|
|
274 |
|
|
|
275 |
for module in net.modules(): |
|
|
276 |
if isinstance(module, exclude_module_types): |
|
|
277 |
set_params_flag(module, ("no_wd", True)) |
|
|
278 |
for param_name, param in net.named_parameters(): |
|
|
279 |
if np.any([ename in param_name for ename in exclude_weight_names]): |
|
|
280 |
setattr(param, "no_wd", True) |
|
|
281 |
|
|
|
282 |
with_dec, no_dec = [], [] |
|
|
283 |
for param in net.parameters(): |
|
|
284 |
if hasattr(param, "no_wd") and param.no_wd == True: |
|
|
285 |
no_dec.append(param) |
|
|
286 |
else: |
|
|
287 |
with_dec.append(param) |
|
|
288 |
orig_ps = sum(p.numel() for p in net.parameters()) |
|
|
289 |
with_ps = sum(p.numel() for p in with_dec) |
|
|
290 |
wo_ps = sum(p.numel() for p in no_dec) |
|
|
291 |
assert orig_ps == with_ps + wo_ps, "orig n parameters {} unequals sum of with wd {} and w/o wd {}."\ |
|
|
292 |
.format(orig_ps, with_ps, wo_ps) |
|
|
293 |
|
|
|
294 |
groups = [{'params': gr, 'weight_decay': wd} for (gr, wd) in [(no_dec, 0.), (with_dec, weight_decay)] if len(gr)>0] |
|
|
295 |
return groups |
|
|
296 |
|
|
|
297 |
|
|
|
298 |
class ModelSelector: |
|
|
299 |
''' |
|
|
300 |
saves a checkpoint after each epoch as 'last_state' (can be loaded to continue interrupted training). |
|
|
301 |
saves the top-k (k=cf.save_n_models) ranked epochs. In inference, predictions of multiple epochs can be ensembled to improve performance. |
|
|
302 |
''' |
|
|
303 |
|
|
|
304 |
def __init__(self, cf, logger): |
|
|
305 |
|
|
|
306 |
self.cf = cf |
|
|
307 |
self.saved_epochs = [-1] * cf.save_n_models |
|
|
308 |
self.logger = logger |
|
|
309 |
|
|
|
310 |
def run_model_selection(self, net: torch.nn.Module, optimizer: torch.optim.Optimizer, |
|
|
311 |
monitor_metrics: dict, epoch: int): |
|
|
312 |
|
|
|
313 |
# take the mean over all selection criteria in each epoch |
|
|
314 |
non_nan_scores = np.mean(np.array([[0 if (ii is None or np.isnan(ii)) else ii for ii in monitor_metrics['val'][sc]] for sc in self.cf.model_selection_criteria]), 0) |
|
|
315 |
epochs_scores = [ii for ii in non_nan_scores[1:]] |
|
|
316 |
# ranking of epochs according to model_selection_criterion |
|
|
317 |
epoch_ranking = np.argsort(epochs_scores, kind="stable")[::-1] + 1 #epochs start at 1 |
|
|
318 |
# if set in configs, epochs < min_save_thresh are discarded from saving process. |
|
|
319 |
epoch_ranking = epoch_ranking[epoch_ranking >= self.cf.min_save_thresh] |
|
|
320 |
|
|
|
321 |
# check if current epoch is among the top-k epochs. |
|
|
322 |
if epoch in epoch_ranking[:self.cf.save_n_models]: |
|
|
323 |
|
|
|
324 |
save_dir = os.path.join(self.cf.fold_dir, '{}_best_checkpoint'.format(epoch)) |
|
|
325 |
if not os.path.exists(save_dir): |
|
|
326 |
os.mkdir(save_dir) |
|
|
327 |
|
|
|
328 |
torch.save(net.state_dict(), os.path.join(save_dir, 'params.pth')) |
|
|
329 |
with open(os.path.join(save_dir, 'monitor_metrics.pickle'), 'wb') as handle: |
|
|
330 |
pickle.dump(monitor_metrics, handle) |
|
|
331 |
# save epoch_ranking to keep info for inference. |
|
|
332 |
np.save(os.path.join(self.cf.fold_dir, 'epoch_ranking'), epoch_ranking[:self.cf.save_n_models]) |
|
|
333 |
np.save(os.path.join(save_dir, 'epoch_ranking'), epoch_ranking[:self.cf.save_n_models]) |
|
|
334 |
|
|
|
335 |
self.logger.info( |
|
|
336 |
"saving current epoch {} at rank {}".format(epoch, np.argwhere(epoch_ranking == epoch))) |
|
|
337 |
# delete params of the epoch that just fell out of the top-k epochs. |
|
|
338 |
for se in [int(ii.split('_')[0]) for ii in os.listdir(self.cf.fold_dir) if 'best_checkpoint' in ii]: |
|
|
339 |
if se in epoch_ranking[self.cf.save_n_models:]: |
|
|
340 |
subprocess.call('rm -rf {}'.format(os.path.join(self.cf.fold_dir, '{}_best_checkpoint'.format(se))), shell=True) |
|
|
341 |
self.logger.info('deleting epoch {} at rank {}'.format(se, np.argwhere(epoch_ranking == se))) |
|
|
342 |
|
|
|
343 |
state = { |
|
|
344 |
'epoch': epoch, |
|
|
345 |
'state_dict': net.state_dict(), |
|
|
346 |
'optimizer': optimizer.state_dict(), |
|
|
347 |
} |
|
|
348 |
|
|
|
349 |
# save checkpoint of current epoch. |
|
|
350 |
save_dir = os.path.join(self.cf.fold_dir, 'last_checkpoint'.format(epoch)) |
|
|
351 |
if not os.path.exists(save_dir): |
|
|
352 |
os.mkdir(save_dir) |
|
|
353 |
torch.save(state, os.path.join(save_dir, 'params.pth')) |
|
|
354 |
np.save(os.path.join(save_dir, 'epoch_ranking'), epoch_ranking[:self.cf.save_n_models]) |
|
|
355 |
with open(os.path.join(save_dir, 'monitor_metrics.pickle'), 'wb') as handle: |
|
|
356 |
pickle.dump(monitor_metrics, handle) |
|
|
357 |
|
|
|
358 |
|
|
|
359 |
|
|
|
360 |
def load_checkpoint(checkpoint_path: str, net: torch.nn.Module, optimizer: torch.optim.Optimizer) -> Tuple: |
|
|
361 |
|
|
|
362 |
checkpoint = torch.load(os.path.join(checkpoint_path, 'params.pth')) |
|
|
363 |
net.load_state_dict(checkpoint['state_dict']) |
|
|
364 |
optimizer.load_state_dict(checkpoint['optimizer']) |
|
|
365 |
with open(os.path.join(checkpoint_path, 'monitor_metrics.pickle'), 'rb') as handle: |
|
|
366 |
monitor_metrics = pickle.load(handle) |
|
|
367 |
starting_epoch = checkpoint['epoch'] + 1 |
|
|
368 |
return starting_epoch, net, optimizer, monitor_metrics |
|
|
369 |
|
|
|
370 |
|
|
|
371 |
|
|
|
372 |
def prepare_monitoring(cf): |
|
|
373 |
""" |
|
|
374 |
creates dictionaries, where train/val metrics are stored. |
|
|
375 |
""" |
|
|
376 |
metrics = {} |
|
|
377 |
# first entry for loss dict accounts for epoch starting at 1. |
|
|
378 |
metrics['train'] = OrderedDict() |
|
|
379 |
metrics['val'] = OrderedDict() |
|
|
380 |
metric_classes = [] |
|
|
381 |
if 'rois' in cf.report_score_level: |
|
|
382 |
metric_classes.extend([v for k, v in cf.class_dict.items()]) |
|
|
383 |
if 'patient' in cf.report_score_level: |
|
|
384 |
metric_classes.extend(['patient']) |
|
|
385 |
for cl in metric_classes: |
|
|
386 |
metrics['train'][cl + '_ap'] = [np.nan] |
|
|
387 |
metrics['val'][cl + '_ap'] = [np.nan] |
|
|
388 |
if cl == 'patient': |
|
|
389 |
metrics['train'][cl + '_auc'] = [np.nan] |
|
|
390 |
metrics['val'][cl + '_auc'] = [np.nan] |
|
|
391 |
|
|
|
392 |
return metrics |
|
|
393 |
|
|
|
394 |
|
|
|
395 |
|
|
|
396 |
def create_csv_output(results_list, cf, logger): |
|
|
397 |
""" |
|
|
398 |
Write out test set predictions to .csv file. output format is one line per prediction: |
|
|
399 |
PatientID | PredictionID | [y1 x1 y2 x2 (z1) (z2)] | score | pred_classID |
|
|
400 |
Note, that prediction coordinates correspond to images as loaded for training/testing and need to be adapted when |
|
|
401 |
plotted over raw data (before preprocessing/resampling). |
|
|
402 |
:param results_list: [[patient_results, patient_id], [patient_results, patient_id], ...] |
|
|
403 |
""" |
|
|
404 |
|
|
|
405 |
logger.info('creating csv output file at {}'.format(os.path.join(cf.test_dir, 'results.csv'))) |
|
|
406 |
predictions_df = pd.DataFrame(columns = ['patientID', 'predictionID', 'coords', 'score', 'pred_classID']) |
|
|
407 |
for r in results_list: |
|
|
408 |
|
|
|
409 |
pid = r[1] |
|
|
410 |
|
|
|
411 |
#optionally load resampling info from preprocessing to match output predictions with raw data. |
|
|
412 |
#with open(os.path.join(cf.exp_dir, 'test_resampling_info', pid), 'rb') as handle: |
|
|
413 |
# resampling_info = pickle.load(handle) |
|
|
414 |
|
|
|
415 |
for bix, box in enumerate(r[0][0]): |
|
|
416 |
if box["box_type"] == "gt": |
|
|
417 |
continue |
|
|
418 |
assert box['box_type'] == 'det', box['box_type'] |
|
|
419 |
coords = box['box_coords'] |
|
|
420 |
score = box['box_score'] |
|
|
421 |
pred_class_id = box['box_pred_class_id'] |
|
|
422 |
out_coords = [] |
|
|
423 |
if score >= cf.min_det_thresh: |
|
|
424 |
out_coords.append(coords[0]) #* resampling_info['scale'][0]) |
|
|
425 |
out_coords.append(coords[1]) #* resampling_info['scale'][1]) |
|
|
426 |
out_coords.append(coords[2]) #* resampling_info['scale'][0]) |
|
|
427 |
out_coords.append(coords[3]) #* resampling_info['scale'][1]) |
|
|
428 |
if len(coords) > 4: |
|
|
429 |
out_coords.append(coords[4]) #* resampling_info['scale'][2] + resampling_info['z_crop']) |
|
|
430 |
out_coords.append(coords[5]) #* resampling_info['scale'][2] + resampling_info['z_crop']) |
|
|
431 |
|
|
|
432 |
predictions_df.loc[len(predictions_df)] = [pid, bix, out_coords, score, pred_class_id] |
|
|
433 |
try: |
|
|
434 |
fold = cf.fold |
|
|
435 |
except: |
|
|
436 |
fold = 'hold_out' |
|
|
437 |
predictions_df.to_csv(os.path.join(cf.exp_dir, 'results_{}.csv'.format(fold)), index=False) |
|
|
438 |
|
|
|
439 |
|
|
|
440 |
|
|
|
441 |
class _AnsiColorizer(object): |
|
|
442 |
""" |
|
|
443 |
A colorizer is an object that loosely wraps around a stream, allowing |
|
|
444 |
callers to write text to the stream in a particular color. |
|
|
445 |
|
|
|
446 |
Colorizer classes must implement C{supported()} and C{write(text, color)}. |
|
|
447 |
""" |
|
|
448 |
_colors = dict(black=30, red=31, green=32, yellow=33, |
|
|
449 |
blue=34, magenta=35, cyan=36, white=37, default=39) |
|
|
450 |
|
|
|
451 |
def __init__(self, stream): |
|
|
452 |
self.stream = stream |
|
|
453 |
|
|
|
454 |
@classmethod |
|
|
455 |
def supported(cls, stream=sys.stdout): |
|
|
456 |
""" |
|
|
457 |
A class method that returns True if the current platform supports |
|
|
458 |
coloring terminal output using this method. Returns False otherwise. |
|
|
459 |
""" |
|
|
460 |
if not stream.isatty(): |
|
|
461 |
return False # auto color only on TTYs |
|
|
462 |
try: |
|
|
463 |
import curses |
|
|
464 |
except ImportError: |
|
|
465 |
return False |
|
|
466 |
else: |
|
|
467 |
try: |
|
|
468 |
try: |
|
|
469 |
return curses.tigetnum("colors") > 2 |
|
|
470 |
except curses.error: |
|
|
471 |
curses.setupterm() |
|
|
472 |
return curses.tigetnum("colors") > 2 |
|
|
473 |
except: |
|
|
474 |
raise |
|
|
475 |
# guess false in case of error |
|
|
476 |
return False |
|
|
477 |
|
|
|
478 |
def write(self, text, color): |
|
|
479 |
""" |
|
|
480 |
Write the given text to the stream in the given color. |
|
|
481 |
|
|
|
482 |
@param text: Text to be written to the stream. |
|
|
483 |
|
|
|
484 |
@param color: A string label for a color. e.g. 'red', 'white'. |
|
|
485 |
""" |
|
|
486 |
color = self._colors[color] |
|
|
487 |
self.stream.write('\x1b[%sm%s\x1b[0m' % (color, text)) |
|
|
488 |
|
|
|
489 |
|
|
|
490 |
|
|
|
491 |
class ColorHandler(logging.StreamHandler): |
|
|
492 |
|
|
|
493 |
|
|
|
494 |
def __init__(self, stream=sys.stdout): |
|
|
495 |
super(ColorHandler, self).__init__(_AnsiColorizer(stream)) |
|
|
496 |
|
|
|
497 |
def emit(self, record): |
|
|
498 |
msg_colors = { |
|
|
499 |
logging.DEBUG: "green", |
|
|
500 |
logging.INFO: "default", |
|
|
501 |
logging.WARNING: "red", |
|
|
502 |
logging.ERROR: "red" |
|
|
503 |
} |
|
|
504 |
color = msg_colors.get(record.levelno, "blue") |
|
|
505 |
self.stream.write(record.msg + "\n", color) |
|
|
506 |
|