[797475]: / dicom_and_image_tools / segment / sitkstrats.py

Download this file

483 lines (358 with data), 16.0 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
'''A collection of segmentation strategies from sitk instead of itk.'''
import SimpleITK as sitk # pylint: disable=F0401
import numpy as np
import datetime
import os
from functools import wraps
import logging
import lungseg
def write(img, fname, compression=True):
'''
A wrapper for sitk WriteImage that defaults to using compression and
creates directories where required.
'''
# ImageFileWriter fails if the directory doesn't exist. Create if req'd
try:
os.makedirs(os.path.dirname(fname))
except OSError:
pass
sitk.WriteImage(img, fname, compression)
def read(fname):
img = sitk.ReadImage(fname)
return img
def hash_img(img, provenance=""):
'''
Calculate the hash of an image and the options that would be used to
process it using sha512. This is used most frequently to cache images for
later reuse.
'''
import hashlib
sha = hashlib.sha512()
sha.update(sitk.GetArrayFromImage(img))
sha.update(provenance)
return sha.hexdigest()
def log_size(func):
'''A decorator that calculates the size of a segmentation.'''
@wraps(func)
def exec_func(img, opts=None):
'''Execute func from outer context and compute the size of the image
func produces.'''
if opts is None:
opts = {}
(img_out, opts_out) = func(img, opts)
opts['size'] = np.count_nonzero( # pylint: disable=E1101
sitk.GetArrayFromImage(img_out))
return (img_out, opts_out)
return exec_func
def options_log(func):
'''A decorator that will modify the incoming options object to also include
information about runtime and algorithm choice.'''
@wraps(func)
def exec_func(*args, **kwargs):
'''The inner function for options_log'''
start = datetime.datetime.now()
(img, out_opts) = func(*args, **kwargs)
out_opts['algorithm'] = func.__name__
out_opts['time'] = datetime.datetime.now() - start
return (img, out_opts)
return exec_func
def cached(relevant_opts, max_cache_size=1):
'''A decorator that uses options and input image to cache an image for
possible later reuse.'''
def cached_decorator(func):
'''
Produce a decorator configured with the options defined in 'cached'
above.
'''
from collections import OrderedDict
# build a cache once for each method?
cache = OrderedDict()
@wraps(func)
def exec_func(img_in, opts): # pylint: disable=C0111
# SHA should be built based only on the options defined in
# relevant opts. This set is not passed to the internal function,
# however, since the output is decorated by other functions to
# report on outcomes and whatnot
limited_opts = dict([(i, opts[i]) for i in opts
if i in relevant_opts])
limited_opts['func'] = func.__name__
sha = hash_img(img_in,
provenance=str(limited_opts))
if sha in cache:
logging.info("Loading '" + sha + "' from " + func.__name__ +
" cache (size=" + str(len(cache)) + ")")
img = cache[sha]
# refresh this image in the order, so we re-add it to the cache
del cache[sha]
else:
logging.info("Cache miss for '" + sha + "' from "
+ func.__name__ + "(size="+str(len(cache)) + ")")
(img, opts) = func(img_in, opts)
cache[sha] = img
# if the cache size exceeds the max, ditch the oldest entry
# this shouldn't ever actually get executed more than once, but
# you never know--recursion, for example, could cause many pushes
# without appropriate pops.
while len(cache) > max_cache_size:
# the last item in the list (FIFO) popped only when last=True
cache.popitem(last=True)
return (img, opts)
return exec_func
return cached_decorator
def com_calc(img, max_size, min_size, lung_img):
'''
Calculate the center of mass of each of the labeled regions in img,
excluding regions that are outside the lung given in lung_img, or the
range size [min_size, max_size], which are reported treated as fractions of
the input lung.
'''
from scipy.ndimage.measurements import center_of_mass as com
# pylint: disable=E1101
arr = sitk.GetArrayFromImage(img)
lung_arr = sitk.GetArrayFromImage(lung_img)
# Take elements from arr only when lung_arr is not zero, i.e. take only
# regions in the lung.
arr = np.where(lung_arr != 0, arr,
np.zeros(arr.shape,
dtype=arr.dtype))
counts = np.bincount(np.ravel(arr))
# volume per voxel is encoded in img spacing, with units mm^3
vox_vol = reduce(lambda x, y: x * y, img.GetSpacing())
# the size of the lung is the size of a voxel times the number of voxels
lung_size = np.count_nonzero(lung_arr)*vox_vol
# print sorted(counts)[-10:], sorted([c for c in counts if c != 0])[0:10]
# print np.count_nonzero(lung_arr)*vox_vol
logging.debug("Availiable labels and sizes: %s",
[p for p in enumerate(counts) if p[1] > 0])
# We gate the deterministic seeds for their regions being of a reasonable
# size.
labels = [label for (label, n_vox) in enumerate(counts)
if min_size*lung_size < n_vox*vox_vol < max_size*lung_size]
# we don't want to calculate the COM of the background, no matter what
# our min/max sizes are
try:
labels.remove(0)
except ValueError:
#Zero wasn't in the list
pass
# compute the center of mass for each of the elements up to 100 elements
com_list = com(np.where(arr != 0, np.ones(arr.shape), np.zeros(arr.shape)),
labels=arr, index=labels)
logging.debug("Label-COM correspondence: %s", dict(zip(labels, com_list)))
# these are array-indexed and we take our seeds to be image-indexed
# plus, they're floats and need to be cast back to integers
seeds = [[int(k) for k in reversed(s)] for s in com_list
if lung_arr[s] == 1]
info = {'nseeds': len(seeds),
'max_size': max_size,
'min_size': min_size,
'seeds': [s for s in seeds]} # deep (enough) copy
logging.info("%s in-lung of %s seeds from %s watershed labels",
len(seeds), len(labels), len(counts))
return (seeds, info)
@options_log
def crop_to_segmentation(img, seg_img, padding_ratio=None, padding_px=None):
'''
Given an image and a segmentation of that image, crop the image to include
only the portions of the image present in that segmentation. Only one of
padding_px and padding_ratio can be specified.
'''
from bounding import bounding_cube
# it's not allowed to specify both padding_ratio and padding_px
assert not ((padding_px is not None) and (padding_ratio is not None))
assert img.GetSize() == seg_img.GetSize()
assert img.GetSpacing() == seg_img.GetSpacing()
# calculate the bounding cube of the lung segmentation in numpy coordinates
lims = bounding_cube(sitk.GetArrayFromImage(seg_img))
# calculate the amount of each image to be removed (in itk indexing)
lower_remove = [l[0] for l in reversed(lims)]
upper_remove = [img.GetSize()[i] - l[1] for (i, l)
in enumerate(reversed(lims))]
if padding_px is None:
padding_ratio = 0.0
padding = [int(padding_ratio*img.GetSize()[i])
for i in range(len(img.GetSize()))]
elif padding_ratio is None:
assert padding_px is not None
padding = [padding_px]*len(img.GetSize())
padded_removes = [[r[i] - padding[i] for i in range(len(padding))]
for r in [lower_remove, upper_remove]]
logging.info("Cropped img to %s", lims)
if np.any(np.array(padded_removes) < 0):
raise ValueError("Padded removes " + str(padded_removes) +
" are invalid.")
return (sitk.Crop(img, *padded_removes),
{"origin": (lims[0][0], lims[1][0], lims[2][0]),
"padding": padding})
def distribute_seeds(img, n_pts=100):
'''Randomly distribute n seeds amongst all points where img != 0'''
import random
array = sitk.GetArrayFromImage(img)
seeds = list()
while len(seeds) < n_pts:
(z, y, x) = [random.randrange(0, i) for i in array.shape]
if array[z, y, x] != 0 and (z, y, x) not in seeds:
seeds.append((x, y, z))
return seeds
@cached(relevant_opts=["anisodiff", "gauss"])
def aniso_gauss(img_in, options):
'''CurvatureAnisotropicDiffusion + GradientMagnitudeRecursiveGaussian is a
a common featurization strategy. Compute these for consumption by other
sitkstrat functions.'''
img = sitk.Cast(img_in, sitk.sitkFloat32)
img = sitk.CurvatureAnisotropicDiffusion(
img,
timeStep=options['anisodiff']['timestep'],
conductanceParameter=options['anisodiff']['conductance'],
# options['anisodiff'].setdefault('scaling_interval', 1),
numberOfIterations=options['anisodiff']['iterations'])
img = sitk.GradientMagnitudeRecursiveGaussian(
img,
options['gauss']['sigma'])
return (img, options)
@log_size
@options_log
def segmentation_union(imgs, options):
'''Compute a consensus segmentation amongst a small set of segmentations'''
# pylint: disable=E1101
# Sadly, images and arrays have a different coordinate system (z, y, x) vs
# (x, y, z) so it's safest just to convert here. Don't worry, it's fast.
incl_count = np.zeros(sitk.GetArrayFromImage(imgs[0]).shape)
n_img = 0
# This would be memory-expensive for large numbers of images, but it's only
# a couple so ith's hopefully ok.
for img in [sitk.GetArrayFromImage(i) for i in imgs]:
assert img.shape == incl_count.shape
img_size = np.count_nonzero(img)
if img_size < options['max_size'] and \
img_size > options['min_size']:
n_img += 1
incl_count += (img != 0)
# store the number of images that passed QC
options['n_imgs'] = n_img
if n_img == 0:
raise RuntimeWarning("No images satisifed image size thresholds" +
str((options['min_size'], options['max_size'])))
# sitk is pretty particular about the datatypes that come in to
# GetImageFromArray, and the default output from the following (bool?)
# isn't acceptable apparently
consensus = np.array(incl_count >= options['threshold'] * n_img,
dtype='uint8')
consensus_size = np.count_nonzero(consensus)
if consensus_size < options['min_size']:
raise RuntimeWarning("Consensus image failed size threshold. " +
"Image too small at " + str(consensus_size))
consensus = sitk.GetImageFromArray(consensus)
consensus.CopyInformation(imgs[0])
consensus = sitk.Cast(consensus, sitk.sitkUInt8)
return (consensus, options)
@log_size
@options_log
def segment_lung(img, options):
'''Produce a lung segmentation from an input image.'''
orig = img
img = lungseg.otsu(img)
img = lungseg.ensure_img_border(img)
img = lungseg.find_components(img)
img = lungseg.isolate_lung_field(img)
img = lungseg.dialate(img, options['probe_size'])
img = lungseg.find_components(img)
img = lungseg.isolate_not_biggest(img)
img = sitk.BinaryErode(img, options['probe_size']-2,
sitk.BinaryErodeImageFilter.Ball)
return (img, options)
@options_log
def curvature_flow(img_in, options):
img = sitk.Cast(img_in, sitk.sitkFloat32)
img = sitk.CurvatureFlow(
img,
options['curvature_flow']['timestep'],
options['curvature_flow']['iterations'])
return (img, options)
@log_size
@options_log
def confidence_connected(img_in, options):
img = sitk.ConfidenceConnected(
img_in,
[options['seed']],
options['conf_connect']['iterations'],
options['conf_connect']['multiplier'],
options['conf_connect']['neighborhood'])
img = sitk.BinaryDilate(
img,
options['dialate']['radius'],
sitk.BinaryDilateImageFilter.Ball)
return (img, options)
@options_log
def aniso_gauss_watershed(img_in, options_in):
'''Compute CurvatureAnisotropicDiffusion +
GradientMagnitudeRecursiveGaussian + Sigmoid featurization of the image.'''
(img, options) = aniso_gauss(img_in, options_in)
img = sitk.MorphologicalWatershed(
img,
level=options['watershed']['level'],
markWatershedLine=True,
fullyConnected=False)
return (img, options)
@log_size
@options_log
def isolate_watershed(img_in, options):
'''Isolate a particular one of the watershed segmentations.'''
seed = options['seed']
arr = sitk.GetArrayFromImage(img_in)
label = arr[seed[2], seed[1], seed[0]]
options['label'] = int(label)
lab_arr = np.array(arr == label, dtype='float32') # pylint: disable=E1101
out_img = sitk.GetImageFromArray(lab_arr)
out_img.CopyInformation(img_in)
return (out_img, options)
@options_log
def aniso_gauss_sigmo(img_in, opts_in):
'''Compute CurvatureAnisotropicDiffusion +
GradientMagnitudeRecursiveGaussian + Sigmoid featurization of the image.'''
(img, options) = aniso_gauss(img_in, opts_in)
img = sitk.Sigmoid(
img,
options['sigmoid']['alpha'],
options['sigmoid']['beta'])
return (img, options)
@log_size
@options_log
def fastmarch_seeded_geocontour(img_in, options):
'''Segment img_in using a GeodesicActiveContourLevelSetImageFilter with an
inital level set built using FastMarchingImageFilter at options['seed']'''
# The speed of wave propagation should be one everywhere, so we produce
# an appropriately sized np array of all ones and convert it into an img
ones_img = sitk.GetImageFromArray(np.ones( # pylint: disable=E1101
sitk.GetArrayFromImage(img_in).shape))
ones_img.CopyInformation(img_in)
fastmarch = sitk.FastMarchingImageFilter()
# to save time, we limit the distances we calculate to a quarter of the
# image size away (i.e. a region no more than half the image in diameter).
fastmarch.SetStoppingValue(max(img_in.GetSize())*0.25)
seeds = sitk.VectorUIntList()
seeds.append(options['seed'])
fastmarch.SetTrialPoints(seeds)
seed_img = fastmarch.Execute(ones_img)
# FastMarch won't output the right PixelType, so we have to cast.
seed_img = sitk.Cast(seed_img, img_in.GetPixelID())
# Generally speaking, you're supposed to subtract an amount from the
# input level set, so that growing algorithm doesn't need to go as far
img_shifted = sitk.GetImageFromArray(
sitk.GetArrayFromImage(seed_img) - options['seed_shift'])
img_shifted.CopyInformation(seed_img)
seed_img = img_shifted
geodesic = sitk.GeodesicActiveContourLevelSetImageFilter()
geodesic.SetPropagationScaling(
options['geodesic']['propagation_scaling'])
geodesic.SetNumberOfIterations(
options['geodesic']['iterations'])
geodesic.SetCurvatureScaling(
options['geodesic']['curvature_scaling'])
geodesic.SetMaximumRMSError(
options['geodesic']['max_rms_change'])
out = geodesic.Execute(seed_img, img_in)
options['geodesic']['elapsed_iterations'] = geodesic.GetElapsedIterations()
options['geodesic']['rms_change'] = geodesic.GetRMSChange()
out = sitk.BinaryThreshold(out, insideValue=0, outsideValue=1)
return (out, options)