--- a +++ b/src/labelStats.cxx @@ -0,0 +1,158 @@ + +#include <nanobind/nanobind.h> +#include <nanobind/stl/vector.h> +#include <nanobind/stl/string.h> +#include <nanobind/stl/tuple.h> +#include <nanobind/stl/list.h> +#include <nanobind/ndarray.h> +#include <nanobind/stl/shared_ptr.h> + +#include "itkImage.h" +#include "itkLabelStatisticsImageFilter.h" +#include "itkImageRegionIteratorWithIndex.h" + +#include "antsImage.h" + +namespace nb = nanobind; +using namespace nb::literals; + +template< unsigned int Dimension > +nb::dict labelStatsHelper( + typename itk::Image< float, Dimension >::Pointer image, + typename itk::Image< unsigned int, Dimension>::Pointer labelImage) +{ + typedef float PixelType; + typedef itk::Image< PixelType, Dimension > ImageType; + typedef unsigned int LabelType; + typedef typename ImageType::PointType PointType; + typedef itk::Image< LabelType, Dimension > LabelImageType; + typedef itk::ImageRegionIteratorWithIndex<LabelImageType> Iterator; + typedef itk::LabelStatisticsImageFilter< ImageType, LabelImageType > + LabelStatisticsImageFilterType; + typename LabelStatisticsImageFilterType::Pointer labelStatisticsImageFilter = + LabelStatisticsImageFilterType::New(); + labelStatisticsImageFilter->SetInput( image ); + labelStatisticsImageFilter->SetLabelInput( labelImage ); + labelStatisticsImageFilter->Update(); + + typedef typename LabelStatisticsImageFilterType::ValidLabelValuesContainerType + ValidLabelValuesType; + + long nlabs = labelStatisticsImageFilter->GetNumberOfLabels(); + + std::vector<double> labelvals(nlabs); + std::vector<double> means(nlabs); + std::vector<double> mins(nlabs); + std::vector<double> maxes(nlabs); + std::vector<double> variances(nlabs); + std::vector<double> counts(nlabs); + std::vector<double> volumes(nlabs); + std::vector<double> x(nlabs); + std::vector<double> y(nlabs); + std::vector<double> z(nlabs); + std::vector<double> t(nlabs); + std::vector<double> mass(nlabs,0.0); + + typename ImageType::SpacingType spacing = image->GetSpacing(); + float voxelVolume = 1.0; + for (unsigned int ii = 0; ii < spacing.Size(); ii++) + { + voxelVolume *= spacing[ii]; + } + + std::map<LabelType, LabelType> RoiList; + + LabelType ii = 0; // counter for label values + for (typename ValidLabelValuesType::const_iterator + labelIterator = labelStatisticsImageFilter->GetValidLabelValues().begin(); + labelIterator != labelStatisticsImageFilter->GetValidLabelValues().end(); + ++labelIterator) + { + if ( labelStatisticsImageFilter->HasLabel(*labelIterator) ) + { + LabelType labelValue = *labelIterator; + labelvals[ii] = labelValue; + means[ii] = labelStatisticsImageFilter->GetMean(labelValue); + mins[ii] = labelStatisticsImageFilter->GetMinimum(labelValue); + maxes[ii] = labelStatisticsImageFilter->GetMaximum(labelValue); + variances[ii] = labelStatisticsImageFilter->GetVariance(labelValue); + counts[ii] = labelStatisticsImageFilter->GetCount(labelValue); + volumes[ii] = labelStatisticsImageFilter->GetCount(labelValue) * voxelVolume; + RoiList[ labelValue ] = ii; + } + ++ii; + } + + Iterator It( labelImage, labelImage->GetLargestPossibleRegion() ); + std::vector<PointType> comvec; + for ( unsigned int i = 0; i < nlabs; i++ ) + { + typename ImageType::PointType myCenterOfMass; + myCenterOfMass.Fill(0); + comvec.push_back( myCenterOfMass ); + } + for( It.GoToBegin(); !It.IsAtEnd(); ++It ) + { + LabelType label = static_cast<LabelType>( It.Get() ); + if( label > 0 ) + { + typename ImageType::PointType point; + image->TransformIndexToPhysicalPoint( It.GetIndex(), point ); + for( unsigned int i = 0; i < spacing.Size(); i++ ) + { + comvec[ RoiList[ label ] ][i] += point[i]; + } + mass[ RoiList[ label ] ] += image->GetPixel( It.GetIndex() ); + } + } + for ( unsigned int labelcount = 0; labelcount < comvec.size(); labelcount++ ) + { + for ( unsigned int k = 0; k < Dimension; k++ ) + { + comvec[ labelcount ][k] = comvec[ labelcount ][k] / counts[labelcount]; + } + x[labelcount]=comvec[ labelcount ][0]; + y[labelcount]=comvec[ labelcount ][1]; + if ( Dimension > 2 ) z[labelcount]=comvec[ labelcount ][2]; + if ( Dimension > 3 ) t[labelcount]=comvec[ labelcount ][3]; + } + + nb::dict labelStats; + labelStats["LabelValue"] = labelvals; + labelStats["Mean"] = means; + labelStats["Min"] = mins; + labelStats["Max"] = maxes; + labelStats["Variance"] = variances; + labelStats["Count"] = counts; + labelStats["Volume"] = volumes; + labelStats["Mass"] = mass; + labelStats["x"] = x; + labelStats["y"] = y; + labelStats["z"] = z; + labelStats["t"] = t; + + return (labelStats); +} + +template <unsigned int Dimension> +nb::dict labelStats(AntsImage<itk::Image<float, Dimension>> & py_image, + AntsImage<itk::Image<unsigned int, Dimension>> & py_labelImage) +{ + typedef itk::Image<float, Dimension> FloatImageType; + typedef itk::Image<unsigned int, Dimension> IntImageType; + typedef typename FloatImageType::Pointer FloatImagePointerType; + typedef typename IntImageType::Pointer IntImagePointerType; + + + FloatImagePointerType myimage = py_image.ptr; + IntImagePointerType mylabelimage = py_labelImage.ptr; + + return labelStatsHelper<Dimension>( myimage, mylabelimage ); +} + +void local_labelStats(nb::module_ &m) +{ + m.def("labelStats2D", &labelStats<2>); + m.def("labelStats3D", &labelStats<3>); + m.def("labelStats4D", &labelStats<4>); +}