a b/data_iterators.py
1
import numpy as np
2
import utils_lung
3
import pathfinder
4
import utils
5
6
7
# 6% to 28% for nodules 5 to 10 mm,
8
prob5 = (0.01+0.06)/2.
9
slope10 = (0.28-prob5) / (10.-5.)
10
offset10 = prob5 - slope10*5.
11
12
slope20 = (0.64-0.28) / (20.-10.)
13
offset20 = 0.28 - slope20*10.
14
15
# and 64% to 82% for nodules >20 mm in diameter
16
slope25 = (0.82-0.64) / (25.-20.)
17
offset25 = 0.64 - slope25*20.
18
19
slope30 = (0.93-0.82) / (30.-25.)
20
offset30 = 0.82 - slope30*25.
21
22
# For nodules more than 3 cm in diameter, 93% to 97% are malignant
23
slope40 = (0.97-0.93) / (40.-30.)
24
offset40 = 0.93 - slope40*30.
25
26
def diameter_to_prob(diam):
27
    # The prevalence of malignancy is 0% to 1% for nodules <5 mm,
28
    if diam < 5:
29
        p = prob5*diam/5.
30
    elif diam < 10:
31
        p = slope10*diam+offset10
32
    elif diam < 20:
33
        p = slope20*diam+offset20
34
    elif diam < 25:
35
        p = slope25*diam+offset25
36
    elif diam < 30:
37
        p = slope30*diam+offset30
38
    else:
39
        p = slope40 * diam + offset40
40
    return np.clip(p ,0.,1.)
41
42
43
class LunaDataGenerator(object):
44
    def __init__(self, data_path, transform_params, data_prep_fun, rng,
45
                 random, infinite, patient_ids=None, **kwargs):
46
47
        self.patient_ids = patient_ids
48
        if patient_ids:
49
            self.patient_paths = [data_path + '/' + p + '.mhd' for p in patient_ids]
50
        else:
51
            patient_paths = utils_lung.get_patient_data_paths(data_path)
52
            self.patient_paths = [p for p in patient_paths if '.mhd' in p]
53
54
        self.id2annotations = utils_lung.read_luna_annotations(pathfinder.LUNA_LABELS_PATH)
55
        self.nsamples = len(self.patient_paths)
56
        self.data_path = data_path
57
        self.rng = rng
58
        self.random = random
59
        self.infinite = infinite
60
        self.data_prep_fun = data_prep_fun
61
        self.transform_params = transform_params
62
63
    def generate(self):
64
        while True:
65
            rand_idxs = np.arange(self.nsamples)
66
            if self.random:
67
                self.rng.shuffle(rand_idxs)
68
            for pos in xrange(0, len(rand_idxs)):
69
                idx = rand_idxs[pos]
70
71
                patient_path = self.patient_paths[idx]
72
                pid = utils_lung.extract_pid_filename(patient_path)
73
74
                img, origin, pixel_spacing = utils_lung.read_mhd(patient_path)
75
                x, y, annotations, tf_matrix = self.data_prep_fun(data=img,
76
                                                                  pixel_spacing=pixel_spacing,
77
                                                                  luna_annotations=
78
                                                                  self.id2annotations[pid],
79
                                                                  luna_origin=origin)
80
81
                x = np.float32(x)[None, None, :, :, :]
82
                y = np.float32(y)[None, None, :, :, :]
83
84
                yield x, y, None, annotations, tf_matrix, pid
85
86
            if not self.infinite:
87
                break
88
89
90
91
class LunaSimpleDataGenerator(object):
92
    def __init__(self, data_path, patient_ids=None, **kwargs):
93
94
        self.patient_ids = patient_ids
95
96
        self.data_path = data_path
97
        self.file_extension = '.pkl' if 'pkl' in data_path else '.mhd'
98
99
        if patient_ids:
100
            self.patient_paths = [data_path + '/' + p + self.file_extension for p in patient_ids]
101
        else:
102
            patient_paths = utils_lung.get_patient_data_paths(data_path)
103
            self.patient_paths = [p for p in patient_paths if self.file_extension in p]
104
        
105
        self.nsamples = len(self.patient_paths)
106
107
        print self.data_path
108
109
    def generate(self):
110
        for patient_path in self.patient_paths:
111
            pid = utils_lung.extract_pid_filename(patient_path)
112
113
            img, origin, pixel_spacing = utils_lung.read_pkl(patient_path) \
114
                if self.file_extension == '.pkl' else utils_lung.read_mhd(patient_path)
115
116
            x = np.float32(img)
117
118
            yield x, pid
119
120
121
122
123
class LunaScanPositiveDataGenerator(LunaDataGenerator):
124
    def __init__(self, data_path, transform_params, data_prep_fun, rng,
125
                 random, infinite, patient_ids=None, **kwargs):
126
        super(LunaScanPositiveDataGenerator, self).__init__(data_path, transform_params, data_prep_fun, rng,
127
                                                            random, infinite, patient_ids, **kwargs)
128
        patient_ids_all = [utils_lung.extract_pid_filename(p) for p in self.patient_paths]
129
        patient_ids_pos = [pid for pid in patient_ids_all if pid in self.id2annotations.keys()]
130
        self.patient_paths = [data_path + '/' + p + '.mhd' for p in patient_ids_pos]
131
        self.nsamples = len(self.patient_paths)
132
133
134
class LunaScanPositiveLungMaskDataGenerator(LunaDataGenerator):
135
    def __init__(self, data_path, batch_size, transform_params, data_prep_fun, rng,
136
                 full_batch, random, infinite, patient_ids=None, **kwargs):
137
        super(LunaScanPositiveLungMaskDataGenerator, self).__init__(data_path, transform_params,
138
                                                                    data_prep_fun, rng,
139
                                                                    random, infinite, patient_ids, **kwargs)
140
141
    def generate(self):
142
        while True:
143
            rand_idxs = np.arange(self.nsamples)
144
            if self.random:
145
                self.rng.shuffle(rand_idxs)
146
            for pos in xrange(0, len(rand_idxs)):
147
                idx = rand_idxs[pos]
148
149
                patient_path = self.patient_paths[idx]
150
                pid = utils_lung.extract_pid_filename(patient_path)
151
152
                img, origin, pixel_spacing = utils_lung.read_mhd(patient_path)
153
                x, y, lung_mask, annotations, tf_matrix = self.data_prep_fun(data=img,
154
                                                                             pixel_spacing=pixel_spacing,
155
                                                                             luna_annotations=
156
                                                                             self.id2annotations[pid],
157
                                                                             luna_origin=origin)
158
159
                x = np.float32(x)[None, None, :, :, :]
160
                y = np.float32(y)[None, None, :, :, :]
161
                lung_mask = np.float32(lung_mask)[None, None, :, :, :]
162
163
                yield x, y, lung_mask, annotations, tf_matrix, pid
164
165
            if not self.infinite:
166
                break
167
168
169
170
class LunaScanMaskPositiveDataGenerator(LunaDataGenerator):
171
    def __init__(self, data_path, seg_data_path, batch_size, transform_params, data_prep_fun, rng,
172
                 full_batch, random, infinite, patient_ids=None, **kwargs):
173
        super(LunaScanMaskPositiveDataGenerator, self).__init__(data_path, transform_params,
174
                                                                    data_prep_fun, rng,
175
                                                                    random, infinite, patient_ids, **kwargs)
176
        self.seg_data_path = seg_data_path
177
        self.mask_paths = [seg_data_path + '/' + p + '.mhd' for p in self.patient_ids]
178
179
    def generate(self):
180
        while True:
181
            rand_idxs = np.arange(self.nsamples)
182
            if self.random:
183
                self.rng.shuffle(rand_idxs)
184
            for pos in xrange(0, len(rand_idxs)):
185
                idx = rand_idxs[pos]
186
187
                ct_scan_path = self.patient_paths[idx]
188
                mask_path = self.mask_paths[idx]
189
190
                pid = utils_lung.extract_pid_filename(ct_scan_path)
191
192
                ct_scan, ct_origin, ct_pixel_spacing = utils_lung.read_mhd(ct_scan_path)
193
                mask, mask_origin, mask_pixel_spacing = utils_lung.read_mhd(mask_path)
194
195
                assert(sum(abs(ct_origin-mask_origin)) < 1e-9)
196
                assert(sum(abs(ct_pixel_spacing-mask_pixel_spacing)) < 1e-9)
197
198
                ct, lung_mask, annotations, tf_matrix = self.data_prep_fun(ct_scan=ct_scan, mask=mask,
199
                                                                             pixel_spacing=ct_pixel_spacing,
200
                                                                             luna_annotations=
201
                                                                             self.id2annotations[pid],
202
                                                                             luna_origin=ct_origin)
203
204
                ct = np.float32(ct)[None, None, :, :, :]
205
                lung_mask = np.float32(lung_mask)[None, None, :, :, :]
206
207
                yield ct, lung_mask, annotations, tf_matrix, pid
208
209
            if not self.infinite:
210
                break
211
212
213
#for lung segmentation, does not work yet
214
class PatchLunaDataGenerator(object):
215
    def __init__(self, ct_data_path, seg_data_path, batch_size, transform_params, data_prep_fun, rng,
216
                 full_batch, random, infinite, patient_ids=None, **kwargs):
217
218
        if patient_ids:
219
            self.patient_ids = patient_ids
220
            #self.patient_paths = [data_path + '/' + p + '.mhd' for p in patient_ids]
221
        else:
222
            patient_paths = utils_lung.get_patient_data_paths(data_path)
223
            #self.patient_paths = [p for p in patient_paths if '.mhd' in p]
224
            self.patient_ids = [utils_lung.extract_pid_filename(p) for p in self.patient_paths]\
225
226
        self.nsamples = len(self.patient_ids)
227
        self.ct_data_path = ct_data_path
228
        self.seg_data_path = seg_data_path
229
        self.rng = rng
230
        self.random = random
231
        self.infinite = infinite
232
        self.data_prep_fun = data_prep_fun
233
        self.transform_params = transform_params
234
        self.batch_size = batch_size
235
        self.full_batch = full_batch
236
237
    def generate(self):
238
        while True:
239
            rand_idxs = np.arange(self.nsamples)
240
            if self.random:
241
                self.rng.shuffle(rand_idxs)
242
            for pos in xrange(0, len(rand_idxs), self.batch_size):
243
                idxs_batch = rand_idxs[pos:pos + self.batch_size]
244
                nb = len(idxs_batch)
245
                # allocate batches
246
                x_batch = np.zeros((nb, 1) + self.transform_params['patch_size'], dtype='float32')
247
                y_batch = np.zeros((nb, 1) + self.transform_params['patch_size'], dtype='float32')
248
                patients_ids = []
249
250
                for i, idx in enumerate(idxs_batch):
251
                    pid = self.patient_ids[idx]
252
                    ct_path = self.ct_data_path + pid + '.mhd'
253
                    seg_path = self.seg_data_path + pid + '.mhd'
254
                    patients_ids.append(pid)
255
256
                    ct_img, ct_origin, ct_pixel_spacing = utils_lung.read_mhd(ct_path)
257
                    seg_img, seg_origin, seg_pixel_spacing = utils_lung.read_mhd(seg_path)
258
259
                    assert(np.sum(ct_origin-seg_origin) <  1e-9)
260
                    assert(np.sum(ct_pixel_spacing-seg_pixel_spacing) <  1e-9)
261
262
                    print 'ct_img.shape', ct_img.shape
263
                    print 'seg_img.shape', seg_img.shape
264
                    w,h,d = self.transform_params['patch_size']
265
                    patch_center = [self.rng.randint(w/2, ct_img.shape[0]-w/2),
266
                                    self.rng.randint(h/2, ct_img.shape[1]-h/2),
267
                                    self.rng.randint(d/2, ct_img.shape[1]-d/2)]
268
                    print patch_center
269
270
271
                    x_batch[i, 0, :, :, :], y_batch[i, 0, :, :, :]  = self.data_prep_fun(ct_img=ct_img, seg_img=seg_img,
272
                                                                    patch_center=patch_center,
273
                                                                    pixel_spacing=ct_pixel_spacing,
274
                                                                    luna_origin=ct_origin)
275
276
                    # y_batch[i, 0, :, :, :],  = self.data_prep_fun(data=seg_img,
277
                    #                                                 patch_center=patch_center,
278
                    #                                                 pixel_spacing=seg_pixel_spacing,
279
                    #                                                 luna_origin=seg_origin)
280
                if self.full_batch:
281
                    if nb == self.batch_size:
282
                        yield x_batch, y_batch, patients_ids
283
                else:
284
                    yield x_batch, y_batch, patients_ids
285
286
            if not self.infinite:
287
                break
288
289
#works, tested
290
class LunaScanDataGenerator(object):
291
    def __init__(self, ct_data_path, seg_data_path, patient_ids=None, **kwargs):
