Switch to unified view

a b/src/labelOverlapMeasures.cxx
1
2
#include <nanobind/nanobind.h>
3
#include <nanobind/stl/vector.h>
4
#include <nanobind/stl/string.h>
5
#include <nanobind/stl/tuple.h>
6
#include <nanobind/stl/list.h>
7
#include <nanobind/ndarray.h>
8
#include <nanobind/stl/shared_ptr.h>
9
10
#include <exception>
11
#include <vector>
12
#include <string>
13
14
#include "itkImage.h"
15
#include "itkLabelOverlapMeasuresImageFilter.h"
16
17
#include "antsImage.h"
18
19
namespace nb = nanobind;
20
using namespace nb::literals;
21
22
template<class PrecisionType, unsigned int ImageDimension>
23
nb::dict labelOverlapMeasures( AntsImage<itk::Image<PrecisionType, ImageDimension>> &  antsSourceImage,
24
                                AntsImage<itk::Image<PrecisionType, ImageDimension>> &  antsTargetImage )
25
{
26
  using ImageType = itk::Image<PrecisionType, ImageDimension>;
27
  using ImagePointerType = typename ImageType::Pointer;
28
29
  typename ImageType::Pointer itkSourceImage = antsSourceImage.ptr;
30
  typename ImageType::Pointer itkTargetImage = antsTargetImage.ptr;
31
32
  using FilterType = itk::LabelOverlapMeasuresImageFilter<ImageType>;
33
  typename FilterType::Pointer filter = FilterType::New();
34
  filter->SetSourceImage( itkSourceImage );
35
  filter->SetTargetImage( itkTargetImage );
36
  filter->Update();
37
38
  typename FilterType::MapType labelMap = filter->GetLabelSetMeasures();
39
40
  // Sort the labels
41
42
  std::vector<PrecisionType> allLabels;
43
  allLabels.clear();
44
  for( typename FilterType::MapType::const_iterator it = labelMap.begin();
45
       it != labelMap.end(); ++it )
46
    {
47
    if( (*it).first == 0 )
48
      {
49
      continue;
50
      }
51
52
    const int label = (*it).first;
53
    allLabels.push_back( label );
54
    }
55
  std::sort( allLabels.begin(), allLabels.end() );
56
57
58
  // Now put the results in an Rcpp data frame
59
60
  unsigned int vectorLength = 1 + allLabels.size();
61
62
  std::vector<PrecisionType> labels( vectorLength );
63
  std::vector<double> totalOrTargetOverlap( vectorLength );
64
  std::vector<double> unionOverlap( vectorLength );
65
  std::vector<double> meanOverlap( vectorLength );
66
  std::vector<double> volumeSimilarity( vectorLength );
67
  std::vector<double> falseNegativeError( vectorLength );
68
  std::vector<double> falsePositiveError( vectorLength );
69
70
  // We'll replace label '0' with "All" in the R wrapper.
71
  labels[0] = itk::NumericTraits<PrecisionType>::Zero;
72
  totalOrTargetOverlap[0] = filter->GetTotalOverlap();
73
  unionOverlap[0] = filter->GetUnionOverlap();
74
  meanOverlap[0] = filter->GetMeanOverlap();
75
  volumeSimilarity[0] = filter->GetVolumeSimilarity();
76
  falseNegativeError[0] = filter->GetFalseNegativeError();
77
  falsePositiveError[0] = filter->GetFalsePositiveError();
78
79
  unsigned int i = 1;
80
  typename std::vector<PrecisionType>::const_iterator itL = allLabels.begin();
81
  for( itL = allLabels.begin(); itL != allLabels.end(); ++itL )
82
    {
83
    labels[i] = *itL;
84
    totalOrTargetOverlap[i] = filter->GetTargetOverlap( *itL );
85
    unionOverlap[i] = filter->GetUnionOverlap( *itL );
86
    meanOverlap[i] = filter->GetMeanOverlap( *itL );
87
    volumeSimilarity[i] = filter->GetVolumeSimilarity( *itL );
88
    falseNegativeError[i] = filter->GetFalseNegativeError( *itL );
89
    falsePositiveError[i] = filter->GetFalsePositiveError( *itL );
90
    i++;
91
    }
92
93
  nb::dict labelOverlapMeasures;
94
  labelOverlapMeasures["Label"] = labels;
95
  labelOverlapMeasures["TotalOrTargetOverlap"] = totalOrTargetOverlap;
96
  labelOverlapMeasures["UnionOverlap"] = unionOverlap;
97
  labelOverlapMeasures["MeanOverlap"] = meanOverlap;
98
  labelOverlapMeasures["VolumeSimilarity"] = volumeSimilarity;
99
  labelOverlapMeasures["FalseNegativeError"] = falseNegativeError;
100
  labelOverlapMeasures["FalsePositiveError"] = falsePositiveError;
101
102
  return labelOverlapMeasures;
103
}
104
105
void local_labelOverlapMeasures(nb::module_ &m)
106
{
107
  m.def("labelOverlapMeasures2D", &labelOverlapMeasures<unsigned int, 2>);
108
  m.def("labelOverlapMeasures3D", &labelOverlapMeasures<unsigned int, 3>);
109
  m.def("labelOverlapMeasures4D", &labelOverlapMeasures<unsigned int, 4>);
110
}