[5d12a0]: / ants / core / ants_metric.py

Download this file

140 lines (105 with data), 4.4 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""
ANTs ImageToImageMetric class
"""
import ants
class ANTsImageToImageMetric(object):
"""
ANTsImageToImageMetric class
"""
def __init__(self, metric):
self._metric = metric
self._is_initialized = False
self.fixed_image = None
self.fixed_mask = None
self.moving_image = None
self.moving_mask = None
# ------------------------------------------
# PROPERTIES
@property
def precision(self):
return self._metric.precision
@property
def dimension(self):
return self._metric.dimension
@property
def metrictype(self):
return self._metric.metrictype.replace('ImageToImageMetricv4','')
@property
def is_vector(self):
return self._metric.isVector == 1
@property
def pointer(self):
return self._metric.pointer
# ------------------------------------------
# METHODS
def set_fixed_image(self, image):
"""
Set Fixed ANTsImage for metric
"""
if not ants.is_image(image):
raise ValueError('image must be ANTsImage type')
if image.dimension != self.dimension:
raise ValueError('image dim (%i) does not match metric dim (%i)' % (image.dimension, self.dimension))
self._metric.setFixedImage(image.pointer, False)
self.fixed_image = image
def set_fixed_mask(self, image):
"""
Set Fixed ANTsImage Mask for metric
"""
if not ants.is_image(image):
raise ValueError('image must be ANTsImage type')
if image.dimension != self.dimension:
raise ValueError('image dim (%i) does not match metric dim (%i)' % (image.dimension, self.dimension))
self._metric.setFixedImage(image.pointer, True)
self.fixed_mask = image
def set_moving_image(self, image):
"""
Set Moving ANTsImage for metric
"""
if not ants.is_image(image):
raise ValueError('image must be ANTsImage type')
if image.dimension != self.dimension:
raise ValueError('image dim (%i) does not match metric dim (%i)' % (image.dimension, self.dimension))
self._metric.setMovingImage(image.pointer, False)
self.moving_image = image
def set_moving_mask(self, image):
"""
Set Fixed ANTsImage Mask for metric
"""
if not ants.is_image(image):
raise ValueError('image must be ANTsImage type')
if image.dimension != self.dimension:
raise ValueError('image dim (%i) does not match metric dim (%i)' % (image.dimension, self.dimension))
self._metric.setMovingImage(image.pointer, True)
self.moving_mask = image
def set_sampling(self, strategy='regular', percentage=1.):
if (self.fixed_image is None) or (self.moving_image is None):
raise ValueError('must set fixed_image and moving_image before setting sampling')
self._metric.setSampling(strategy, percentage)
def initialize(self):
if (self.fixed_image is None) or (self.moving_image is None):
raise ValueError('must set fixed_image and moving_image before initializing')
self._metric.initialize()
self._is_initialized = True
def get_value(self):
if not self._is_initialized:
self.initialize()
return self._metric.getValue()
def __call__(self, fixed, moving, fixed_mask=None, moving_mask=None, sampling_strategy='regular', sampling_percentage=1.):
self.set_fixed_image(fixed)
self.set_moving_image(moving)
if fixed_mask is not None:
self.set_fixed_mask(fixed_mask)
if moving_mask is not None:
self.set_moving_mask(moving_mask)
if (sampling_strategy is not None) or (sampling_percentage is not None):
self.set_sampling(sampling_strategy, sampling_percentage)
self.initialize()
return self.get_value()
def __repr__(self):
s = "ANTsImageToImageMetric\n" +\
'\t {:<10} : {}\n'.format('Dimension', self.dimension)+\
'\t {:<10} : {}\n'.format('Precision', self.precision)+\
'\t {:<10} : {}\n'.format('MetricType', self.metrictype)+\
'\t {:<10} : {}\n'.format('IsVector', self.is_vector)
return s