292
293
        if patient_ids:
294
            self.patient_ids = patient_ids
295
            #self.patient_paths = [data_path + '/' + p + '.mhd' for p in patient_ids]
296
        else:
297
            patient_paths = utils_lung.get_patient_data_paths(ct_data_path)
298
            #self.patient_paths = [p for p in patient_paths if '.mhd' in p]
299
            self.patient_ids = [utils_lung.extract_pid_filename(p) for p in self.patient_paths]\
300
301
        self.nsamples = len(self.patient_ids)
302
        self.ct_data_path = ct_data_path
303
        self.seg_data_path = seg_data_path
304
        
305
306
    def generate(self):
307
        for pid in self.patient_ids:
308
            ct_path = self.ct_data_path + pid + '.mhd'
309
            seg_path = self.seg_data_path + pid + '.mhd'
310
311
            ct_img, ct_origin, ct_pixel_spacing = utils_lung.read_mhd(ct_path)
312
            seg_img, seg_origin, seg_pixel_spacing = utils_lung.read_mhd(seg_path)
313
314
            assert(np.sum(ct_origin-seg_origin) <  1e-9)
315
            assert(np.sum(ct_pixel_spacing-seg_pixel_spacing) <  1e-9)
316
317
            print 'ct_img.shape', ct_img.shape
318
            print 'seg_img.shape', seg_img.shape
319
320
            yield ct_img, seg_img, pid
321
322
323
class PatchPositiveLunaDataGenerator(object):
324
    def __init__(self, data_path, batch_size, transform_params, data_prep_fun, rng,
325
                 full_batch, random, infinite, patient_ids=None, **kwargs):
326
327
        self.id2annotations = utils_lung.read_luna_annotations(pathfinder.LUNA_LABELS_PATH)
328
329
        if patient_ids:
330
            self.patient_paths = [data_path + '/' + p + '.mhd' for p in patient_ids]
331
        else:
332
            patient_paths = utils_lung.get_patient_data_paths(data_path)
333
            self.patient_paths = [p for p in patient_paths if '.mhd' in p]
334
335
        patient_ids_all = [utils_lung.extract_pid_filename(p) for p in self.patient_paths]
336
        patient_ids_pos = [pid for pid in patient_ids_all if pid in self.id2annotations.keys()]
337
        self.patient_paths = [data_path + '/' + p + '.mhd' for p in patient_ids_pos]
338
339
        self.nsamples = len(self.patient_paths)
340
        self.data_path = data_path
341
        self.rng = rng
342
        self.random = random
343
        self.infinite = infinite
344
        self.data_prep_fun = data_prep_fun
345
        self.transform_params = transform_params
346
        self.batch_size = batch_size
347
        self.full_batch = full_batch
348
349
    def generate(self):
350
        while True:
351
            rand_idxs = np.arange(self.nsamples)
352
            if self.random:
353
                self.rng.shuffle(rand_idxs)
354
            for pos in xrange(0, len(rand_idxs), self.batch_size):
355
                idxs_batch = rand_idxs[pos:pos + self.batch_size]
356
                nb = len(idxs_batch)
357
                # allocate batches
358
                x_batch = np.zeros((nb, 1) + self.transform_params['patch_size'], dtype='float32')
359
                y_batch = np.zeros((nb, 1) + self.transform_params['patch_size'], dtype='float32')
360
                patients_ids = []
361
362
                for i, idx in enumerate(idxs_batch):
363
                    patient_path = self.patient_paths[idx]
364
                    id = utils_lung.extract_pid_filename(patient_path)
365
                    patients_ids.append(id)
366
                    img, origin, pixel_spacing = utils_lung.read_mhd(patient_path)
367
368
                    patient_annotations = self.id2annotations[id]
369
                    patch_center = patient_annotations[self.rng.randint(len(patient_annotations))]
370
                    x_batch[i, 0, :, :, :], y_batch[i, 0, :, :, :] = self.data_prep_fun(data=img,
371
                                                                                        patch_center=patch_center,
372
                                                                                        pixel_spacing=pixel_spacing,
373
                                                                                        luna_annotations=patient_annotations,
374
                                                                                        luna_origin=origin)
375
                if self.full_batch:
376
                    if nb == self.batch_size:
377
                        yield x_batch, y_batch, patients_ids
378
                else:
379
                    yield x_batch, y_batch, patients_ids
380
381
            if not self.infinite:
382
                break
383
384
385
386
class ValidPatchPositiveLunaDataGenerator(object):
387
    def __init__(self, data_path, transform_params, patient_ids, data_prep_fun, **kwargs):
388
389
        id2positive_annotations = utils_lung.read_luna_annotations(pathfinder.LUNA_LABELS_PATH)
390
391
        self.id2positive_annotations = {}
392
        self.id2patient_path = {}
393
        n_positive = 0
394
        for pid in patient_ids:
395
            if pid in id2positive_annotations:
396
                self.id2positive_annotations[pid] = id2positive_annotations[pid]
397
                n_pos = len(id2positive_annotations[pid])
398
                self.id2patient_path[pid] = data_path + '/' + pid + '.mhd'
399
                n_positive += n_pos
400
401
        self.nsamples = n_positive
402
        self.data_path = data_path
403
        self.data_prep_fun = data_prep_fun
404
        self.transform_params = transform_params
405
406
    def generate(self):
407
408
        for pid in self.id2positive_annotations.iterkeys():
409
            for patch_center in self.id2positive_annotations[pid]:
410
                patient_path = self.id2patient_path[pid]
411
                img, origin, pixel_spacing = utils_lung.read_mhd(patient_path)
412
413
                patient_annotations = self.id2positive_annotations[pid]
414
                x_batch, y_batch = self.data_prep_fun(data=img,
415
                                                      patch_center=patch_center,
416
                                                      pixel_spacing=pixel_spacing,
417
                                                      luna_annotations=patient_annotations,
418
                                                      luna_origin=origin)
419
                x_batch = np.float32(x_batch)[None, None, :, :, :]
420
                y_batch = np.float32(y_batch)[None, None, :, :, :]
421
                yield x_batch, y_batch, [pid]
422
423
424
class CandidatesLunaDataGenerator(object):
425
    def __init__(self, data_path, batch_size, transform_params, patient_ids, data_prep_fun, rng,
426
                 full_batch, random, infinite, positive_proportion, **kwargs):
427
428
        id2positive_annotations = utils_lung.read_luna_annotations(pathfinder.LUNA_LABELS_PATH)
429
        id2negative_annotations = utils_lung.read_luna_negative_candidates(pathfinder.LUNA_CANDIDATES_PATH)
430
431
        self.file_extension = '.pkl' if 'pkl' in data_path else '.mhd'
432
        self.id2positive_annotations = {}
433
        self.id2negative_annotations = {}
434
        self.patient_paths = []
435
        n_positive, n_negative = 0, 0
436
        for pid in patient_ids:
437
            if pid in id2positive_annotations:
438
                self.id2positive_annotations[pid] = id2positive_annotations[pid]
439
                self.id2negative_annotations[pid] = id2negative_annotations[pid]
440
                self.patient_paths.append(data_path + '/' + pid + self.file_extension)
441
                n_positive += len(id2positive_annotations[pid])
442
                n_negative += len(id2negative_annotations[pid])
443
444
        print 'n positive', n_positive
445
        print 'n negative', n_negative
446
447
        self.nsamples = len(self.patient_paths)
448
449
        print 'n patients', self.nsamples
450
        self.data_path = data_path
451
        self.batch_size = batch_size
452
        self.rng = rng
453
        self.full_batch = full_batch
454
        self.random = random
455
        self.infinite = infinite
456
        self.data_prep_fun = data_prep_fun
457
        self.transform_params = transform_params
458
        self.positive_proportion = positive_proportion
459
460
    def generate(self):
461
        while True:
462
            rand_idxs = np.arange(self.nsamples)
463
            if self.random:
464
                self.rng.shuffle(rand_idxs)
465
            for pos in xrange(0, len(rand_idxs), self.batch_size):
466
                idxs_batch = rand_idxs[pos:pos + self.batch_size]
467
                nb = len(idxs_batch)
468
                # allocate batches
469
                x_batch = np.zeros((nb, 1) + self.transform_params['patch_size'], dtype='float32')
470
                y_batch = np.zeros((nb, 1), dtype='float32')
471
                patients_ids = []
472
473
                for i, idx in enumerate(idxs_batch):
474
                    patient_path = self.patient_paths[idx]
475
476
                    id = utils_lung.extract_pid_filename(patient_path, self.file_extension)
477
                    patients_ids.append(id)
478
479
                    img, origin, pixel_spacing = utils_lung.read_pkl(patient_path) \
480
                        if self.file_extension == '.pkl' else utils_lung.read_mhd(patient_path)
481
                    if i < np.rint(self.batch_size * self.positive_proportion):
482
                        patient_annotations = self.id2positive_annotations[id]
483
                    else:
484
                        patient_annotations = self.id2negative_annotations[id]
485
486
                    patch_center = patient_annotations[self.rng.randint(len(patient_annotations))]
487
488
                    y_batch[i] = float(patch_center[-1] > 0)
489
                    x_batch[i, 0, :, :, :] = self.data_prep_fun(data=img,
490
                                                                patch_center=patch_center,
491
                                                                pixel_spacing=pixel_spacing,
492
                                                                luna_origin=origin)
493
494
                if self.full_batch:
495
                    if nb == self.batch_size:
496
                        yield x_batch, y_batch, patients_ids
497
                else:
498
                    yield x_batch, y_batch, patients_ids
499
500
            if not self.infinite:
501
                break
502
503
504
505
class CandidatesLunaDataGenerator(object):
506
    def __init__(self, data_path, batch_size, transform_params, patient_ids, data_prep_fun, rng,
507
                 full_batch, random, infinite, positive_proportion, return_malignancy=False, **kwargs):
508
509
        id2positive_annotations = utils_lung.read_luna_annotations(pathfinder.LUNA_LABELS_PATH)
510
        id2negative_annotations = utils_lung.read_luna_negative_candidates(pathfinder.LUNA_CANDIDATES_PATH)
511
512
        self.file_extension = '.pkl' if 'pkl' in data_path else '.mhd'
513
        self.id2positive_annotations = {}
514
        self.id2negative_annotations = {}
515
        self.patient_paths = []
516
        n_positive, n_negative = 0, 0
517
        for pid in patient_ids:
518
            if pid in id2positive_annotations:
519
                self.id2positive_annotations[pid] = id2positive_annotations[pid]
520
                self.id2negative_annotations[pid] = id2negative_annotations[pid]
521
                self.patient_paths.append(data_path + '/' + pid + self.file_extension)
522
                n_positive += len(id2positive_annotations[pid])
523
                n_negative += len(id2negative_annotations[pid])
524
525
        print 'n positive', n_positive
526
        print 'n negative', n_negative
527
528
        self.nsamples = len(self.patient_paths)
529
530
        print 'n patients', self.nsamples
531
        self.data_path = data_path
532
        self.batch_size = batch_size
533
        self.rng = rng
534
        self.full_batch = full_batch
535
        self.random = random
536
        self.infinite = infinite
537
        self.data_prep_fun = data_prep_fun
538
        self.transform_params = transform_params
539
        self.positive_proportion = positive_proportion
540
        self.return_malignancy = return_malignancy
541
542
    def generate(self):
543
        while True:
544
            rand_idxs = np.arange(self.nsamples)
545
            if self.random:
546
                self.rng.shuffle(rand_idxs)
547
            for pos in xrange(0, len(rand_idxs), self.batch_size):
548
                idxs_batch = rand_idxs[pos:pos + self.batch_size]
549
                nb = len(idxs_batch)
550
                # allocate batches
551
                x_batch = np.zeros((nb,) + self.transform_params['patch_size'], dtype='float32')
552
                y_batch = np.zeros((nb,), dtype='float32')
553
                patients_ids = []
554
555
                for i, idx in enumerate(idxs_batch):
556
                    patient_path = self.patient_paths[idx]
557
558
                    id = utils_lung.extract_pid_filename(patient_path, self.file_extension)
559
                    patients_ids.append(id)
560
561
                    img, origin, pixel_spacing = utils_lung.read_pkl(patient_path) \
562
                        if self.file_extension == '.pkl' else utils_lung.read_mhd(patient_path)
563
                    if i < np.rint(self.batch_size * self.positive_proportion):
564
                        patient_annotations = self.id2positive_annotations[id]
565
                    else:
566
                        patient_annotations = self.id2negative_annotations[id]
567
568
                    patch_center = patient_annotations[self.rng.randint(len(patient_annotations))]
569
570
                    if self.return_malignancy:
571
                        y_batch[i] = np.float32(diameter_to_prob(patch_center[-1]))
572
                    else:
