|
a |
|
b/src/labelOverlapMeasures.cxx |
|
|
1 |
|
|
|
2 |
#include <nanobind/nanobind.h> |
|
|
3 |
#include <nanobind/stl/vector.h> |
|
|
4 |
#include <nanobind/stl/string.h> |
|
|
5 |
#include <nanobind/stl/tuple.h> |
|
|
6 |
#include <nanobind/stl/list.h> |
|
|
7 |
#include <nanobind/ndarray.h> |
|
|
8 |
#include <nanobind/stl/shared_ptr.h> |
|
|
9 |
|
|
|
10 |
#include <exception> |
|
|
11 |
#include <vector> |
|
|
12 |
#include <string> |
|
|
13 |
|
|
|
14 |
#include "itkImage.h" |
|
|
15 |
#include "itkLabelOverlapMeasuresImageFilter.h" |
|
|
16 |
|
|
|
17 |
#include "antsImage.h" |
|
|
18 |
|
|
|
19 |
namespace nb = nanobind; |
|
|
20 |
using namespace nb::literals; |
|
|
21 |
|
|
|
22 |
template<class PrecisionType, unsigned int ImageDimension> |
|
|
23 |
nb::dict labelOverlapMeasures( AntsImage<itk::Image<PrecisionType, ImageDimension>> & antsSourceImage, |
|
|
24 |
AntsImage<itk::Image<PrecisionType, ImageDimension>> & antsTargetImage ) |
|
|
25 |
{ |
|
|
26 |
using ImageType = itk::Image<PrecisionType, ImageDimension>; |
|
|
27 |
using ImagePointerType = typename ImageType::Pointer; |
|
|
28 |
|
|
|
29 |
typename ImageType::Pointer itkSourceImage = antsSourceImage.ptr; |
|
|
30 |
typename ImageType::Pointer itkTargetImage = antsTargetImage.ptr; |
|
|
31 |
|
|
|
32 |
using FilterType = itk::LabelOverlapMeasuresImageFilter<ImageType>; |
|
|
33 |
typename FilterType::Pointer filter = FilterType::New(); |
|
|
34 |
filter->SetSourceImage( itkSourceImage ); |
|
|
35 |
filter->SetTargetImage( itkTargetImage ); |
|
|
36 |
filter->Update(); |
|
|
37 |
|
|
|
38 |
typename FilterType::MapType labelMap = filter->GetLabelSetMeasures(); |
|
|
39 |
|
|
|
40 |
// Sort the labels |
|
|
41 |
|
|
|
42 |
std::vector<PrecisionType> allLabels; |
|
|
43 |
allLabels.clear(); |
|
|
44 |
for( typename FilterType::MapType::const_iterator it = labelMap.begin(); |
|
|
45 |
it != labelMap.end(); ++it ) |
|
|
46 |
{ |
|
|
47 |
if( (*it).first == 0 ) |
|
|
48 |
{ |
|
|
49 |
continue; |
|
|
50 |
} |
|
|
51 |
|
|
|
52 |
const int label = (*it).first; |
|
|
53 |
allLabels.push_back( label ); |
|
|
54 |
} |
|
|
55 |
std::sort( allLabels.begin(), allLabels.end() ); |
|
|
56 |
|
|
|
57 |
|
|
|
58 |
// Now put the results in an Rcpp data frame |
|
|
59 |
|
|
|
60 |
unsigned int vectorLength = 1 + allLabels.size(); |
|
|
61 |
|
|
|
62 |
std::vector<PrecisionType> labels( vectorLength ); |
|
|
63 |
std::vector<double> totalOrTargetOverlap( vectorLength ); |
|
|
64 |
std::vector<double> unionOverlap( vectorLength ); |
|
|
65 |
std::vector<double> meanOverlap( vectorLength ); |
|
|
66 |
std::vector<double> volumeSimilarity( vectorLength ); |
|
|
67 |
std::vector<double> falseNegativeError( vectorLength ); |
|
|
68 |
std::vector<double> falsePositiveError( vectorLength ); |
|
|
69 |
|
|
|
70 |
// We'll replace label '0' with "All" in the R wrapper. |
|
|
71 |
labels[0] = itk::NumericTraits<PrecisionType>::Zero; |
|
|
72 |
totalOrTargetOverlap[0] = filter->GetTotalOverlap(); |
|
|
73 |
unionOverlap[0] = filter->GetUnionOverlap(); |
|
|
74 |
meanOverlap[0] = filter->GetMeanOverlap(); |
|
|
75 |
volumeSimilarity[0] = filter->GetVolumeSimilarity(); |
|
|
76 |
falseNegativeError[0] = filter->GetFalseNegativeError(); |
|
|
77 |
falsePositiveError[0] = filter->GetFalsePositiveError(); |
|
|
78 |
|
|
|
79 |
unsigned int i = 1; |
|
|
80 |
typename std::vector<PrecisionType>::const_iterator itL = allLabels.begin(); |
|
|
81 |
for( itL = allLabels.begin(); itL != allLabels.end(); ++itL ) |
|
|
82 |
{ |
|
|
83 |
labels[i] = *itL; |
|
|
84 |
totalOrTargetOverlap[i] = filter->GetTargetOverlap( *itL ); |
|
|
85 |
unionOverlap[i] = filter->GetUnionOverlap( *itL ); |
|
|
86 |
meanOverlap[i] = filter->GetMeanOverlap( *itL ); |
|
|
87 |
volumeSimilarity[i] = filter->GetVolumeSimilarity( *itL ); |
|
|
88 |
falseNegativeError[i] = filter->GetFalseNegativeError( *itL ); |
|
|
89 |
falsePositiveError[i] = filter->GetFalsePositiveError( *itL ); |
|
|
90 |
i++; |
|
|
91 |
} |
|
|
92 |
|
|
|
93 |
nb::dict labelOverlapMeasures; |
|
|
94 |
labelOverlapMeasures["Label"] = labels; |
|
|
95 |
labelOverlapMeasures["TotalOrTargetOverlap"] = totalOrTargetOverlap; |
|
|
96 |
labelOverlapMeasures["UnionOverlap"] = unionOverlap; |
|
|
97 |
labelOverlapMeasures["MeanOverlap"] = meanOverlap; |
|
|
98 |
labelOverlapMeasures["VolumeSimilarity"] = volumeSimilarity; |
|
|
99 |
labelOverlapMeasures["FalseNegativeError"] = falseNegativeError; |
|
|
100 |
labelOverlapMeasures["FalsePositiveError"] = falsePositiveError; |
|
|
101 |
|
|
|
102 |
return labelOverlapMeasures; |
|
|
103 |
} |
|
|
104 |
|
|
|
105 |
void local_labelOverlapMeasures(nb::module_ &m) |
|
|
106 |
{ |
|
|
107 |
m.def("labelOverlapMeasures2D", &labelOverlapMeasures<unsigned int, 2>); |
|
|
108 |
m.def("labelOverlapMeasures3D", &labelOverlapMeasures<unsigned int, 3>); |
|
|
109 |
m.def("labelOverlapMeasures4D", &labelOverlapMeasures<unsigned int, 4>); |
|
|
110 |
} |