--- a +++ b/src/simulateDisplacementField.cxx @@ -0,0 +1,153 @@ + +#include <nanobind/nanobind.h> +#include <nanobind/stl/vector.h> +#include <nanobind/stl/string.h> +#include <nanobind/stl/tuple.h> +#include <nanobind/stl/list.h> +#include <nanobind/ndarray.h> +#include <nanobind/stl/shared_ptr.h> + +#include <exception> +#include <vector> +#include <string> + +#include "itkImage.h" +#include "itkCastImageFilter.h" +#include "antscore/itkSimulatedBSplineDisplacementFieldSource.h" +#include "antscore/itkSimulatedExponentialDisplacementFieldSource.h" +#include "antsImage.h" + +namespace nb = nanobind; +using namespace nb::literals; + +template<class PrecisionType, unsigned int Dimension> +AntsImage<itk::VectorImage<PrecisionType, Dimension>> simulateBsplineDisplacementField(AntsImage<itk::Image<PrecisionType, Dimension>> & antsDomainImage, + unsigned int numberOfRandomPoints, + float standardDeviationDisplacementField, + bool enforceStationaryBoundary, + unsigned int numberOfFittingLevels, + std::vector<unsigned int> numberOfControlPoints) +{ + using ImageType = itk::Image<PrecisionType, Dimension>; + using ImagePointerType = typename ImageType::Pointer; + + ImagePointerType domainImage = antsDomainImage.ptr; + + using VectorType = itk::Vector<PrecisionType, Dimension>; + using DisplacementFieldType = itk::Image<VectorType, Dimension>; + using ANTsFieldType = itk::VectorImage<PrecisionType, Dimension>; + using IteratorType = itk::ImageRegionIteratorWithIndex<DisplacementFieldType>; + + using BSplineSimulatorType = itk::SimulatedBSplineDisplacementFieldSource<DisplacementFieldType>; + + typename BSplineSimulatorType::ArrayType ncps; + for( unsigned int d = 0; d < numberOfControlPoints.size(); ++d ) + { + ncps = numberOfControlPoints[d]; + } + + using RealImageType = typename BSplineSimulatorType::RealImageType; + using CastImageFilterType = itk::CastImageFilter<ImageType, RealImageType>; + typename CastImageFilterType::Pointer caster = CastImageFilterType::New(); + caster->SetInput( domainImage ); + caster->Update(); + + typename BSplineSimulatorType::Pointer bsplineSimulator = BSplineSimulatorType::New(); + bsplineSimulator->SetDisplacementFieldDomainFromImage( caster->GetOutput() ); + bsplineSimulator->SetNumberOfRandomPoints( numberOfRandomPoints ); + bsplineSimulator->SetEnforceStationaryBoundary( enforceStationaryBoundary ); + bsplineSimulator->SetDisplacementNoiseStandardDeviation( standardDeviationDisplacementField ); + bsplineSimulator->SetNumberOfFittingLevels( numberOfFittingLevels ); + bsplineSimulator->SetNumberOfControlPoints( ncps ); + bsplineSimulator->Update(); + + typename ANTsFieldType::Pointer antsField = ANTsFieldType::New(); + antsField->CopyInformation( domainImage ); + antsField->SetRegions( domainImage->GetRequestedRegion() ); + antsField->SetVectorLength( Dimension ); + antsField->AllocateInitialized(); + + IteratorType It( bsplineSimulator->GetOutput(), + bsplineSimulator->GetOutput()->GetRequestedRegion() ); + for( It.GoToBegin(); !It.IsAtEnd(); ++It ) + { + VectorType itkVector = It.Value(); + + typename ANTsFieldType::PixelType antsVector( Dimension ); + for( unsigned int d = 0; d < Dimension; d++ ) + { + antsVector[d] = itkVector[d]; + } + antsField->SetPixel( It.GetIndex(), antsVector ); + } + + AntsImage<ANTsFieldType> outImage = { antsField }; + return outImage; +} + +template<class PrecisionType, unsigned int Dimension> +AntsImage<itk::VectorImage<PrecisionType, Dimension>> simulateExponentialDisplacementField(AntsImage<itk::Image<PrecisionType, Dimension>> & antsDomainImage, + unsigned int numberOfRandomPoints, + float standardDeviationDisplacementField, + bool enforceStationaryBoundary, + float standardDeviationSmoothing) +{ + using ImageType = itk::Image<PrecisionType, Dimension>; + using ImagePointerType = typename ImageType::Pointer; + + ImagePointerType domainImage = antsDomainImage.ptr; + + using VectorType = itk::Vector<PrecisionType, Dimension>; + using DisplacementFieldType = itk::Image<VectorType, Dimension>; + using ANTsFieldType = itk::VectorImage<PrecisionType, Dimension>; + using IteratorType = itk::ImageRegionIteratorWithIndex<DisplacementFieldType>; + + using ExponentialSimulatorType = itk::SimulatedExponentialDisplacementFieldSource<DisplacementFieldType>; + + using RealImageType = typename ExponentialSimulatorType::RealImageType; + using CastImageFilterType = itk::CastImageFilter<ImageType, RealImageType>; + typename CastImageFilterType::Pointer caster = CastImageFilterType::New(); + caster->SetInput( domainImage ); + caster->Update(); + + typename ExponentialSimulatorType::Pointer exponentialSimulator = ExponentialSimulatorType::New(); + exponentialSimulator->SetDisplacementFieldDomainFromImage( caster->GetOutput() ); + exponentialSimulator->SetNumberOfRandomPoints( numberOfRandomPoints ); + exponentialSimulator->SetEnforceStationaryBoundary( enforceStationaryBoundary ); + exponentialSimulator->SetDisplacementNoiseStandardDeviation( standardDeviationDisplacementField ); + exponentialSimulator->SetSmoothingStandardDeviation( standardDeviationSmoothing ); + exponentialSimulator->Update(); + + typename ANTsFieldType::Pointer antsField = ANTsFieldType::New(); + antsField->CopyInformation( domainImage ); + antsField->SetRegions( domainImage->GetRequestedRegion() ); + antsField->SetVectorLength( Dimension ); + antsField->AllocateInitialized(); + + IteratorType It( exponentialSimulator->GetOutput(), + exponentialSimulator->GetOutput()->GetRequestedRegion() ); + for( It.GoToBegin(); !It.IsAtEnd(); ++It ) + { + VectorType itkVector = It.Value(); + + typename ANTsFieldType::PixelType antsVector( Dimension ); + for( unsigned int d = 0; d < Dimension; d++ ) + { + antsVector[d] = itkVector[d]; + } + antsField->SetPixel( It.GetIndex(), antsVector ); + } + + AntsImage<ANTsFieldType> outImage = { antsField }; + return outImage; +} + +void local_simulateDisplacementField(nb::module_ &m) +{ + m.def("simulateBsplineDisplacementField2D", &simulateBsplineDisplacementField<float, 2>); + m.def("simulateBsplineDisplacementField3D", &simulateBsplineDisplacementField<float, 3>); + + m.def("simulateExponentialDisplacementField2D", &simulateExponentialDisplacementField<float, 2>); + m.def("simulateExponentialDisplacementField3D", &simulateExponentialDisplacementField<float, 3>); +} +