573
                        y_batch[i] = float(patch_center[-1] > 0) 
574
                    x_batch[i, :, :, :] = self.data_prep_fun(data=img,
575
                                                                patch_center=patch_center,
576
                                                                pixel_spacing=pixel_spacing,
577
                                                                luna_origin=origin)
578
579
                if self.full_batch:
580
                    if nb == self.batch_size:
581
                        yield x_batch, y_batch, patients_ids
582
                else:
583
                    yield x_batch, y_batch, patients_ids
584
585
            if not self.infinite:
586
                break
587
588
589
class CandidatesLunaValidDataGenerator(object):
590
    def __init__(self, data_path, transform_params, patient_ids, data_prep_fun, return_malignancy=False, **kwargs):
591
        rng = np.random.RandomState(42)  # do not change this!!!
592
593
        id2positive_annotations = utils_lung.read_luna_annotations(pathfinder.LUNA_LABELS_PATH)
594
        id2negative_annotations = utils_lung.read_luna_negative_candidates(pathfinder.LUNA_CANDIDATES_PATH)
595
596
        self.file_extension = '.pkl' if 'pkl' in data_path else '.mhd'
597
        self.id2positive_annotations = {}
598
        self.id2negative_annotations = {}
599
        self.id2patient_path = {}
600
        n_positive, n_negative = 0, 0
601
        for pid in patient_ids:
602
            if pid in id2positive_annotations:
603
                self.id2positive_annotations[pid] = id2positive_annotations[pid]
604
                negative_annotations = id2negative_annotations[pid]
605
                n_pos = len(id2positive_annotations[pid])
606
                n_neg = len(id2negative_annotations[pid])
607
                neg_idxs = rng.choice(n_neg, size=n_pos, replace=False)
608
                negative_annotations_selected = []
609
                for i in neg_idxs:
610
                    negative_annotations_selected.append(negative_annotations[i])
611
                self.id2negative_annotations[pid] = negative_annotations_selected
612
613
                self.id2patient_path[pid] = data_path + '/' + pid + self.file_extension
614
                n_positive += n_pos
615
                n_negative += n_pos
616
617
        print 'n positive', n_positive
618
        print 'n negative', n_negative
619
620
        self.nsamples = len(self.id2patient_path)
621
        self.data_path = data_path
622
        self.rng = rng
623
        self.data_prep_fun = data_prep_fun
624
        self.transform_params = transform_params
625
        self.return_malignancy = return_malignancy
626
627
    def generate(self):
628
629
        for pid in self.id2positive_annotations.iterkeys():
630
            for patch_center in self.id2positive_annotations[pid]:
631
                patient_path = self.id2patient_path[pid]
632
633
                img, origin, pixel_spacing = utils_lung.read_pkl(patient_path) \
634
                    if self.file_extension == '.pkl' else utils_lung.read_mhd(patient_path)
635
                
636
                if self.return_malignancy:
637
                    y_batch = np.array([diameter_to_prob(patch_center[-1])], dtype='float32')
638
                else:
639
                    y_batch = np.array([1.], dtype='float32')
640
641
                x_batch = np.float32(self.data_prep_fun(data=img,
642
                                                        patch_center=patch_center,
643
                                                        pixel_spacing=pixel_spacing,
644
                                                        luna_origin=origin))[None, :, :, :]
645
646
                yield x_batch, y_batch, [pid]
647
648
            for patch_center in self.id2negative_annotations[pid]:
649
                patient_path = self.id2patient_path[pid]
650
651
                img, origin, pixel_spacing = utils_lung.read_pkl(patient_path) \
652
                    if self.file_extension == '.pkl' else utils_lung.read_mhd(patient_path)
653
                y_batch = np.array([0.], dtype='float32')
654
                x_batch = np.float32(self.data_prep_fun(data=img,
655
                                                        patch_center=patch_center,
656
                                                        pixel_spacing=pixel_spacing,
657
                                                        luna_origin=origin))[None, :, :, :]
658
659
                yield x_batch, y_batch, [pid]
660
661
662
class FixedCandidatesLunaDataGenerator(object):
663
    def __init__(self, data_path, transform_params, id2candidates_path, data_prep_fun, top_n=None):
664
665
        self.file_extension = '.pkl' if 'pkl' in data_path else '.mhd'
666
        self.id2candidates_path = id2candidates_path
667
        self.id2patient_path = {}
668
        for pid in id2candidates_path.keys():
669
            self.id2patient_path[pid] = data_path + '/' + pid + self.file_extension
670
671
        self.nsamples = len(self.id2patient_path)
672
        self.data_path = data_path
673
        self.data_prep_fun = data_prep_fun
674
        self.transform_params = transform_params
675
        self.top_n = top_n
676
677
    def generate(self):
678
679
        for pid in self.id2candidates_path.iterkeys():
680
            patient_path = self.id2patient_path[pid]
681
            print 'PATIENT', pid
682
            candidates = utils.load_pkl(self.id2candidates_path[pid])
683
            if self.top_n is not None:
684
                candidates = candidates[:self.top_n]
685
                print candidates
686
            print 'n blobs', len(candidates)
687
688
            img, origin, pixel_spacing = utils_lung.read_pkl(patient_path) \
689
                if self.file_extension == '.pkl' else utils_lung.read_mhd(patient_path)
690
691
            for candidate in candidates:
692
                y_batch = np.array(candidate, dtype='float32')
693
                patch_center = candidate[:3]
694
                x_batch = np.float32(self.data_prep_fun(data=img,
695
                                                        patch_center=patch_center,
696
                                                        pixel_spacing=pixel_spacing,
697
                                                        luna_origin=origin))[None, None, :, :, :]
698
699
                yield x_batch, y_batch, [pid]
700
701
702
703
class CandidatesLunaSizeDataGenerator(object):
704
    def __init__(self, data_path, batch_size, transform_params, patient_ids, data_prep_fun, rng,
705
                 full_batch, random, infinite, positive_proportion, **kwargs):
706
707
        id2positive_annotations = utils_lung.read_luna_annotations(pathfinder.LUNA_LABELS_PATH)
708
        id2negative_annotations = utils_lung.read_luna_negative_candidates(pathfinder.LUNA_CANDIDATES_PATH)
709
710
        self.file_extension = '.pkl' if 'pkl' in data_path else '.mhd'
711
        self.id2positive_annotations = {}
712
        self.id2negative_annotations = {}
713
        self.patient_paths = []
714
        n_positive, n_negative = 0, 0
715
        for pid in patient_ids:
716
            if pid in id2positive_annotations:
717
                self.id2positive_annotations[pid] = id2positive_annotations[pid]
718
                self.id2negative_annotations[pid] = id2negative_annotations[pid]
719
                self.patient_paths.append(data_path + '/' + pid + self.file_extension)
720
                n_positive += len(id2positive_annotations[pid])
721
                n_negative += len(id2negative_annotations[pid])
722
723
        print 'n positive', n_positive
724
        print 'n negative', n_negative
725
726
        self.nsamples = len(self.patient_paths)
727
728
        print 'n patients', self.nsamples
729
        self.data_path = data_path
730
        self.batch_size = batch_size
731
        self.rng = rng
732
        self.full_batch = full_batch
733
        self.random = random
734
        self.infinite = infinite
735
        self.data_prep_fun = data_prep_fun
736
        self.transform_params = transform_params
737
        self.positive_proportion = positive_proportion
738
739
    def generate(self):
740
        while True:
741
            rand_idxs = np.arange(self.nsamples)
742
            if self.random:
743
                self.rng.shuffle(rand_idxs)
744
            for pos in xrange(0, len(rand_idxs), self.batch_size):
745
                idxs_batch = rand_idxs[pos:pos + self.batch_size]
746
                nb = len(idxs_batch)
747
                # allocate batches
748
                x_batch = np.zeros((nb, 1) + self.transform_params['patch_size'], dtype='float32')
749
                y_batch = np.zeros((nb, 1), dtype='float32')
750
                patients_ids = []
751
752
                for i, idx in enumerate(idxs_batch):
753
                    patient_path = self.patient_paths[idx]
754
755
                    id = utils_lung.extract_pid_filename(patient_path, self.file_extension)
756
                    patients_ids.append(id)
757
758
                    img, origin, pixel_spacing = utils_lung.read_pkl(patient_path) \
759
                        if self.file_extension == '.pkl' else utils_lung.read_mhd(patient_path)
760
                    if i < np.rint(self.batch_size * self.positive_proportion):
761
                        patient_annotations = self.id2positive_annotations[id]
762
                    else:
763
                        patient_annotations = self.id2negative_annotations[id]
764
765
                    patch_center = patient_annotations[self.rng.randint(len(patient_annotations))]
766
767
                    y_batch[i] = float(patch_center[-1])
768
                    x_batch[i, 0, :, :, :] = self.data_prep_fun(data=img,
769
                                                                patch_center=patch_center,
770
                                                                pixel_spacing=pixel_spacing,
771
                                                                luna_origin=origin)
772
773
                if self.full_batch:
774
                    if nb == self.batch_size:
775
                        yield x_batch, y_batch, patients_ids
776
                else:
777
                    yield x_batch, y_batch, patients_ids
778
779
            if not self.infinite:
780
                break
781
782
class CandidatesLunaSizeValidDataGenerator(object):
783
    def __init__(self, data_path, transform_params, patient_ids, data_prep_fun, **kwargs):
784
        rng = np.random.RandomState(42)  # do not change this!!!
785
786
        id2positive_annotations = utils_lung.read_luna_annotations(pathfinder.LUNA_LABELS_PATH)
787
        id2negative_annotations = utils_lung.read_luna_negative_candidates(pathfinder.LUNA_CANDIDATES_PATH)
788
789
        self.file_extension = '.pkl' if 'pkl' in data_path else '.mhd'
790
        self.id2positive_annotations = {}
791
        self.id2negative_annotations = {}
792
        self.id2patient_path = {}
793
        n_positive, n_negative = 0, 0
794
        for pid in patient_ids:
795
            if pid in id2positive_annotations:
796
                self.id2positive_annotations[pid] = id2positive_annotations[pid]
797
                negative_annotations = id2negative_annotations[pid]
798
                n_pos = len(id2positive_annotations[pid])
799
                n_neg = len(id2negative_annotations[pid])
800
                neg_idxs = rng.choice(n_neg, size=n_pos, replace=False)
801
                negative_annotations_selected = []
802
                for i in neg_idxs:
803
                    negative_annotations_selected.append(negative_annotations[i])
804
                self.id2negative_annotations[pid] = negative_annotations_selected
805
806
                self.id2patient_path[pid] = data_path + '/' + pid + self.file_extension
807
                n_positive += n_pos
808
                n_negative += n_pos
809
810
        print 'n positive', n_positive
811
        print 'n negative', n_negative
812
813
        self.nsamples = len(self.id2patient_path)
814
        self.data_path = data_path
815
        self.rng = rng
816
        self.data_prep_fun = data_prep_fun
817
        self.transform_params = transform_params
818
819
    def generate(self):
820
821
        for pid in self.id2positive_annotations.iterkeys():
822
            for patch_center in self.id2positive_annotations[pid]:
823
                patient_path = self.id2patient_path[pid]
824
825
                img, origin, pixel_spacing = utils_lung.read_pkl(patient_path) \
826
                    if self.file_extension == '.pkl' else utils_lung.read_mhd(patient_path)
827
                y_batch = np.array([[float(patch_center[-1])]], dtype='float32')
828
                x_batch = np.float32(self.data_prep_fun(data=img,
829
                                                        patch_center=patch_center,
830
                                                        pixel_spacing=pixel_spacing,
831
                                                        luna_origin=origin))[None, None, :, :, :]
832
833
                yield x_batch, y_batch, [pid]
834
835
            for patch_center in self.id2negative_annotations[pid]:
836
                patient_path = self.id2patient_path[pid]
837
838
                img, origin, pixel_spacing = utils_lung.read_pkl(patient_path) \
839
                    if self.file_extension == '.pkl' else utils_lung.read_mhd(patient_path)
840
                y_batch = np.array([[0.]], dtype='float32')
841
                x_batch = np.float32(self.data_prep_fun(data=img,
842
                                                        patch_center=patch_center,
843
                                                        pixel_spacing=pixel_spacing,
844
                                                        luna_origin=origin))[None, None, :, :, :]
845
846
                yield x_batch, y_batch, [pid]
847
848
849
850
class CandidatesLunaSizeBinDataGenerator(object):
851
    def __init__(self, data_path, batch_size, transform_params, patient_ids, data_prep_fun, rng,
852
                 full_batch, random, infinite, positive_proportion, bin_borders = [4,8,20,50], **kwargs):
853
854
        id2positive_annotations = utils_lung.read_luna_annotations(pathfinder.LUNA_LABELS_PATH)
