Diff of /test_seg_scan.py [000000] .. [70b6b3]

Switch to unified view

a b/test_seg_scan.py
1
import sys
2
import lasagne as nn
3
import numpy as np
4
import theano
5
import pathfinder
6
import utils
7
from configuration import config, set_configuration
8
from utils_plots import plot_slice_3d_4
9
import theano.tensor as T
10
import blobs_detection
11
import logger
12
import time
13
import multiprocessing as mp
14
import buffering
15
16
17
def extract_candidates(predictions_scan, annotations, tf_matrix, pid, outputs_path):
18
    print 'computing blobs'
19
    start_time = time.time()
20
    blobs = blobs_detection.blob_dog(predictions_scan[0, 0], min_sigma=1, max_sigma=15, threshold=0.1)
21
    print 'blobs computation time:', (time.time() - start_time) / 60.
22
23
    print 'n_blobs detected', len(blobs)
24
    correct_blobs_idxs = []
25
    for zyxd in annotations:
26
        r = zyxd[-1] / 2.
27
        distance2 = ((zyxd[0] - blobs[:, 0]) ** 2
28
                     + (zyxd[1] - blobs[:, 1]) ** 2
29
                     + (zyxd[2] - blobs[:, 2]) ** 2)
30
        blob_idx = np.argmin(distance2)
31
        print 'node', zyxd
32
        print 'closest blob', blobs[blob_idx]
33
        if distance2[blob_idx] <= r ** 2:
34
            correct_blobs_idxs.append(blob_idx)
35
        else:
36
            print 'not detected !!!'
37
38
    # we will save blobs the the voxel space of the original image
39
    # blobs that are true detections will have blobs[-1] = 1 else 0
40
    blobs_original_voxel_coords = []
41
    for j in xrange(blobs.shape[0]):
42
        blob_j = np.append(blobs[j, :3], [1])
43
        blob_j_original = tf_matrix.dot(blob_j)
44
        blob_j_original[-1] = 1 if j in correct_blobs_idxs else 0
45
        if j in correct_blobs_idxs:
46
            print 'blob in original', blob_j_original
47
        blobs_original_voxel_coords.append(blob_j_original)
48
49
    blobs = np.asarray(blobs_original_voxel_coords)
50
    utils.save_pkl(blobs, outputs_path + '/%s.pkl' % pid)
51
52
53
jobs = []
54
theano.config.warn_float64 = 'raise'
55
56
if len(sys.argv) < 2:
57
    sys.exit("Usage: test_luna_scan.py <configuration_name>")
58
59
config_name = sys.argv[1]
60
set_configuration('configs_seg_scan', config_name)
61
62
# predictions path
63
predictions_dir = utils.get_dir_path('model-predictions', pathfinder.METADATA_PATH)
64
outputs_path = predictions_dir + '/%s' % config_name
65
utils.auto_make_dir(outputs_path)
66
67
# logs
68
logs_dir = utils.get_dir_path('logs', pathfinder.METADATA_PATH)
69
sys.stdout = logger.Logger(logs_dir + '/%s.log' % config_name)
70
sys.stderr = sys.stdout
71
72
# builds model and sets its parameters
73
model = config().build_model()
74
75
x_shared = nn.utils.shared_empty(dim=len(model.l_in.shape))
76
idx_z = T.lscalar('idx_z')
77
idx_y = T.lscalar('idx_y')
78
idx_x = T.lscalar('idx_x')
79
80
window_size = config().window_size
81
stride = config().stride
82
n_windows = config().n_windows
83
84
givens = {}
85
givens[model.l_in.input_var] = x_shared
86
87
get_predictions_patch = theano.function([],
88
                                        nn.layers.get_output(model.l_out, deterministic=True),
89
                                        givens=givens,
90
                                        on_unused_input='ignore')
91
92
valid_data_iterator = config().valid_data_iterator
93
94
print
95
print 'Data'
96
print 'n samples: %d' % valid_data_iterator.nsamples
97
98
start_time = time.time()
99
for n, (x, y, lung_mask, annotations, tf_matrix, pid) in enumerate(
100
        buffering.buffered_gen_threaded(valid_data_iterator.generate(), buffer_size=2)):
101
    print '-------------------------------------'
102
    print n, pid
103
104
    predictions_scan = np.zeros((1, 1, n_windows * stride, n_windows * stride, n_windows * stride))
105
106
    for iz in xrange(n_windows):
107
        for iy in xrange(n_windows):
108
            for ix in xrange(n_windows):
109
                start_time_patch = time.time()
110
                x_shared.set_value(x[:, :, iz * stride:(iz * stride) + window_size,
111
                                   iy * stride:(iy * stride) + window_size,
112
                                   ix * stride:(ix * stride) + window_size])
113
                predictions_patch = get_predictions_patch()
114
115
                predictions_scan[0, 0,
116
                iz * stride:(iz + 1) * stride,
117
                iy * stride:(iy + 1) * stride,
118
                ix * stride:(ix + 1) * stride] = predictions_patch
119
120
    if predictions_scan.shape != x.shape:
121
        pad_width = (np.asarray(x.shape) - np.asarray(predictions_scan.shape)) / 2
122
        pad_width = [(p, p) for p in pad_width]
123
        predictions_scan = np.pad(predictions_scan, pad_width=pad_width, mode='constant')
124
125
    if lung_mask is not None:
126
        predictions_scan *= lung_mask
127
128
    for nodule_n, zyxd in enumerate(annotations):
129
        plot_slice_3d_4(input=x[0, 0], mask=y[0, 0], prediction=predictions_scan[0, 0],
130
                        lung_mask=lung_mask[0, 0] if lung_mask is not None else x[0, 0],
131
                        axis=0, pid='-'.join([str(n), str(nodule_n), str(pid)]),
132
                        img_dir=outputs_path, idx=zyxd)
133
    print 'saved plot'
134
    print 'time since start:', (time.time() - start_time) / 60.
135
136
    jobs = [job for job in jobs if job.is_alive]
137
    if len(jobs) >= 3:
138
        jobs[0].join()
139
        del jobs[0]
140
    jobs.append(
141
        mp.Process(target=extract_candidates, args=(predictions_scan, annotations, tf_matrix, pid, outputs_path)))
142
    jobs[-1].daemon = True
143
    jobs[-1].start()
144
145
for job in jobs: job.join()