Switch to unified view

a b/src/simulateDisplacementField.cxx
1
2
#include <nanobind/nanobind.h>
3
#include <nanobind/stl/vector.h>
4
#include <nanobind/stl/string.h>
5
#include <nanobind/stl/tuple.h>
6
#include <nanobind/stl/list.h>
7
#include <nanobind/ndarray.h>
8
#include <nanobind/stl/shared_ptr.h>
9
10
#include <exception>
11
#include <vector>
12
#include <string>
13
14
#include "itkImage.h"
15
#include "itkCastImageFilter.h"
16
#include "antscore/itkSimulatedBSplineDisplacementFieldSource.h"
17
#include "antscore/itkSimulatedExponentialDisplacementFieldSource.h"
18
#include "antsImage.h"
19
20
namespace nb = nanobind;
21
using namespace nb::literals;
22
23
template<class PrecisionType, unsigned int Dimension>
24
AntsImage<itk::VectorImage<PrecisionType, Dimension>> simulateBsplineDisplacementField(AntsImage<itk::Image<PrecisionType, Dimension>> & antsDomainImage,
25
                                             unsigned int numberOfRandomPoints,
26
                                             float standardDeviationDisplacementField,
27
                                             bool enforceStationaryBoundary,
28
                                             unsigned int numberOfFittingLevels,
29
                                             std::vector<unsigned int> numberOfControlPoints)
30
{
31
  using ImageType = itk::Image<PrecisionType, Dimension>;
32
  using ImagePointerType = typename ImageType::Pointer;
33
34
  ImagePointerType domainImage = antsDomainImage.ptr;
35
36
  using VectorType = itk::Vector<PrecisionType, Dimension>;
37
  using DisplacementFieldType = itk::Image<VectorType, Dimension>;
38
  using ANTsFieldType = itk::VectorImage<PrecisionType, Dimension>;
39
  using IteratorType = itk::ImageRegionIteratorWithIndex<DisplacementFieldType>;
40
41
  using BSplineSimulatorType = itk::SimulatedBSplineDisplacementFieldSource<DisplacementFieldType>;
42
43
  typename BSplineSimulatorType::ArrayType ncps;
44
  for( unsigned int d = 0; d < numberOfControlPoints.size(); ++d )
45
    {
46
    ncps = numberOfControlPoints[d];
47
    }
48
49
  using RealImageType = typename BSplineSimulatorType::RealImageType;
50
  using CastImageFilterType = itk::CastImageFilter<ImageType, RealImageType>;
51
  typename CastImageFilterType::Pointer caster = CastImageFilterType::New();
52
  caster->SetInput( domainImage );
53
  caster->Update();
54
55
  typename BSplineSimulatorType::Pointer bsplineSimulator = BSplineSimulatorType::New();
56
  bsplineSimulator->SetDisplacementFieldDomainFromImage( caster->GetOutput() );
57
  bsplineSimulator->SetNumberOfRandomPoints( numberOfRandomPoints );
58
  bsplineSimulator->SetEnforceStationaryBoundary( enforceStationaryBoundary );
59
  bsplineSimulator->SetDisplacementNoiseStandardDeviation( standardDeviationDisplacementField );
60
  bsplineSimulator->SetNumberOfFittingLevels( numberOfFittingLevels );
61
  bsplineSimulator->SetNumberOfControlPoints( ncps );
62
  bsplineSimulator->Update();
63
64
  typename ANTsFieldType::Pointer antsField = ANTsFieldType::New();
65
  antsField->CopyInformation( domainImage );
66
  antsField->SetRegions( domainImage->GetRequestedRegion() );
67
  antsField->SetVectorLength( Dimension );
68
  antsField->AllocateInitialized();
69
70
  IteratorType It( bsplineSimulator->GetOutput(),
71
    bsplineSimulator->GetOutput()->GetRequestedRegion() );
72
  for( It.GoToBegin(); !It.IsAtEnd(); ++It )
73
    {
74
    VectorType itkVector = It.Value();
75
76
    typename ANTsFieldType::PixelType antsVector( Dimension );
77
    for( unsigned int d = 0; d < Dimension; d++ )
78
      {
79
      antsVector[d] = itkVector[d];
80
      }
81
    antsField->SetPixel( It.GetIndex(), antsVector );
82
    }
83
84
  AntsImage<ANTsFieldType> outImage = { antsField };
85
  return outImage;
86
}
87
88
template<class PrecisionType, unsigned int Dimension>
89
AntsImage<itk::VectorImage<PrecisionType, Dimension>> simulateExponentialDisplacementField(AntsImage<itk::Image<PrecisionType, Dimension>> & antsDomainImage,
90
                                                 unsigned int numberOfRandomPoints,
91
                                                 float standardDeviationDisplacementField,
92
                                                 bool enforceStationaryBoundary,
93
                                                 float standardDeviationSmoothing)