855
        id2negative_annotations = utils_lung.read_luna_negative_candidates(pathfinder.LUNA_CANDIDATES_PATH)
856
857
        self.file_extension = '.pkl' if 'pkl' in data_path else '.mhd'
858
        self.id2positive_annotations = {}
859
        self.id2negative_annotations = {}
860
        self.patient_paths = []
861
        n_positive, n_negative = 0, 0
862
        for pid in patient_ids:
863
            if pid in id2positive_annotations:
864
                self.id2positive_annotations[pid] = id2positive_annotations[pid]
865
                self.id2negative_annotations[pid] = id2negative_annotations[pid]
866
                self.patient_paths.append(data_path + '/' + pid + self.file_extension)
867
                n_positive += len(id2positive_annotations[pid])
868
                n_negative += len(id2negative_annotations[pid])
869
870
        print 'n positive', n_positive
871
        print 'n negative', n_negative
872
873
        self.nsamples = len(self.patient_paths)
874
875
        print 'n patients', self.nsamples
876
        self.data_path = data_path
877
        self.batch_size = batch_size
878
        self.rng = rng
879
        self.full_batch = full_batch
880
        self.random = random
881
        self.infinite = infinite
882
        self.data_prep_fun = data_prep_fun
883
        self.transform_params = transform_params
884
        self.positive_proportion = positive_proportion
885
        self.bin_borders = bin_borders
886
887
    def generate(self):
888
        while True:
889
            rand_idxs = np.arange(self.nsamples)
890
            if self.random:
891
                self.rng.shuffle(rand_idxs)
892
            for pos in xrange(0, len(rand_idxs), self.batch_size):
893
                idxs_batch = rand_idxs[pos:pos + self.batch_size]
894
                nb = len(idxs_batch)
895
                # allocate batches
896
                x_batch = np.zeros((nb,) + self.transform_params['patch_size'], dtype='float32')
897
                y_batch = np.zeros((nb,), dtype='float32')
898
                patients_ids = []
899
900
                for i, idx in enumerate(idxs_batch):
901
                    patient_path = self.patient_paths[idx]
902
903
                    id = utils_lung.extract_pid_filename(patient_path, self.file_extension)
904
                    patients_ids.append(id)
905
906
                    img, origin, pixel_spacing = utils_lung.read_pkl(patient_path) \
907
                        if self.file_extension == '.pkl' else utils_lung.read_mhd(patient_path)
908
                    if i < np.rint(self.batch_size * self.positive_proportion):
909
                        patient_annotations = self.id2positive_annotations[id]
910
                    else:
911
                        patient_annotations = self.id2negative_annotations[id]
912
913
                    patch_center = patient_annotations[self.rng.randint(len(patient_annotations))]
914
915
                    diameter = patch_center[-1]
916
                    if diameter > 0.:
917
                        ybin = 0
918
                        for idx, border in enumerate(self.bin_borders):
919
                            if diameter<border:
920
                                ybin = idx
921
                                break                            
922
                        y_batch[i] = 1. + ybin
923
                    else:
924
                        y_batch[i] = 0. 
925
                    #print 'y_batch[i]', y_batch[i], 'diameter', diameter
926
927
                    x_batch[i, :, :, :] = self.data_prep_fun(data=img,
928
                                                                patch_center=patch_center,
929
                                                                pixel_spacing=pixel_spacing,
930
                                                                luna_origin=origin)
931
932
                if self.full_batch:
933
                    if nb == self.batch_size:
934
                        yield x_batch, y_batch, patients_ids
935
                else:
936
                    yield x_batch, y_batch, patients_ids
937
938
            if not self.infinite:
939
                break
940
941
class CandidatesLunaSizeBinValidDataGenerator(object):
942
    def __init__(self, data_path, transform_params, patient_ids, data_prep_fun, bin_borders = [4,8,20,50], **kwargs):
943
        rng = np.random.RandomState(42)  # do not change this!!!
944
945
        id2positive_annotations = utils_lung.read_luna_annotations(pathfinder.LUNA_LABELS_PATH)
946
        id2negative_annotations = utils_lung.read_luna_negative_candidates(pathfinder.LUNA_CANDIDATES_PATH)
947
948
        self.file_extension = '.pkl' if 'pkl' in data_path else '.mhd'
949
        self.id2positive_annotations = {}
950
        self.id2negative_annotations = {}
951
        self.id2patient_path = {}
952
        n_positive, n_negative = 0, 0
953
        for pid in patient_ids:
954
            if pid in id2positive_annotations:
955
                self.id2positive_annotations[pid] = id2positive_annotations[pid]
956
                negative_annotations = id2negative_annotations[pid]
957
                n_pos = len(id2positive_annotations[pid])
958
                n_neg = len(id2negative_annotations[pid])
959
                neg_idxs = rng.choice(n_neg, size=n_pos, replace=False)
960
                negative_annotations_selected = []
961
                for i in neg_idxs:
962
                    negative_annotations_selected.append(negative_annotations[i])
963
                self.id2negative_annotations[pid] = negative_annotations_selected
964
965
                self.id2patient_path[pid] = data_path + '/' + pid + self.file_extension
966
                n_positive += n_pos
967
                n_negative += n_pos
968
969
        print 'n positive', n_positive
970
        print 'n negative', n_negative
971
972
        self.nsamples = len(self.id2patient_path)
973
        self.data_path = data_path
974
        self.rng = rng
975
        self.data_prep_fun = data_prep_fun
976
        self.transform_params = transform_params
977
        self.bin_borders = bin_borders
978
979
    def generate(self):
980
981
        for pid in self.id2positive_annotations.iterkeys():
982
            for patch_center in self.id2positive_annotations[pid]:
983
                patient_path = self.id2patient_path[pid]
984
985
                img, origin, pixel_spacing = utils_lung.read_pkl(patient_path) \
986
                    if self.file_extension == '.pkl' else utils_lung.read_mhd(patient_path)
987
988
                diameter = patch_center[3]                        
989
                ybin = 0
990
                for idx, border in enumerate(self.bin_borders):
991
                    if diameter<border:
992
                        ybin = idx
993
                        break  
994
995
                y_batch = np.array([1. + ybin], dtype='float32')
996
                x_batch = np.float32(self.data_prep_fun(data=img,
997
                                                        patch_center=patch_center,
998
                                                        pixel_spacing=pixel_spacing,
999
                                                        luna_origin=origin))[None, :, :, :]
1000
1001
                yield x_batch, y_batch, [pid]
1002
1003
            for patch_center in self.id2negative_annotations[pid]:
1004
                patient_path = self.id2patient_path[pid]
1005
1006
                img, origin, pixel_spacing = utils_lung.read_pkl(patient_path) \
1007
                    if self.file_extension == '.pkl' else utils_lung.read_mhd(patient_path)
1008
                y_batch = np.array([0.], dtype='float32')
1009
                x_batch = np.float32(self.data_prep_fun(data=img,
1010
                                                        patch_center=patch_center,
1011
                                                        pixel_spacing=pixel_spacing,
1012
                                                        luna_origin=origin))[None, :, :, :]
1013
1014
                yield x_batch, y_batch, [pid]
1015
1016
1017
1018
class CandidatesLunaPropsDataGenerator(object):
1019
    def __init__(self, data_path, batch_size, transform_params, patient_ids, data_prep_fun, rng,
1020
                 full_batch, random, infinite, 
1021
                 positive_proportion,
1022
                 order_objectives,
1023
                 property_type,
1024
                 property_bin_borders = None,
1025
                 return_enable_target_vector = False, **kwargs):
1026
1027
        id2positive_annotations = utils_lung.read_luna_annotations(pathfinder.LUNA_LABELS_PATH)
1028
        id2negative_annotations = utils_lung.read_luna_negative_candidates(pathfinder.LUNA_CANDIDATES_PATH)
1029
1030
        self.file_extension = '.pkl' if 'pkl' in data_path else '.mhd'
1031
        self.id2positive_annotations = {}
1032
        self.id2negative_annotations = {}
1033
        self.all_pids = patient_ids
1034
        self.pos_pids = []
1035
        self.neg_pids = []
1036
        n_positive, n_negative = 0, 0
1037
        for pid in patient_ids:
1038
            if pid in id2positive_annotations:
1039
                self.id2positive_annotations[pid] = id2positive_annotations[pid]
1040
                self.id2negative_annotations[pid] = id2negative_annotations[pid]
1041
                self.pos_pids.append(pid)
1042
                n_positive += len(id2positive_annotations[pid])
1043
                n_negative += len(id2negative_annotations[pid])
1044
            elif pid in id2negative_annotations:
1045
                self.id2negative_annotations[pid] = id2negative_annotations[pid]
1046
                self.neg_pids.append(pid)
1047
                n_negative += len(id2negative_annotations[pid])
1048
            else:
1049
                print 'WARNING something weird happens'
1050
1051
        print 'n positive', n_positive
1052
        print 'n negative', n_negative
1053
1054
        self.n_neg_cans = n_negative
1055
        self.n_pos_cans = n_positive
1056
1057
        self.n_pos_pids = len(self.pos_pids)
1058
        self.n_neg_pids = len(self.neg_pids)
1059
1060
        self.nsamples = len(self.all_pids)
1061
        print 'n patients', self.nsamples
1062
        self.data_path = data_path
1063
        self.batch_size = batch_size
1064
        self.rng = rng
1065
        self.full_batch = full_batch
1066
        self.random = random
1067
        self.infinite = infinite
1068
        self.data_prep_fun = data_prep_fun
1069
        self.transform_params = transform_params
1070
        self.positive_proportion = positive_proportion
1071
1072
        self.order_objectives = order_objectives
1073
        self.property_bin_borders = property_bin_borders
1074
    self.property_type = property_type
1075
        #self.return_enable_target_vector = return_enable_target_vector
1076
1077
    def L2(self, a,b):
1078
        return ((a[0]-b[0])**2 + (a[1]-b[1])**2 + (a[2]-b[2])**2)**(0.5)
1079
1080
    def build_ground_truth_vector(self, pid, patch_center):
1081
        properties={}
1082
        feature_vector = np.zeros((len(self.order_objectives)), dtype='float32')
1083
        enable_target_vector = np.zeros((len(self.order_objectives)), dtype='float32')
1084
        diameter = patch_center[-1]
1085
        is_nodule  = diameter>0.01
1086
        properties['nodule'] = np.float32(is_nodule)
1087
        if is_nodule:
1088
            if 'size' in self.property_bin_borders:
1089
                properties['size'] = np.digitize(diameter, self.property_bin_borders['size'])
1090
            else:
1091
                properties['size'] = diameter
1092
            
1093
            patient = utils_lung.read_patient_annotations_luna(pid, pathfinder.LUNA_NODULE_ANNOTATIONS_PATH)
1094
1095
            #find the nodules in the doctor's annotations
1096
            nodule_characteristics = []
1097
            for doctor in patient:
1098
                for nodule in doctor:
1099
                    if "centroid_xyz" in nodule:
1100
                        dist = self.L2(patch_center[:3],nodule["centroid_xyz"][::-1])
1101
                        if  dist < 5:
1102
                            #print 'found a very close nodule at', dist, ': ', patch_center[:3]
1103
                            nodule_characteristics.append(nodule['characteristics'])
1104
1105
            if len(nodule_characteristics)==0:
1106
                print 'WARNING: no nodule found in doctor annotations for ', patch_center
1107
            else:
1108
                #calculate the median property values
1109
                for prop in nodule_characteristics[0]:
1110
                    if prop in self.order_objectives:
1111
                        prop_values = []
1112
                        for nchar in nodule_characteristics:
1113
                            prop_values.append(float(nchar[prop]))
1114
                            random_value = self.rng.choice(np.array(prop_values))
1115
                            if prop in self.property_bin_borders:
1116
                                properties[prop] = np.digitize(random_value, self.property_bin_borders[prop])
1117
                            else:      
1118
                                if self.property_type:
1119
                                    if self.property_type[prop] == 'bounded_continuous':
1120
                                        properties[prop] = (random_value-1) / 4.
1121
                                    else:
1122
                                        properties[prop] = random_value-1
1123
                                else:
1124
                                    raise
1125
1126
        for idx, prop in enumerate(self.order_objectives):
1127
            if prop in properties:
1128
                feature_vector[idx] = properties[prop]
1129
                enable_target_vector[idx] = 1.
1130
            
1131
        return feature_vector, enable_target_vector
1132
1133
    def generate(self):
1134
        while True:
1135
            # Construct pid set with
1136
            rand_pos_idxs = np.arange(self.n_pos_pids)
1137
            rand_neg_idxs = np.arange(self.n_neg_pids)
1138
            ptr_pos_idcs = 0
1139
            ptr_neg_idcs = 0
1140
1141
            if self.random:
1142
                self.rng.shuffle(rand_pos_idxs)
