Diff of /src/labelStats.cxx [000000] .. [5d12a0]

Switch to side-by-side view

--- a
+++ b/src/labelStats.cxx
@@ -0,0 +1,158 @@
+
+#include <nanobind/nanobind.h>
+#include <nanobind/stl/vector.h>
+#include <nanobind/stl/string.h>
+#include <nanobind/stl/tuple.h>
+#include <nanobind/stl/list.h>
+#include <nanobind/ndarray.h>
+#include <nanobind/stl/shared_ptr.h>
+
+#include "itkImage.h"
+#include "itkLabelStatisticsImageFilter.h"
+#include "itkImageRegionIteratorWithIndex.h"
+
+#include "antsImage.h"
+
+namespace nb = nanobind;
+using namespace nb::literals;
+
+template< unsigned int Dimension >
+nb::dict labelStatsHelper(
+  typename itk::Image< float, Dimension >::Pointer image,
+  typename itk::Image< unsigned int, Dimension>::Pointer labelImage)
+{
+  typedef float PixelType;
+  typedef itk::Image< PixelType, Dimension > ImageType;
+  typedef unsigned int LabelType;
+  typedef typename ImageType::PointType PointType;
+  typedef itk::Image< LabelType, Dimension > LabelImageType;
+  typedef itk::ImageRegionIteratorWithIndex<LabelImageType>                    Iterator;
+  typedef itk::LabelStatisticsImageFilter< ImageType, LabelImageType >
+    LabelStatisticsImageFilterType;
+  typename LabelStatisticsImageFilterType::Pointer labelStatisticsImageFilter =
+    LabelStatisticsImageFilterType::New();
+  labelStatisticsImageFilter->SetInput( image );
+  labelStatisticsImageFilter->SetLabelInput( labelImage );
+  labelStatisticsImageFilter->Update();
+
+  typedef typename LabelStatisticsImageFilterType::ValidLabelValuesContainerType
+    ValidLabelValuesType;
+
+  long nlabs = labelStatisticsImageFilter->GetNumberOfLabels();
+
+  std::vector<double> labelvals(nlabs);
+  std::vector<double> means(nlabs);
+  std::vector<double> mins(nlabs);
+  std::vector<double> maxes(nlabs);
+  std::vector<double> variances(nlabs);
+  std::vector<double> counts(nlabs);
+  std::vector<double> volumes(nlabs);
+  std::vector<double> x(nlabs);
+  std::vector<double> y(nlabs);
+  std::vector<double> z(nlabs);
+  std::vector<double> t(nlabs);
+  std::vector<double> mass(nlabs,0.0);
+
+  typename ImageType::SpacingType spacing = image->GetSpacing();
+  float voxelVolume = 1.0;
+  for (unsigned int ii = 0; ii < spacing.Size(); ii++)
+  {
+    voxelVolume *= spacing[ii];
+  }
+
+  std::map<LabelType, LabelType> RoiList;
+
+  LabelType ii = 0; // counter for label values
+  for (typename ValidLabelValuesType::const_iterator
+         labelIterator  = labelStatisticsImageFilter->GetValidLabelValues().begin();
+         labelIterator != labelStatisticsImageFilter->GetValidLabelValues().end();
+         ++labelIterator)
+  {
+    if ( labelStatisticsImageFilter->HasLabel(*labelIterator) )
+    {
+      LabelType labelValue = *labelIterator;
+      labelvals[ii] = labelValue;
+      means[ii]     = labelStatisticsImageFilter->GetMean(labelValue);
+      mins[ii]      = labelStatisticsImageFilter->GetMinimum(labelValue);
+      maxes[ii]     = labelStatisticsImageFilter->GetMaximum(labelValue);
+      variances[ii] = labelStatisticsImageFilter->GetVariance(labelValue);
+      counts[ii]    = labelStatisticsImageFilter->GetCount(labelValue);
+      volumes[ii]   = labelStatisticsImageFilter->GetCount(labelValue) * voxelVolume;
+      RoiList[ labelValue ] = ii;
+    }
+    ++ii;
+  }
+
+  Iterator It( labelImage, labelImage->GetLargestPossibleRegion() );
+  std::vector<PointType> comvec;
+  for ( unsigned int i = 0; i < nlabs; i++ )
+    {
+    typename ImageType::PointType myCenterOfMass;
+    myCenterOfMass.Fill(0);
+    comvec.push_back( myCenterOfMass );
+    }
+  for( It.GoToBegin(); !It.IsAtEnd(); ++It )
+    {
+    LabelType label = static_cast<LabelType>( It.Get() );
+    if(  label > 0  )
+      {
+      typename ImageType::PointType point;
+      image->TransformIndexToPhysicalPoint( It.GetIndex(), point );
+      for( unsigned int i = 0; i < spacing.Size(); i++ )
+        {
+        comvec[  RoiList[ label ] ][i] += point[i];
+        }
+      mass[  RoiList[ label ] ] += image->GetPixel( It.GetIndex() );
+      }
+    }
+  for ( unsigned int labelcount = 0; labelcount < comvec.size(); labelcount++ )
+    {
+    for ( unsigned int k = 0; k < Dimension; k++ )
+      {
+      comvec[ labelcount ][k] = comvec[ labelcount ][k] / counts[labelcount];
+      }
+    x[labelcount]=comvec[ labelcount ][0];
+    y[labelcount]=comvec[ labelcount ][1];
+    if ( Dimension > 2 ) z[labelcount]=comvec[ labelcount ][2];
+    if ( Dimension > 3 ) t[labelcount]=comvec[ labelcount ][3];
+    }
+
+  nb::dict labelStats;
+  labelStats["LabelValue"] = labelvals;
+  labelStats["Mean"] = means;
+  labelStats["Min"] = mins;
+  labelStats["Max"] = maxes;
+  labelStats["Variance"] = variances;
+  labelStats["Count"] = counts;
+  labelStats["Volume"] = volumes;
+  labelStats["Mass"] = mass;
+  labelStats["x"] = x;
+  labelStats["y"] = y;
+  labelStats["z"] = z;
+  labelStats["t"] = t;
+
+  return (labelStats);
+}
+
+template <unsigned int Dimension>
+nb::dict labelStats(AntsImage<itk::Image<float, Dimension>> & py_image,
+                    AntsImage<itk::Image<unsigned int, Dimension>> & py_labelImage)
+{ 
+  typedef itk::Image<float, Dimension> FloatImageType;
+  typedef itk::Image<unsigned int, Dimension> IntImageType;
+  typedef typename FloatImageType::Pointer FloatImagePointerType;
+  typedef typename IntImageType::Pointer IntImagePointerType;
+
+
+  FloatImagePointerType myimage = py_image.ptr;
+  IntImagePointerType mylabelimage = py_labelImage.ptr;
+
+  return labelStatsHelper<Dimension>( myimage, mylabelimage );
+}
+
+void local_labelStats(nb::module_ &m)
+{
+  m.def("labelStats2D", &labelStats<2>);
+  m.def("labelStats3D", &labelStats<3>);
+  m.def("labelStats4D", &labelStats<4>);
+}