94
{
95
  using ImageType = itk::Image<PrecisionType, Dimension>;
96
  using ImagePointerType = typename ImageType::Pointer;
97
98
  ImagePointerType domainImage = antsDomainImage.ptr;
99
100
  using VectorType = itk::Vector<PrecisionType, Dimension>;
101
  using DisplacementFieldType = itk::Image<VectorType, Dimension>;
102
  using ANTsFieldType = itk::VectorImage<PrecisionType, Dimension>;
103
  using IteratorType = itk::ImageRegionIteratorWithIndex<DisplacementFieldType>;
104
105
  using ExponentialSimulatorType = itk::SimulatedExponentialDisplacementFieldSource<DisplacementFieldType>;
106
107
  using RealImageType = typename ExponentialSimulatorType::RealImageType;
108
  using CastImageFilterType = itk::CastImageFilter<ImageType, RealImageType>;
109
  typename CastImageFilterType::Pointer caster = CastImageFilterType::New();
110
  caster->SetInput( domainImage );
111
  caster->Update();
112
113
  typename ExponentialSimulatorType::Pointer exponentialSimulator = ExponentialSimulatorType::New();
114
  exponentialSimulator->SetDisplacementFieldDomainFromImage( caster->GetOutput() );
115
  exponentialSimulator->SetNumberOfRandomPoints( numberOfRandomPoints );
116
  exponentialSimulator->SetEnforceStationaryBoundary( enforceStationaryBoundary );
117
  exponentialSimulator->SetDisplacementNoiseStandardDeviation( standardDeviationDisplacementField );
118
  exponentialSimulator->SetSmoothingStandardDeviation( standardDeviationSmoothing );
119
  exponentialSimulator->Update();
120
121
  typename ANTsFieldType::Pointer antsField = ANTsFieldType::New();
122
  antsField->CopyInformation( domainImage );
123
  antsField->SetRegions( domainImage->GetRequestedRegion() );
124
  antsField->SetVectorLength( Dimension );
125
  antsField->AllocateInitialized();
126
127
  IteratorType It( exponentialSimulator->GetOutput(),
128
    exponentialSimulator->GetOutput()->GetRequestedRegion() );
129
  for( It.GoToBegin(); !It.IsAtEnd(); ++It )
130
    {
131
    VectorType itkVector = It.Value();
132
133
    typename ANTsFieldType::PixelType antsVector( Dimension );
134
    for( unsigned int d = 0; d < Dimension; d++ )
135
      {
136
      antsVector[d] = itkVector[d];
137
      }
138
    antsField->SetPixel( It.GetIndex(), antsVector );
139
    }
140
141
  AntsImage<ANTsFieldType> outImage = { antsField };
142
  return outImage;
143
}
144
145
void local_simulateDisplacementField(nb::module_ &m)
146
{
147
  m.def("simulateBsplineDisplacementField2D", &simulateBsplineDisplacementField<float, 2>);
148
  m.def("simulateBsplineDisplacementField3D", &simulateBsplineDisplacementField<float, 3>);
149
150
  m.def("simulateExponentialDisplacementField2D", &simulateExponentialDisplacementField<float, 2>);
151
  m.def("simulateExponentialDisplacementField3D", &simulateExponentialDisplacementField<float, 3>);
152
}
153