1143
                self.rng.shuffle(rand_neg_idxs)
1144
1145
            n_pos_batch = int(np.rint(self.batch_size * self.positive_proportion))
1146
            n_neg_batch = self.batch_size - n_pos_batch
1147
            for _idx, pos_pos in enumerate(xrange(0, len(rand_pos_idxs), n_pos_batch)):
1148
                pos_idxs_batch = rand_pos_idxs[pos_pos:pos_pos + n_pos_batch]
1149
                neg_idxs_batch = rand_neg_idxs[_idx * n_neg_batch:(_idx+1) * n_neg_batch]
1150
1151
                nb = len(pos_idxs_batch) + len(neg_idxs_batch)
1152
                # allocate batches
1153
                x_batch = np.zeros((nb,) + self.transform_params['patch_size'], dtype='float32')
1154
                y_batch = np.zeros((nb, len(self.order_objectives)), dtype='float32')
1155
                z_batch = np.zeros((nb, len(self.order_objectives)), dtype='float32')
1156
                patients_ids = []
1157
1158
                batch_ptr = 0
1159
                for idx in pos_idxs_batch:
1160
                    pid  = self.pos_pids[idx]
1161
                    patient_path = self.data_path + '/' + pid + self.file_extension
1162
                    patients_ids.append(pid)
1163
1164
                    img, origin, pixel_spacing = utils_lung.read_pkl(patient_path) \
1165
                        if self.file_extension == '.pkl' else utils_lung.read_mhd(patient_path)
1166
1167
                    patient_annotations = self.id2positive_annotations[pid]
1168
                    patch_center = patient_annotations[self.rng.randint(len(patient_annotations))]
1169
1170
                    y_batch[batch_ptr], z_batch[batch_ptr] = self.build_ground_truth_vector(pid, patch_center)
1171
                    x_batch[batch_ptr, :, :, :] = self.data_prep_fun(data=img,
1172
                                                                patch_center=patch_center,
1173
                                                                pixel_spacing=pixel_spacing,
1174
                                                                luna_origin=origin)
1175
                    batch_ptr += 1
1176
1177
                for idx in neg_idxs_batch:
1178
                    pid  = self.neg_pids[idx]
1179
                    patient_path = self.data_path + '/' + pid + self.file_extension
1180
                    patients_ids.append(pid)
1181
1182
                    img, origin, pixel_spacing = utils_lung.read_pkl(patient_path) \
1183
                        if self.file_extension == '.pkl' else utils_lung.read_mhd(patient_path)
1184
1185
                    patient_annotations = self.id2negative_annotations[pid]
1186
                    patch_center = patient_annotations[self.rng.randint(len(patient_annotations))]
1187
1188
                    y_batch[batch_ptr], z_batch[batch_ptr] = self.build_ground_truth_vector(pid, patch_center)
1189
                    x_batch[batch_ptr, :, :, :] = self.data_prep_fun(data=img,
1190
                                                                patch_center=patch_center,
1191
                                                                pixel_spacing=pixel_spacing,
1192
                                                                luna_origin=origin)
1193
                    batch_ptr += 1
1194
1195
1196
                if self.full_batch:
1197
                    if nb == self.batch_size:
1198
                        yield x_batch, y_batch, z_batch, patients_ids
1199
                else:
1200
                    yield x_batch, y_batch, z_batch, patients_ids
1201
1202
            if not self.infinite:
1203
                break
1204
1205
1206
class CandidatesLunaPropsValidDataGenerator(object):
1207
    def __init__(self, data_path, transform_params, patient_ids, data_prep_fun, 
1208
                    order_objectives, property_type, property_bin_borders=None, **kwargs):
1209
        rng = np.random.RandomState(42)  # do not change this!!!
1210
1211
        id2positive_annotations = utils_lung.read_luna_annotations(pathfinder.LUNA_LABELS_PATH)
1212
        id2negative_annotations = utils_lung.read_luna_negative_candidates(pathfinder.LUNA_CANDIDATES_PATH)
1213
1214
        self.file_extension = '.pkl' if 'pkl' in data_path else '.mhd'
1215
        self.id2positive_annotations = {}
1216
        self.id2negative_annotations = {}
1217
        self.id2patient_path = {}
1218
        n_positive, n_negative = 0, 0
1219
        for pid in patient_ids:
1220
            if pid in id2positive_annotations:
1221
                self.id2positive_annotations[pid] = id2positive_annotations[pid]
1222
                negative_annotations = id2negative_annotations[pid]
1223
                n_pos = len(id2positive_annotations[pid])
1224
                n_neg = len(id2negative_annotations[pid])
1225
                neg_idxs = rng.choice(n_neg, size=n_pos, replace=False)
1226
                negative_annotations_selected = []
1227
                for i in neg_idxs:
1228
                    negative_annotations_selected.append(negative_annotations[i])
1229
                self.id2negative_annotations[pid] = negative_annotations_selected
1230
1231
                self.id2patient_path[pid] = data_path + '/' + pid + self.file_extension
1232
                n_positive += n_pos
1233
                n_negative += n_pos
1234
1235
        print 'n positive', n_positive
1236
        print 'n negative', n_negative
1237
1238
        self.nsamples = len(self.id2patient_path)
1239
        self.data_path = data_path
1240
        self.rng = rng
1241
        self.data_prep_fun = data_prep_fun
1242
        self.transform_params = transform_params
1243
1244
        self.order_objectives = order_objectives
1245
        self.property_bin_borders = property_bin_borders
1246
        self.property_type = property_type
1247
    
1248
1249
    def L2(self, a,b):
1250
        return ((a[0]-b[0])**2 + (a[1]-b[1])**2 + (a[2]-b[2])**2)**(0.5)
1251
1252
1253
    def build_ground_truth_vector(self, pid, patch_center):
1254
        properties={}
1255
        feature_vector = np.zeros((len(self.order_objectives)), dtype='float32')
1256
        enable_target_vector = np.zeros((len(self.order_objectives)), dtype='float32')
1257
        diameter = patch_center[-1]
1258
        is_nodule  = diameter>0.01
1259
        properties['nodule'] = np.float32(is_nodule)
1260
        if is_nodule:
1261
            if 'size' in self.property_bin_borders:
1262
                properties['size'] = np.digitize(diameter, self.property_bin_borders['size'])
1263
            else:
1264
                properties['size'] = diameter
1265
            
1266
            patient = utils_lung.read_patient_annotations_luna(pid, pathfinder.LUNA_NODULE_ANNOTATIONS_PATH)
1267
1268
            #find the nodules in the doctor's annotations
1269
            nodule_characteristics = []
1270
            for doctor in patient:
1271
                for nodule in doctor:
1272
                    if "centroid_xyz" in nodule:
1273
                        dist = self.L2(patch_center[:3],nodule["centroid_xyz"][::-1])
1274
                        if  dist < 5:
1275
                            #print 'found a very close nodule at', dist, ': ', patch_center[:3]
1276
                            nodule_characteristics.append(nodule['characteristics'])
1277
1278
            if len(nodule_characteristics)==0:
1279
                print 'WARNING: no nodule found in doctor annotations for ', patch_center
1280
            else:
1281
                #calculate the median property values
1282
                for prop in nodule_characteristics[0]:
1283
                    if prop in self.order_objectives:
1284
                        prop_values = []
1285
                        for nchar in nodule_characteristics:
1286
                            prop_values.append(float(nchar[prop]))
1287
                        if prop in self.property_bin_borders:
1288
                            median_value = np.median(np.array(prop_values))
1289
                            properties[prop] = np.digitize(median_value, self.property_bin_borders[prop])
1290
                        else:
1291
                            mean_value = np.mean(np.array(prop_values))
1292
                            if self.property_type:
1293
                                if self.property_type[prop] == 'bounded_continuous':
1294
                                    properties[prop] = (mean_value-1) / 4.
1295
                                else:
1296
                                    properties[prop] = mean_value-1
1297
                            else:
1298
                                raise
1299
1300
        for idx, prop in enumerate(self.order_objectives):
1301
            if prop in properties:
1302
                feature_vector[idx] = properties[prop]
1303
                enable_target_vector[idx] = 1.
1304
            
1305
        return feature_vector, enable_target_vector
1306
1307
1308
    def generate(self):
1309
1310
        for pid in self.id2positive_annotations.iterkeys():
1311
            for patch_center in self.id2positive_annotations[pid]:
1312
                patient_path = self.id2patient_path[pid]
1313
1314
                img, origin, pixel_spacing = utils_lung.read_pkl(patient_path) \
1315
                    if self.file_extension == '.pkl' else utils_lung.read_mhd(patient_path)
1316
1317
                x_batch = np.float32(self.data_prep_fun(data=img,
1318
                                                        patch_center=patch_center,
1319
                                                        pixel_spacing=pixel_spacing,
1320
                                                        luna_origin=origin))[None, :, :, :]
1321
1322
                feature_vector, enable_target_vector = self.build_ground_truth_vector(pid, patch_center)
1323
                y_batch = np.array([feature_vector], dtype='float32')
1324
                z_batch = np.array([enable_target_vector], dtype='float32')
1325
1326
                yield x_batch, y_batch, z_batch, [pid]
1327
1328
            for patch_center in self.id2negative_annotations[pid]:
1329
                patient_path = self.id2patient_path[pid]
1330
1331
                img, origin, pixel_spacing = utils_lung.read_pkl(patient_path) \
1332
                    if self.file_extension == '.pkl' else utils_lung.read_mhd(patient_path)
1333
1334
                x_batch = np.float32(self.data_prep_fun(data=img,
1335
                                                        patch_center=patch_center,
1336
                                                        pixel_spacing=pixel_spacing,
1337
                                                        luna_origin=origin))[None, :, :, :]
1338
1339
                feature_vector, enable_target_vector = self.build_ground_truth_vector(pid, patch_center)
1340
                y_batch = np.array([feature_vector], dtype='float32')
1341
                z_batch = np.array([enable_target_vector], dtype='float32')
1342
1343
                yield x_batch, y_batch, z_batch, [pid]
1344
1345
1346
class DSBScanDataGenerator(object):
1347
    def __init__(self, data_path, transform_params, data_prep_fun, **kwargs):
1348
        self.patient_paths = utils_lung.get_patient_data_paths(data_path)
1349
        self.nsamples = len(self.patient_paths)
1350
        self.data_path = data_path
1351
        self.data_prep_fun = data_prep_fun
1352
        self.transform_params = transform_params
1353
1354
    def generate(self):
1355
        for p in self.patient_paths:
1356
            pid = utils_lung.extract_pid_dir(p)
1357
1358
            img, pixel_spacing = utils_lung.read_dicom_scan(p)
1359
1360
            x, tf_matrix = self.data_prep_fun(data=img, pixel_spacing=pixel_spacing)
1361
1362
            x = np.float32(x)[None, None, :, :, :]
1363
            yield x, None, tf_matrix, pid
1364
1365
1366
class DSBScanLungMaskDataGenerator(object):
1367
    def __init__(self, data_path, transform_params, data_prep_fun, exclude_pids=None,
1368
                 include_pids=None, part_out_of=(1, 1)):
1369
1370
        self.patient_paths = utils_lung.get_patient_data_paths(data_path)
1371
1372
        this_part = part_out_of[0]
1373
        all_parts = part_out_of[1]
1374
        part_lenght = int(len(self.patient_paths) / all_parts)
1375
1376
        if this_part == all_parts:
1377
            self.patient_paths = self.patient_paths[part_lenght * (this_part - 1):]
1378
        else:
1379
            self.patient_paths = self.patient_paths[part_lenght * (this_part - 1): part_lenght * this_part]
1380
1381
        if exclude_pids is not None:
1382
            for ep in exclude_pids:
1383
                for i in xrange(len(self.patient_paths)):
1384
                    if ep in self.patient_paths[i]:
1385
                        self.patient_paths.pop(i)
1386
                        break
1387
1388
        if include_pids is not None:
1389
            self.patient_paths = [data_path + '/' + p for p in include_pids]
1390
1391
        self.nsamples = len(self.patient_paths)
1392
        self.data_path = data_path
1393
        self.data_prep_fun = data_prep_fun
1394
        self.transform_params = transform_params
1395
1396
    def generate(self):
1397
        for p in self.patient_paths:
1398
            pid = utils_lung.extract_pid_dir(p)
1399
1400
            img, pixel_spacing = utils_lung.read_dicom_scan(p)
1401
1402
            x, lung_mask, tf_matrix = self.data_prep_fun(data=img, pixel_spacing=pixel_spacing)
1403
1404
            x = np.float32(x)[None, None, :, :, :]
1405
            lung_mask = np.float32(lung_mask)[None, None, :, :, :]
1406
            yield x, lung_mask, tf_matrix, pid
