|
a |
|
b/test_fpred_patch.py |
|
|
1 |
import string |
|
|
2 |
import sys |
|
|
3 |
import lasagne as nn |
|
|
4 |
import numpy as np |
|
|
5 |
import theano |
|
|
6 |
import buffering |
|
|
7 |
import pathfinder |
|
|
8 |
import utils |
|
|
9 |
from configuration import config, set_configuration |
|
|
10 |
from utils_plots import plot_slice_3d_3 |
|
|
11 |
import utils_lung |
|
|
12 |
import logger |
|
|
13 |
|
|
|
14 |
theano.config.warn_float64 = 'raise' |
|
|
15 |
|
|
|
16 |
if len(sys.argv) < 2: |
|
|
17 |
sys.exit("Usage: train.py <configuration_name>") |
|
|
18 |
|
|
|
19 |
config_name = sys.argv[1] |
|
|
20 |
set_configuration('configs_fpred_patch', config_name) |
|
|
21 |
|
|
|
22 |
# metadata |
|
|
23 |
metadata_dir = utils.get_dir_path('models', pathfinder.METADATA_PATH) |
|
|
24 |
metadata_path = utils.find_model_metadata(metadata_dir, config_name) |
|
|
25 |
|
|
|
26 |
metadata = utils.load_pkl(metadata_path) |
|
|
27 |
expid = metadata['experiment_id'] |
|
|
28 |
|
|
|
29 |
# logs |
|
|
30 |
logs_dir = utils.get_dir_path('logs', pathfinder.METADATA_PATH) |
|
|
31 |
sys.stdout = logger.Logger(logs_dir + '/%s-test.log' % expid) |
|
|
32 |
sys.stderr = sys.stdout |
|
|
33 |
|
|
|
34 |
# predictions path |
|
|
35 |
predictions_dir = utils.get_dir_path('model-predictions', pathfinder.METADATA_PATH) |
|
|
36 |
outputs_path = predictions_dir + '/' + expid |
|
|
37 |
utils.auto_make_dir(outputs_path) |
|
|
38 |
|
|
|
39 |
print 'Build model' |
|
|
40 |
model = config().build_model() |
|
|
41 |
all_layers = nn.layers.get_all_layers(model.l_out) |
|
|
42 |
all_params = nn.layers.get_all_params(model.l_out) |
|
|
43 |
num_params = nn.layers.count_params(model.l_out) |
|
|
44 |
print ' number of parameters: %d' % num_params |
|
|
45 |
print string.ljust(' layer output shapes:', 36), |
|
|
46 |
print string.ljust('#params:', 10), |
|
|
47 |
print 'output shape:' |
|
|
48 |
for layer in all_layers: |
|
|
49 |
name = string.ljust(layer.__class__.__name__, 32) |
|
|
50 |
num_param = sum([np.prod(p.get_value().shape) for p in layer.get_params()]) |
|
|
51 |
num_param = string.ljust(num_param.__str__(), 10) |
|
|
52 |
print ' %s %s %s' % (name, num_param, layer.output_shape) |
|
|
53 |
|
|
|
54 |
nn.layers.set_all_param_values(model.l_out, metadata['param_values']) |
|
|
55 |
|
|
|
56 |
valid_loss = config().build_objective(model, deterministic=True) |
|
|
57 |
|
|
|
58 |
x_shared = nn.utils.shared_empty(dim=len(model.l_in.shape)) |
|
|
59 |
y_shared = nn.utils.shared_empty(dim=len(model.l_target.shape)) |
|
|
60 |
|
|
|
61 |
givens_valid = {} |
|
|
62 |
givens_valid[model.l_in.input_var] = x_shared |
|
|
63 |
givens_valid[model.l_target.input_var] = y_shared |
|
|
64 |
|
|
|
65 |
# theano functions |
|
|
66 |
iter_get_predictions = theano.function([], [valid_loss, nn.layers.get_output(model.l_out, deterministic=True)], |
|
|
67 |
givens=givens_valid) |
|
|
68 |
valid_data_iterator = config().valid_data_iterator |
|
|
69 |
|
|
|
70 |
print |
|
|
71 |
print 'Data' |
|
|
72 |
print 'n validation: %d' % valid_data_iterator.nsamples |
|
|
73 |
threshold = 0.2 |
|
|
74 |
n_tp, n_tn, n_fp, n_fn = 0, 0, 0, 0 |
|
|
75 |
n_pos = 0 |
|
|
76 |
n_neg = 0 |
|
|
77 |
|
|
|
78 |
validation_losses = [] |
|
|
79 |
for n, (x_chunk, y_chunk, id_chunk) in enumerate(buffering.buffered_gen_threaded(valid_data_iterator.generate())): |
|
|
80 |
# load chunk to GPU |
|
|
81 |
x_shared.set_value(x_chunk) |
|
|
82 |
y_shared.set_value(y_chunk) |
|
|
83 |
loss, predictions = iter_get_predictions() |
|
|
84 |
validation_losses.append(loss) |
|
|
85 |
targets = y_chunk[0, 0] |
|
|
86 |
p1 = predictions[0][1] |
|
|
87 |
if targets == 1 and p1 >= threshold: |
|
|
88 |
n_tp += 1 |
|
|
89 |
if targets == 1 and p1 < threshold: |
|
|
90 |
n_fn += 1 |
|
|
91 |
if targets == 0 and p1 >= threshold: |
|
|
92 |
n_fp += 1 |
|
|
93 |
if targets == 0 and p1 < threshold: |
|
|
94 |
n_tn += 1 |
|
|
95 |
if targets == 1: |
|
|
96 |
n_pos += 1 |
|
|
97 |
else: |
|
|
98 |
n_neg += 1 |
|
|
99 |
|
|
|
100 |
print id_chunk, targets, p1, loss |
|
|
101 |
|
|
|
102 |
print 'Validation loss', np.mean(validation_losses) |
|
|
103 |
print 'TP', n_tp |
|
|
104 |
print 'TN', n_tn |
|
|
105 |
print 'FP', n_fp |
|
|
106 |
print 'FN', n_fn |
|
|
107 |
print 'n neg', n_neg |
|
|
108 |
print 'n pos', n_pos |