1407
1408
1409
class CandidatesDSBDataGenerator(object):
1410
    def __init__(self, data_path, transform_params, id2candidates_path, data_prep_fun, exclude_pids=None):
1411
        if exclude_pids is not None:
1412
            for p in exclude_pids:
1413
                id2candidates_path.pop(p, None)
1414
1415
        self.id2candidates_path = id2candidates_path
1416
        self.id2patient_path = {}
1417
        for pid in id2candidates_path.keys():
1418
            self.id2patient_path[pid] = data_path + '/' + pid
1419
1420
        self.nsamples = len(self.id2patient_path)
1421
        self.data_path = data_path
1422
        self.data_prep_fun = data_prep_fun
1423
        self.transform_params = transform_params
1424
1425
    def generate(self):
1426
1427
        for pid in self.id2candidates_path.iterkeys():
1428
            patient_path = self.id2patient_path[pid]
1429
            print pid, patient_path
1430
            img, pixel_spacing = utils_lung.read_dicom_scan(patient_path)
1431
1432
            print self.id2candidates_path[pid]
1433
            candidates = utils.load_pkl(self.id2candidates_path[pid])
1434
            print candidates.shape
1435
            for candidate in candidates:
1436
                y_batch = np.array(candidate, dtype='float32')
1437
                patch_center = candidate[:3]
1438
                x_batch = np.float32(self.data_prep_fun(data=img,
1439
                                                        patch_center=patch_center,
1440
                                                        pixel_spacing=pixel_spacing))[None, :, :, :]
1441
1442
                yield x_batch, y_batch, [pid]
1443
1444
1445
1446
1447
class CandidatesDSBDataGeneratorTTA(object):
1448
    def __init__(self, data_path, transform_params, id2candidates_path, data_prep_fun, exclude_pids=None, tta=64):
1449
        if exclude_pids is not None:
1450
            for p in exclude_pids:
1451
                id2candidates_path.pop(p, None)
1452
1453
        self.id2candidates_path = id2candidates_path
1454
        self.id2patient_path = {}
1455
        for pid in id2candidates_path.keys():
1456
            self.id2patient_path[pid] = data_path + '/' + pid
1457
1458
        self.nsamples = len(self.id2patient_path)
1459
        self.data_path = data_path
1460
        self.data_prep_fun = data_prep_fun
1461
        self.transform_params = transform_params
1462
        self.tta = tta
1463
1464
    def generate(self):
1465
1466
        for pid in self.id2candidates_path.iterkeys():
1467
            patient_path = self.id2patient_path[pid]
1468
            print pid, patient_path
1469
            img, pixel_spacing = utils_lung.read_dicom_scan(patient_path)
1470
1471
            print self.id2candidates_path[pid]
1472
            candidates = utils.load_pkl(self.id2candidates_path[pid])
1473
            print candidates.shape
1474
            for candidate in candidates:
1475
                y_batch = np.array(candidate, dtype='float32')
1476
                patch_center = candidate[:3]
1477
                batch = []
1478
                for i in range(self.tta):
1479
                    batch.append(np.float32(self.data_prep_fun(data=img,
1480
                                                        patch_center=patch_center,
1481
                                                        pixel_spacing=pixel_spacing)))
1482
                x_batch = np.stack(batch)
1483
                print x_batch.shape
1484
1485
                yield x_batch, y_batch, [pid]
1486
1487
1488
class DSBFeatureDataGenerator(object):
1489
    def __init__(self, data_path, batch_size, p_features,
1490
                 rng, random, infinite, patient_ids=None):
1491
1492
        print 'init DSBFeatureDataGenerator'
1493
1494
        self.id2label = utils_lung.read_labels(pathfinder.LABELS_PATH)
1495
        self.patient_paths = []
1496
        if patient_ids is not None:
1497
            for pid in patient_ids:
1498
                self.patient_paths.append(data_path + '/' + pid)
1499
        else:
1500
            raise ValueError('provide patient ids')
1501
1502
        self.nsamples = len(self.patient_paths)
1503
        self.data_path = data_path
1504
        self.batch_size = batch_size
1505
1506
        self.p_features = p_features
1507
        self.rng = rng
1508
        self.random = random
1509
        self.infinite = infinite
1510
1511
    def generate(self):
1512
        while True:
1513
            rand_idxs = np.arange(self.nsamples)
1514
            if self.random:
1515
                self.rng.shuffle(rand_idxs)
1516
1517
            for pos in xrange(0, len(rand_idxs), self.batch_size):
1518
                idxs_batch = rand_idxs[pos:pos + self.batch_size]
1519
1520
                x_batch = np.zeros((self.batch_size,)
1521
                                   + self.p_features['output_shape'], dtype='float32')
1522
1523
                y_batch = np.zeros((self.batch_size,), dtype='float32')
1524
                
1525
                pids_batch = []
1526
1527
                for i, idx in enumerate(idxs_batch):
1528
                    patient_path = self.patient_paths[idx]
1529
                    pid = utils_lung.extract_pid_dir(patient_path)
1530
1531
                    t_features = utils.load_pkl(patient_path+'.pkl')
1532
                    if 'reshape' in self.p_features:
1533
                        t_features = np.reshape(t_features, self.p_features['reshape'])
1534
                    if 'swapaxes' in self.p_features:
1535
                        t_features = np.swapaxes(t_features, *self.p_features['swapaxes'])
1536
1537
1538
                    x_batch[i] = t_features
1539
1540
                    y_batch[i] = self.id2label.get(pid)
1541
                    
1542
                    pids_batch.append(pid)
1543
1544
                if len(idxs_batch) == self.batch_size:
1545
                    yield x_batch, y_batch, pids_batch
1546
1547
            if not self.infinite:
1548
                break
1549
1550
class DSBPatientsDataGenerator(object):
1551
    def __init__(self, data_path, batch_size, transform_params, id2candidates_path, id2label, data_prep_fun,
1552
                 n_candidates_per_patient, rng, random, infinite, candidates_prep_fun, return_patch_locs=False, shuffle_top_n=False, patient_ids=None):
1553
1554
        self.id2label = id2label #utils_lung.read_labels(pathfinder.LABELS_PATH)
1555
        self.id2candidates_path = id2candidates_path
1556
        self.patient_paths = []
1557
        if patient_ids is not None:
1558
            for pid in patient_ids:
1559
                if pid in self.id2candidates_path:  # TODO: this should be redundant if fpr and segemntation are correctly generated
1560
                    self.patient_paths.append(data_path + '/' + pid)
1561
        else:
1562
            raise ValueError('provide patient ids')
1563
1564
        self.nsamples = len(self.patient_paths)
1565
        self.data_path = data_path
1566
        self.data_prep_fun = data_prep_fun
1567
        self.batch_size = batch_size
1568
        self.transform_params = transform_params
1569
        self.n_candidates_per_patient = n_candidates_per_patient
1570
        self.rng = rng
1571
        self.random = random
1572
        self.infinite = infinite
1573
        self.shuffle_top_n = shuffle_top_n
1574
        self.return_patch_locs = return_patch_locs
1575
        self.candidates_prep_fun = candidates_prep_fun
1576
1577
    def generate(self):
1578
        while True:
1579
            rand_idxs = np.arange(self.nsamples)
1580
            if self.random:
1581
                self.rng.shuffle(rand_idxs)
1582
1583
            for pos in xrange(0, len(rand_idxs), self.batch_size):
1584
                idxs_batch = rand_idxs[pos:pos + self.batch_size]
1585
1586
                x_batch = np.zeros((self.batch_size, self.n_candidates_per_patient,)
1587
                                   + self.transform_params['patch_size'], dtype='float32')
1588
1589
                if self.return_patch_locs:
1590
                    x_loc_batch = np.zeros((self.batch_size, self.n_candidates_per_patient, 3), dtype='float32')
1591
1592
                y_batch = np.zeros((self.batch_size,), dtype='float32')
1593
                pids_batch = []
1594
1595
                for i, idx in enumerate(idxs_batch):
1596
                    patient_path = self.patient_paths[idx]
1597
                    pid = utils_lung.extract_pid_dir(patient_path)
1598
1599
                    img, pixel_spacing = utils_lung.read_dicom_scan(patient_path)
1600
1601
                    all_candidates = utils.load_pkl(self.id2candidates_path[pid])
1602
                    if self.candidates_prep_fun:
1603
                        top_candidates = self.candidates_prep_fun(all_candidates, self.n_candidates_per_patient)
1604
                    else:
1605
                        top_candidates = all_candidates[:self.n_candidates_per_patient]
1606
                        if self.shuffle_top_n:
1607
                            self.rng.shuffle(top_candidates)
1608
1609
                    if self.return_patch_locs:
1610
                        #TODO move the normalization to the config file
1611
                        x_loc_batch[i] = np.float32(top_candidates[:,:3])/512. 
1612
1613
                    x_batch[i] = np.float32(self.data_prep_fun(data=img, pid=pid,
1614
                                                               patch_centers=top_candidates,
1615
                                                               pixel_spacing=pixel_spacing))[:, :, :, :]
1616
                    y_batch[i] = self.id2label.get(pid)
1617
                    pids_batch.append(pid)
1618
1619
                if len(idxs_batch) == self.batch_size:
1620
                    if self.return_patch_locs:
1621
                        yield x_batch, x_loc_batch, y_batch, pids_batch
1622
                    else:
1623
                        yield x_batch, y_batch, pids_batch
1624
1625
            if not self.infinite:
1626
                break
1627
1628
1629
1630
class DSBPatientsDataGeneratorTTA(object):
1631
    def __init__(self, data_path, transform_params, id2candidates_path, id2label, data_prep_fun, candidates_prep_fun,
1632
                 n_candidates_per_patient, patient_ids, tta=1):
1633
1634
        self.id2label = id2label 
1635
        self.id2candidates_path = id2candidates_path
1636
        self.patient_paths = []
1637
        if patient_ids is not None:
1638
            for pid in patient_ids:
1639
                if pid in self.id2candidates_path:  # TODO: this should be redundant if fpr and segemntation are correctly generated
1640
                    self.patient_paths.append(data_path + '/' + pid)
1641
        else:
1642
            raise ValueError('provide patient ids')
1643
1644
        self.nsamples = len(self.patient_paths)
1645
        self.data_path = data_path
1646
        self.data_prep_fun = data_prep_fun
1647
        self.transform_params = transform_params
1648
        self.n_candidates_per_patient = n_candidates_per_patient
1649
        self.tta = tta
1650
        self.candidates_prep_fun = candidates_prep_fun
1651
1652
    def generate(self):
1653
        print 
1654
        for idx in xrange(self.nsamples):
1655
            x_batch = np.zeros((self.tta, self.n_candidates_per_patient,)
1656
                               + self.transform_params['patch_size'], dtype='float32')
1657
1658
            y_batch = np.zeros((self.tta,), dtype='float32')
1659
1660
            patient_path = self.patient_paths[idx]
1661
            pid = utils_lung.extract_pid_dir(patient_path)
1662
1663
            img, pixel_spacing = utils_lung.read_dicom_scan(patient_path)
1664
1665
            all_candidates = utils.load_pkl(self.id2candidates_path[pid])
1666
            if self.candidates_prep_fun:
1667
                top_candidates = self.candidates_prep_fun(all_candidates, self.n_candidates_per_patient)
1668
            else:
1669
                top_candidates = all_candidates[:self.n_candidates_per_patient]
1670
1671
            for i in range(self.tta):
1672
                x_batch[i] = np.float32(self.data_prep_fun(data=img,
1673
                                                           patch_centers=top_candidates,
1674
                                                           pixel_spacing=pixel_spacing))[:, :, :, :]
1675
1676
                y_batch[i] = self.id2label.get(pid)
1677
1678
            yield x_batch, y_batch, pid
1679
1680
1681
1682
1683
class DSBPixelSpacingsGenerator(object):
1684
    def __init__(self, data_path, id2candidates_path, patient_ids):
1685
1686
        self.id2candidates_path = id2candidates_path
1687
        self.patient_paths = []
1688
        if patient_ids is not None:
1689
            for pid in patient_ids:
1690
                if pid in self.id2candidates_path:  # TODO: this should be redundant if fpr and segemntation are correctly generated
1691
                    self.patient_paths.append(data_path + '/' + pid)
1692
        else:
1693
            raise ValueError('provide patient ids')
1694
1695
        self.nsamples = len(self.patient_paths)
1696
        self.data_path = data_path
1697
1698
    def generate(self):
1699
1700
        for idx in xrange(self.nsamples):
1701
1702
            patient_path = self.patient_paths[idx]
1703
            pid = utils_lung.extract_pid_dir(patient_path)
1704
1705
            img, pixel_spacing = utils_lung.read_dicom_scan(patient_path)
1706
1707
            yield  pid, pixel_spacing
1708
1709
1710
class DSBPatientsDataGenerator_only_heatmap(object):
1711
    def __init__(self, data_path, batch_size, transform_params, id2candidates_path, data_prep_fun, 
1712
                 n_candidates_per_patient, rng, random, infinite, candidates_prep_fun, return_patch_locs=False, shuffle_top_n=False, patient_ids=None):
1713
1714
        self.id2label = utils_lung.read_labels(pathfinder.LABELS_PATH)
1715
        self.id2candidates_path = id2candidates_path
1716
        self.patient_paths = []
1717
        if patient_ids is not None:
1718
            for pid in patient_ids:
1719
                if pid in self.id2candidates_path:  # TODO: this should be redundant if fpr and segemntation are correctly generated
1720
                    self.patient_paths.append(data_path + '/' + pid)
1721
        else:
1722
            raise ValueError('provide patient ids')
1723
1724
        self.nsamples = len(self.patient_paths)
1725
        self.data_path = data_path
1726
        self.data_prep_fun = data_prep_fun
1727
        self.batch_size = batch_size
1728
        self.transform_params = transform_params
1729
        self.rng = rng
1730
        self.random = random
1731
        self.infinite = infinite
1732
        self.shuffle_top_n = shuffle_top_n
1733
        self.candidates_prep_fun = candidates_prep_fun
1734
        self.n_candidates_per_patient = n_candidates_per_patient
1735
1736
    def generate(self):
1737
        while True:
1738
            rand_idxs = np.arange(self.nsamples)
1739
            if self.random:
1740
                self.rng.shuffle(rand_idxs)
1741
1742
            for pos in xrange(0, len(rand_idxs), self.batch_size):
1743
                idxs_batch = rand_idxs[pos:pos + self.batch_size]
1744
1745
                x_batch = np.zeros((self.batch_size,)
1746
                                   + self.transform_params['heatmap_size'], dtype='float32')
1747
1748
                y_batch = np.zeros((self.batch_size,), dtype='float32')
1749
                pids_batch = []
1750
1751
                for i, idx in enumerate(idxs_batch):
1752
                    patient_path = self.patient_paths[idx]
1753
                    pid = utils_lung.extract_pid_dir(patient_path)
1754
1755
                    img, pixel_spacing = utils_lung.read_dicom_scan(patient_path)
1756
1757
                    all_candidates = utils.load_pkl(self.id2candidates_path[pid])
1758
                    candidates_w_value = self.candidates_prep_fun(all_candidates)
1759
1760
                    x_batch[i] = np.float32(self.data_prep_fun(data=img,
1761
                                                               candidates=candidates_w_value,
1762
                                                               pixel_spacing=pixel_spacing))
1763
                    y_batch[i] = self.id2label.get(pid)
1764
                    pids_batch.append(pid)
1765
1766
                if len(idxs_batch) == self.batch_size:
1767
                    yield x_batch, y_batch, pids_batch
1768
1769
            if not self.infinite:
1770
                break
1771
1772
1773
class DSBPatientsDataGeneratorRandomSelectionNonCancerous(object):
1774
    def __init__(self, data_path, batch_size, transform_params, id2candidates_path, data_prep_fun,
1775
                 n_candidates_per_patient, rng, random, infinite, top_true=10, top_false=16, shuffle_top_n=False, patient_ids=None):
1776
1777
        self.id2label = utils_lung.read_labels(pathfinder.LABELS_PATH)
1778
        self.id2candidates_path = id2candidates_path
1779
        self.patient_paths = []
1780
        if patient_ids is not None:
1781
            for pid in patient_ids:
1782
                if pid in self.id2candidates_path:  # TODO: this should be redundant if fpr and segemntation are correctly generated
1783
                    self.patient_paths.append(data_path + '/' + pid)
1784
        else:
1785
            raise ValueError('provide patient ids')
1786
1787
        self.nsamples = len(self.patient_paths)
1788
        self.data_path = data_path
1789
        self.data_prep_fun = data_prep_fun
1790
        self.batch_size = batch_size
1791
        self.transform_params = transform_params
1792
        self.n_candidates_per_patient = n_candidates_per_patient
1793
        self.rng = rng
1794
        self.random = random
1795
        self.infinite = infinite
1796
        self.shuffle_top_n = shuffle_top_n
1797
        self.top_true = top_true
1798
        self.top_false = top_false  
1799
1800
    def generate(self):
1801
        while True:
1802
            rand_idxs = np.arange(self.nsamples)
1803
            if self.random:
1804
                self.rng.shuffle(rand_idxs)
1805
1806
            for pos in xrange(0, len(rand_idxs), self.batch_size):
1807
                idxs_batch = rand_idxs[pos:pos + self.batch_size]
1808
1809
                x_batch = np.zeros((self.batch_size, self.n_candidates_per_patient, 1,)
1810
                                   + self.transform_params['patch_size'], dtype='float32')
1811
                y_batch = np.zeros((self.batch_size,), dtype='float32')
1812
                pids_batch = []
1813
1814
                for i, idx in enumerate(idxs_batch):
1815
                    patient_path = self.patient_paths[idx]
1816
                    pid = utils_lung.extract_pid_dir(patient_path)
1817
1818
                    img, pixel_spacing = utils_lung.read_dicom_scan(patient_path)
1819
                    all_candidates = utils.load_pkl(self.id2candidates_path[pid])
1820
1821
                    label = self.id2label.get(pid)
1822
                    if label:
1823
                        top_candidates = all_candidates[:self.n_candidates_per_patient]
1824
                    else:
1825
                        selection = np.arange(self.top_false)
1826
                        self.rng.shuffle(selection)
1827
                        selection = selection[:self.n_candidates_per_patient]
1828
                        top_candidates = all_candidates[selection]
1829
1830
                    
1831
                    if self.shuffle_top_n:
1832
                        self.rng.shuffle(top_candidates)
1833
1834
                    x_batch[i] = np.float32(self.data_prep_fun(data=img,
1835
                                                               patch_centers=top_candidates,
1836
                                                               pixel_spacing=pixel_spacing))[:, None, :, :, :]
1837
                    y_batch[i] = label
1838
                    pids_batch.append(pid)
1839
1840
                if len(idxs_batch) == self.batch_size:
1841
                    yield x_batch, y_batch, pids_batch
1842
1843
            if not self.infinite:
1844
                break
1845
1846
#balance between patients with and without cancer
1847
class BalancedDSBPatientsDataGenerator(object):
1848
    def __init__(self, data_path, batch_size, transform_params, id2candidates_path, data_prep_fun,
1849
                 n_candidates_per_patient, rng, random, infinite, shuffle_top_n=False, patient_ids=None):
1850
1851
        self.id2label = utils_lung.read_labels(pathfinder.LABELS_PATH)
1852
        self.id2candidates_path = id2candidates_path
1853
        self.patient_paths = []
1854
        if patient_ids is not None:
1855
            for pid in patient_ids:
1856
                if pid in self.id2candidates_path:  # TODO: this should be redundant if fpr and segemntation are correctly generated
1857
                    self.patient_paths.append(data_path + '/' + pid)
1858
        else:
1859
            raise ValueError('provide patient ids')
1860
        self.pos_ids = []
1861
        self.neg_ids = []
1862
        for pid in patient_ids:
1863
            if self.id2label[pid]:
1864
                self.pos_ids.append(pid)
1865
            else:
1866
                self.neg_ids.append(pid)
1867
        self.n_pos_ids = len(self.pos_ids)
1868
        self.n_neg_ids = len(self.neg_ids)
1869
        print 'n positive ids', self.n_pos_ids
1870
        print 'n negative ids', self.n_neg_ids
1871
        self.all_pids = patient_ids
1872
        self.nsamples = len(self.all_pids)
1873
1874
        self.data_path = data_path
1875
        self.data_prep_fun = data_prep_fun
1876
        self.batch_size = batch_size
1877
        self.transform_params = transform_params
1878
        self.n_candidates_per_patient = n_candidates_per_patient
1879
        self.rng = rng
1880
        self.random = random
1881
        self.infinite = infinite
1882
        self.shuffle_top_n = shuffle_top_n
1883
1884
    def generate(self):
1885
        while True:
1886
            neg_rand_idxs = np.arange(self.n_neg_ids)
1887
            if self.random:
1888
                self.rng.shuffle(neg_rand_idxs)
1889
            neg_rand_idxs_ptr = 0
1890
            batch_pids = []
1891
            while(neg_rand_idxs_ptr<self.n_neg_ids):               
1892
                if self.rng.randint(2):
1893
                    #take a cancerous patient
1894
                    pos_pid = self.rng.choice(self.pos_ids)
1895
                    batch_pids.append(pos_pid)
1896
                else:
1897
                    neg_pid = self.neg_ids[neg_rand_idxs[neg_rand_idxs_ptr]] 
1898
                    batch_pids.append(neg_pid)
1899
                    neg_rand_idxs_ptr += 1
1900
                if len(batch_pids)==self.batch_size:
1901
                    yield self.prepare_batch(batch_pids)
1902
                    batch_pids = []
1903
            # yield the half filled batch
1904
            if len(batch_pids) > 0:
1905
                yield self.prepare_batch(batch_pids)
1906
1907
            if not self.infinite:
1908
                break
1909
1910
    def prepare_batch(self, batch_pids):
1911
        x_batch = np.zeros((len(batch_pids), self.n_candidates_per_patient, 1,)
1912
                               + self.transform_params['patch_size'], dtype='float32')
1913
        y_batch = np.zeros((len(batch_pids),), dtype='float32')
1914
        for i, pid in enumerate(batch_pids):
1915
            patient_path = self.data_path + '/' + str(pid)
1916
            img, pixel_spacing = utils_lung.read_dicom_scan(patient_path)  
1917
            all_candidates = utils.load_pkl(self.id2candidates_path[pid])
1918
            top_candidates = all_candidates[:self.n_candidates_per_patient]                       
1919
            if self.shuffle_top_n:
1920
                self.rng.shuffle(top_candidates)
1921
            x_batch[i] = np.float32(self.data_prep_fun(data=img,
1922
                                                   patch_centers=top_candidates,
1923
                                                   pixel_spacing=pixel_spacing))[:, None, :, :, :]
1924
            y_batch[i] = self.id2label.get(pid) 
1925
        return x_batch, y_batch, batch_pids
1926
1927
class DSBDataGenerator(object):
1928
    def __init__(self, data_path, transform_params=None, data_prep_fun=None, patient_pids=None, **kwargs):
1929
        self.patient_paths = utils_lung.get_patient_data_paths(data_path)
1930
1931
1932
        self.patient_paths = [data_path + '/' + p for p in patient_pids]
1933
1934
        self.nsamples = len(self.patient_paths)
1935
        self.data_path = data_path
1936
        self.data_prep_fun = data_prep_fun
1937
        self.transform_params = transform_params
1938
1939
    def generate(self):
1940
        for p in self.patient_paths:
1941
            pid = utils_lung.extract_pid_dir(p)
1942
1943
            img, pixel_spacing = utils_lung.read_dicom_scan(p)
1944
1945
            if self.data_prep_fun:
1946
                x, tf_matrix = self.data_prep_fun(data=img, pixel_spacing=pixel_spacing)
1947
            else:
1948
                x = img
1949
1950
            x = np.float32(x)
1951
            yield x,  pid
1952
1953
1954
1955
1956
class CandidatesPropertiesLunaDataGenerator(object):
1957
    def __init__(self, data_path, batch_size, transform_params, label_prep_fun,
1958
                 nproperties,  patient_ids, data_prep_fun, rng,
1959
                 full_batch, random, infinite, positive_proportion, properties_included=[],
1960
                 random_negative_samples=False, **kwargs):
1961
1962
        id2positive_annotations = utils_lung.read_luna_properties(pathfinder.LUNA_PROPERTIES_PATH)
1963
        id2negative_annotations = utils_lung.read_luna_negative_candidates(pathfinder.LUNA_CANDIDATES_PATH)
1964
1965
        self.file_extension = '.pkl' if 'pkl' in data_path else '.mhd'
1966
        self.id2positive_annotations = {}
1967
        self.id2negative_annotations = {}
1968
        self.pid2patient_path = {}
1969
        n_positive = 0
1970
        for pid in patient_ids:
1971
            self.pid2patient_path[pid] = data_path + '/' + pid + self.file_extension
1972
            if pid in id2positive_annotations:
1973
                self.id2positive_annotations[pid] = id2positive_annotations[pid]
1974
                n_positive += len(id2positive_annotations[pid])
1975
            if pid in id2negative_annotations:
1976
                self.id2negative_annotations[pid] = id2negative_annotations[pid]
1977
1978
        self.nsamples = int(n_positive + (1. - positive_proportion) / positive_proportion * n_positive)
1979
        print 'n samples', self.nsamples
1980
        self.idx2pid_annotation = {}
1981
        i = 0
1982
        for pid, annotations in self.id2positive_annotations.iteritems():
1983
            for a in annotations:
1984
                self.idx2pid_annotation[i] = (pid, a)
1985
                i += 1
1986
        print 'n positive', len(self.idx2pid_annotation.keys())
1987
1988
        if random_negative_samples:
1989
            while i < self.nsamples:
1990
                self.idx2pid_annotation[i] = (None, None)
1991
                i += 1
1992
        else:
1993
            while i < self.nsamples:
1994
                pid = rng.choice(self.id2negative_annotations.keys())
1995
                patient_annotations = self.id2negative_annotations[pid]
1996
                a = patient_annotations[rng.randint(len(patient_annotations))]
1997
                self.idx2pid_annotation[i] = (pid, a)
1998
                i += 1
1999
        assert len(self.idx2pid_annotation) == self.nsamples
2000
2001
        self.data_path = data_path
2002
        self.batch_size = batch_size
2003
        self.rng = rng
2004
        self.full_batch = full_batch
2005
        self.random = random
2006
        self.infinite = infinite
2007
        self.data_prep_fun = data_prep_fun
2008
        self.transform_params = transform_params
2009
        self.positive_proportion = positive_proportion
2010
        self.label_prep_fun = label_prep_fun
2011
        self.nlabels = nproperties
2012
2013
        if len(properties_included)>0:
2014
            self.nlabels=len(properties_included)
2015
        self.properties_included = properties_included
2016
2017
        assert self.transform_params['pixel_spacing'] == (1., 1., 1.)
2018
2019
    def generate(self):
2020
        while True:
2021
            rand_idxs = np.arange(self.nsamples)
2022
            if self.random:
2023
                self.rng.shuffle(rand_idxs)
2024
            for pos in xrange(0, len(rand_idxs), self.batch_size):
2025
                idxs_batch = rand_idxs[pos:pos + self.batch_size]
2026
                nb = len(idxs_batch)
2027
                # allocate batches
2028
                x_batch = np.zeros((nb,) + self.transform_params['patch_size'], dtype='float32')
2029
                y_batch = np.zeros((nb, self.nlabels), dtype='float32')
2030
                patients_ids = []
2031
2032
                for i, idx in enumerate(idxs_batch):
2033
                    pid, patch_annotation = self.idx2pid_annotation[idx]
2034
2035
                    if pid is None:
2036
                        pid = self.rng.choice(self.id2negative_annotations.keys())
2037
                        patient_annotations = self.id2negative_annotations[pid]
2038
                        patch_annotation = patient_annotations[self.rng.randint(len(patient_annotations))]
2039
2040
                    patient_path = self.pid2patient_path[pid]
2041
                    patients_ids.append(pid)
2042
2043
                    y_batch[i] = self.label_prep_fun(patch_annotation,self.properties_included)
2044
                    # print pid, y_batch[i]
2045
2046
                    img, origin, pixel_spacing = utils_lung.read_pkl(patient_path) \
2047
                        if self.file_extension == '.pkl' else utils_lung.read_mhd(patient_path)
2048
2049
                    patch_zyxd = patch_annotation[:4]
2050
                    x_batch[i, :, :, :] = self.data_prep_fun(data=img, pid = pid,
2051
                                                                patch_center=patch_zyxd,
2052
                                                                pixel_spacing=pixel_spacing,
2053
                                                                luna_origin=origin)
2054
                y_batch = np.asarray(y_batch,dtype=np.float32)
2055
                if self.full_batch:
2056
                    if nb == self.batch_size:
2057
                        yield x_batch, y_batch, patients_ids
2058
                else:
2059
                    yield x_batch, y_batch, patients_ids
2060
2061
            if not self.infinite:
2062
                break
2063
2064
2065
class CandidatesPropertiesLunaDataGenerator2(object):
2066
    def __init__(self, data_path, batch_size, transform_params, label_prep_fun,
2067
                 nproperties,  patient_ids, data_prep_fun, rng,
2068
                 full_batch, random, infinite, positive_proportion, properties_included=[],
2069
                 random_negative_samples=False, **kwargs):
2070
2071
        id2positive_annotations = utils_lung.read_luna_properties(pathfinder.LUNA_PROPERTIES_PATH)
2072
        id2negative_annotations = utils_lung.read_luna_negative_candidates(pathfinder.LUNA_CANDIDATES_PATH)
2073
2074
        self.file_extension = '.pkl' if 'pkl' in data_path else '.mhd'
2075
        self.id2positive_annotations = {}
2076
        self.id2negative_annotations = {}
2077
        self.pid2patient_path = {}
2078
        n_positive = 0
2079
        for pid in patient_ids:
2080
            self.pid2patient_path[pid] = data_path + '/' + pid + self.file_extension
2081
            if pid in id2positive_annotations:
2082
                self.id2positive_annotations[pid] = id2positive_annotations[pid]
2083
                n_positive += len(id2positive_annotations[pid])
2084
            if pid in id2negative_annotations:
2085
                self.id2negative_annotations[pid] = id2negative_annotations[pid]
2086
2087
        self.nsamples = int(n_positive + (1. - positive_proportion) / positive_proportion * n_positive)
2088
        print 'n samples', self.nsamples
2089
        self.idx2pid_annotation = {}
2090
        i = 0
2091
        for pid, annotations in self.id2positive_annotations.iteritems():
2092
            for a in annotations:
2093
                self.idx2pid_annotation[i] = (pid, a)
2094
                i += 1
2095
        print 'n positive', len(self.idx2pid_annotation.keys())
2096
2097
        if random_negative_samples:
2098
            while i < self.nsamples:
2099
                self.idx2pid_annotation[i] = (None, None)
2100
                i += 1
2101
        else:
2102
            while i < self.nsamples:
2103
                pid = rng.choice(self.id2negative_annotations.keys())
2104
                patient_annotations = self.id2negative_annotations[pid]
2105
                a = patient_annotations[rng.randint(len(patient_annotations))]
2106
                self.idx2pid_annotation[i] = (pid, a)
2107
                i += 1
2108
        assert len(self.idx2pid_annotation) == self.nsamples
2109
2110
        self.data_path = data_path
2111
        self.batch_size = batch_size
2112
        self.rng = rng
2113
        self.full_batch = full_batch
2114
        self.random = random
2115
        self.infinite = infinite
2116
        self.data_prep_fun = data_prep_fun
2117
        self.transform_params = transform_params
2118
        self.positive_proportion = positive_proportion
2119
        self.label_prep_fun = label_prep_fun
2120
        self.nlabels = nproperties
2121
2122
        if len(properties_included)>0:
2123
            self.nlabels=len(properties_included)
2124
        self.properties_included = properties_included
2125
2126
        assert self.transform_params['pixel_spacing'] == (1., 1., 1.)
2127
2128
    def generate(self):
2129
        while True:
2130
            rand_idxs = np.arange(self.nsamples)
2131
            if self.random:
2132
                self.rng.shuffle(rand_idxs)
2133
            for pos in xrange(0, len(rand_idxs), self.batch_size):
2134
                idxs_batch = rand_idxs[pos:pos + self.batch_size]
2135
                nb = len(idxs_batch)
2136
                # allocate batches
2137
                x_batch = np.zeros((nb,) + self.transform_params['patch_size'], dtype='float32')
2138
                if self.nlabels == 1:
2139
                    y_batch = np.zeros((nb,), dtype='float32')
2140
                else:
2141
                    y_batch = np.zeros((nb, self.nlabels), dtype='float32')
2142
                patients_ids = []
2143
2144
                for i, idx in enumerate(idxs_batch):
2145
                    pid, patch_annotation = self.idx2pid_annotation[idx]
2146
2147
                    if pid is None:
2148
                        pid = self.rng.choice(self.id2negative_annotations.keys())
2149
                        patient_annotations = self.id2negative_annotations[pid]
2150
                        patch_annotation = patient_annotations[self.rng.randint(len(patient_annotations))]
2151
2152
                    patient_path = self.pid2patient_path[pid]
2153
                    patients_ids.append(pid)
2154
2155
                    y_batch[i] = self.label_prep_fun(patch_annotation,self.properties_included)
2156
                    # print pid, y_batch[i]
2157
2158
                    img, origin, pixel_spacing = utils_lung.read_pkl(patient_path) \
2159
                        if self.file_extension == '.pkl' else utils_lung.read_mhd(patient_path)
2160
2161
                    patch_zyxd = patch_annotation[:4]
2162
                    x_batch[i, :, :, :] = self.data_prep_fun(data=img, pid = pid,
2163
                                                                patch_center=patch_zyxd,
2164
                                                                pixel_spacing=pixel_spacing,
2165
                                                                luna_origin=origin)
2166
                y_batch = np.asarray(y_batch,dtype=np.float32)
2167
                if self.full_batch:
2168
                    if nb == self.batch_size:
2169
                        yield x_batch, y_batch, patients_ids
2170
                else:
2171
                    yield x_batch, y_batch, patients_ids
2172
2173
            if not self.infinite:
2174
                break
2175
2176
2177
class CandidatesLunaValidDataGenerator2(object):
2178
    def __init__(self, data_path, transform_params, patient_ids, data_prep_fun, label_prep_fun=None,properties_included=[],
2179
                 **kwargs):
2180
        rng = np.random.RandomState(42)  # do not change this!!!
2181
2182
        id2positive_annotations = utils_lung.read_luna_properties(pathfinder.LUNA_PROPERTIES_PATH)
2183
        id2negative_annotations = utils_lung.read_luna_negative_candidates(pathfinder.LUNA_CANDIDATES_PATH)
2184
2185
        self.file_extension = '.pkl' if 'pkl' in data_path else '.mhd'
2186
        self.id2positive_annotations = {}
2187
        self.id2negative_annotations = {}
2188
        self.id2patient_path = {}
2189
        n_positive, n_negative = 0, 0
2190
        for pid in patient_ids:
2191
            if pid in id2positive_annotations:
2192
                self.id2positive_annotations[pid] = id2positive_annotations[pid]
2193
                negative_annotations = id2negative_annotations[pid]
2194
                n_pos = len(id2positive_annotations[pid])
2195
                n_neg = len(id2negative_annotations[pid])
2196
                neg_idxs = rng.choice(n_neg, size=n_pos, replace=False)
2197
                negative_annotations_selected = []
2198
                for i in neg_idxs:
2199
                    negative_annotations_selected.append(negative_annotations[i])
2200
                self.id2negative_annotations[pid] = negative_annotations_selected
2201
2202
                self.id2patient_path[pid] = data_path + '/' + pid + self.file_extension
2203
                n_positive += n_pos
2204
                n_negative += n_pos
2205
2206
        print 'n positive', n_positive
2207
        print 'n negative', n_negative
2208
2209
        self.nsamples = len(self.id2patient_path)
2210
        self.data_path = data_path
2211
        self.rng = rng
2212
        self.data_prep_fun = data_prep_fun
2213
        self.transform_params = transform_params
2214
        self.label_prep_fun = label_prep_fun
2215
        if label_prep_fun is not None:
2216
            assert self.transform_params['pixel_spacing'] == (1., 1., 1.)
2217
2218
        self.properties_included = properties_included
2219
2220
    def generate(self):
2221
2222
        for pid in self.id2positive_annotations.iterkeys():
2223
            for patch_center in self.id2positive_annotations[pid]:
2224
                patient_path = self.id2patient_path[pid]
2225
2226
                img, origin, pixel_spacing = utils_lung.read_pkl(patient_path) \
2227
                    if self.file_extension == '.pkl' else utils_lung.read_mhd(patient_path)
2228
                if self.label_prep_fun is None:
2229
                    y_batch = np.array([1.], dtype='float32')
2230
                else:
2231
                    y_batch = np.array([self.label_prep_fun(patch_center,self.properties_included)], dtype='float32')
2232
                x_batch = np.float32(self.data_prep_fun(data=img, pid=pid,
2233
                                                        patch_center=patch_center[0:4],
2234
                                                        pixel_spacing=pixel_spacing,
2235
                                                        luna_origin=origin))[None, :, :, :]
2236
2237
                yield x_batch, y_batch, [pid]
2238
2239
            for patch_center in self.id2negative_annotations[pid]:
2240
                patient_path = self.id2patient_path[pid]
2241
2242
                img, origin, pixel_spacing = utils_lung.read_pkl(patient_path) \
2243
                    if self.file_extension == '.pkl' else utils_lung.read_mhd(patient_path)
2244
                y_batch = np.array([0.], dtype='float32')
2245
                x_batch = np.float32(self.data_prep_fun(data=img, pid=pid,
2246
                                                        patch_center=patch_center,
2247
                                                        pixel_spacing=pixel_spacing,
2248
                                                        luna_origin=origin))[None, :, :, :]
2249
2250
                yield x_batch, y_batch, [pid]