Diff of /src/sccaner.cxx [000000] .. [5d12a0]

Switch to unified view

a b/src/sccaner.cxx
1
2
#include <nanobind/nanobind.h>
3
#include <nanobind/stl/vector.h>
4
#include <nanobind/stl/string.h>
5
#include <nanobind/stl/tuple.h>
6
#include <nanobind/stl/list.h>
7
#include <nanobind/ndarray.h>
8
#include <nanobind/stl/shared_ptr.h>
9
10
#include <exception>
11
#include <vector>
12
#include <string>
13
#include "itkImage.h"
14
#include "itkVectorImage.h"
15
#include "itkImageRegionIteratorWithIndex.h"
16
17
#include "antscore/antsSCCANObject.h"
18
#include "antsImage.h"
19
20
namespace nb = nanobind;
21
using namespace nb::literals;
22
23
template< class ImageType, class IntType, class RealType >
24
nb::dict sccanCppHelper(
25
  std::vector<std::vector<double>> X,
26
  std::vector<std::vector<double>> Y,
27
  AntsImage<ImageType> & maskXimage,
28
  AntsImage<ImageType> & maskYimage,
29
  int maskxisnull,
30
  int maskyisnull,
31
  RealType sparsenessx,
32
  RealType sparsenessy,
33
  IntType nvecs,
34
  IntType its,
35
  IntType cthreshx,
36
  IntType cthreshy,
37
  RealType z,
38
  RealType smooth,
39
  std::vector<AntsImage<ImageType>> initializationListx,
40
  std::vector<AntsImage<ImageType>> initializationListy,
41
  IntType covering,
42
  RealType ell1,
43
  IntType verbose,
44
  RealType priorWeight,
45
  IntType useMaxBasedThresh )
46
{
47
  enum { Dimension = ImageType::ImageDimension };
48
  typename ImageType::RegionType region;
49
  typedef typename ImageType::PixelType PixelType;
50
  typedef typename ImageType::Pointer ImagePointerType;
51
  typedef double                                        Scalar;
52
  typedef itk::ants::antsSCCANObject<ImageType, Scalar> SCCANType;
53
  typedef typename SCCANType::MatrixType                vMatrix;
54
  typename SCCANType::Pointer sccanobj = SCCANType::New();
55
  sccanobj->SetMaxBasedThresholding( useMaxBasedThresh );
56
57
  // cast mask ANTsImages to itk
58
  typename ImageType::Pointer maskx = ITK_NULLPTR;
59
  if (maskxisnull > 0)
60
  {
61
    maskx = maskXimage.ptr;
62
  }
63
  typename ImageType::Pointer masky = ITK_NULLPTR;
64
  if (maskyisnull > 0)
65
  {
66
    masky = maskYimage.ptr;
67
  }
68
69
// deal with the initializationList, if any
70
  unsigned int nImagesx = initializationListx.size();
71
  if ( ( nImagesx > 0 ) && ( !maskxisnull ) )
72
  {
73
    itk::ImageRegionIteratorWithIndex<ImageType> it( maskx,
74
      maskx->GetLargestPossibleRegion() );
75
    vMatrix priorROIMatx( nImagesx , X[0].size() );
76
    priorROIMatx.fill( 0 );
77
    for ( unsigned int i = 0; i < nImagesx; i++ )
78
    {
79
      typename ImageType::Pointer init = initializationListx[i].ptr;
80
      unsigned long ct = 0;
81
      it.GoToBegin();
82
      while ( !it.IsAtEnd() )
83
      {
84
        PixelType pix = it.Get();
85
        if ( pix >= 0.5 )
86
        {
87
          pix = init->GetPixel( it.GetIndex() );
88
          priorROIMatx( i, ct ) = pix;
89
          ct++;
90
        }
91
        ++it;
92
      }
93
    }
94
    sccanobj->SetMatrixPriorROI( priorROIMatx );
95
    nvecs = nImagesx;
96
  }
97
  unsigned int nImagesy = initializationListy.size();
98
  if ( ( nImagesy > 0 ) && ( !maskyisnull ) )
99
  {
100
    itk::ImageRegionIteratorWithIndex<ImageType> it( masky,
101
      masky->GetLargestPossibleRegion() );
102
    vMatrix priorROIMaty( nImagesy , Y[0].size() );
103
    priorROIMaty.fill( 0 );
104
    for ( unsigned int i = 0; i < nImagesy; i++ )
105
    {
106
      typename ImageType::Pointer init = initializationListy[i].ptr;
107
      unsigned long ct = 0;
108
      it.GoToBegin();
109
      while ( !it.IsAtEnd() )
110
      {
111
        PixelType pix = it.Get();
112
        if ( pix >= 0.5 )
113
        {
114
          pix = init->GetPixel( it.GetIndex() );
115
          priorROIMaty( i, ct ) = pix;
116
          ct++;
117
        }
118
        ++it;
119
      }
120
    }
121
    sccanobj->SetMatrixPriorROI2( priorROIMaty );
122
    nvecs = nImagesy;
123
  }
124
  sccanobj->SetPriorWeight( priorWeight );
125
  sccanobj->SetLambda( priorWeight );
126
// cast hack from Python type to sccan type
127
  //std::vector<double> xdat = X;//.reshape({-1}).cast<std::vector<double> >();
128
  //const double* _xdata = &xdat[0];
129
130
    vMatrix vnlX(  X.size(), X[0].size() );
131
  for (int i = 0; i < X.size(); i++)
132
  {
133
    for (int j = 0; j < X[0].size(); j++)
134
    {
135
      vnlX(i,j) = X[i][j];
136
    }
137
  }
138
139
  //vnlX = vnlX.transpose();
140
141
  //std::vector<double> ydat = Y.reshape({-1}).cast<std::vector<double> >();
142
  //const double* _ydata = &ydat[0];
143
  //vMatrix vnlY( _ydata , Y.shape(0), Y.shape(1)  );
144
    vMatrix vnlY(  Y.size(), Y[0].size() );
145
  for (int i = 0; i < Y.size(); i++)
146
  {
147
    for (int j = 0; j < Y[0].size(); j++)
148
    {
149
      vnlY(i,j) = Y[i][j];
150
    }
151
  }
152
  //vnlY = vnlY.transpose();
153
// cast hack done
154
  sccanobj->SetGetSmall( false  );
155
  sccanobj->SetCovering( covering );
156
  sccanobj->SetSilent(  ! verbose  );
157
  if( ell1 > 0 )
158
    {
159
    sccanobj->SetUseL1( true );
160
    }
161
  else
162
    {
163
    sccanobj->SetUseL1( false );
164
    }
165
  sccanobj->SetGradStep( std::abs( ell1 ) );
166
  sccanobj->SetMaximumNumberOfIterations( its );
167
  sccanobj->SetRowSparseness( z );
168
  sccanobj->SetSmoother( smooth );
169
  if ( sparsenessx < 0 ) sccanobj->SetKeepPositiveP(false);
170
  if ( sparsenessy < 0 ) sccanobj->SetKeepPositiveQ(false);
171
  sccanobj->SetSCCANFormulation(  SCCANType::PQ );
172
  sccanobj->SetFractionNonZeroP( fabs( sparsenessx ) );
173
  sccanobj->SetFractionNonZeroQ( fabs( sparsenessy ) );
174
  sccanobj->SetMinClusterSizeP( cthreshx );
175
  sccanobj->SetMinClusterSizeQ( cthreshy );
176
  sccanobj->SetMatrixP( vnlX );
177
  sccanobj->SetMatrixQ( vnlY );
178
//  sccanobj->SetMatrixR( r ); // FIXME
179
  sccanobj->SetMaskImageP( maskx );
180
  sccanobj->SetMaskImageQ( masky );
181
  sccanobj->SparsePartialArnoldiCCA( nvecs );
182
183
  // FIXME - should not copy, should map memory
184
  vMatrix solP = sccanobj->GetVariatesP();
185
  std::vector<std::vector<float> > eanatMatp( solP.cols(), std::vector<float>(solP.rows()));
186
  unsigned long rows = solP.rows();
187
  for( unsigned long c = 0; c < solP.cols(); c++ )
188
    {
189
    for( unsigned int r = 0; r < rows; r++ )
190
      {
191
      eanatMatp[c][r] = solP( r, c );
192
      }
193
    }
194
195
  vMatrix solQ = sccanobj->GetVariatesQ();
196
197
  std::vector<std::vector<float> > eanatMatq( solQ.cols(), std::vector<float>(solQ.rows()));
198
  rows = solQ.rows();
199
  for( unsigned long c = 0; c < solQ.cols(); c++ )
200
    {
201
    for( unsigned int r = 0; r < rows; r++ )
202
      {
203
      eanatMatq[c][r] = solQ( r, c );
204
      }
205
    }
206
207
  //nb::ndarray<double> eanatMatpList = nb::cast(eanatMatp);
208
  //nb::ndarray<double> eanatMatqList = nb::cast(eanatMatq);
209
210
     nb::dict res;
211
   res["eig1"] = eanatMatp; 
212
   res["eig2"] = eanatMatq;
213
   return res;
214
215
}
216
217
template <typename ImageType>
218
nb::dict sccanCpp(std::vector<std::vector<double>> X,
219
                   std::vector<std::vector<double>> Y,
220
                   AntsImage<ImageType> & maskXimage,
221
                   AntsImage<ImageType> & maskYimage,
222
                   int maskxisnull,
223
                   int maskyisnull,
224
                   float sparsenessx,
225
                   float sparsenessy,
226
                   unsigned int nvecs,
227
                   unsigned int its,
228
                   unsigned int cthreshx,
229
                   unsigned int cthreshy,
230
                   float z,
231
                   float smooth,
232
                   std::vector<AntsImage<ImageType>> initializationListx,
233
                   std::vector<AntsImage<ImageType>> initializationListy,
234
                   float ell1,
235
                   unsigned int verbose,
236
                   float priorWeight,
237
                   unsigned int mycoption,
238
                   unsigned int maxBasedThresh)
239
{
240
  typedef float RealType;
241
  typedef unsigned int IntType;
242
  return  sccanCppHelper<ImageType,IntType,RealType>(
243
        X,
244
        Y,
245
        maskXimage,
246
        maskYimage,
247
        maskxisnull,
248
        maskyisnull,
249
        sparsenessx,
250
        sparsenessy,
251
        nvecs,
252
        its,
253
        cthreshx,
254
        cthreshy,
255
        z,
256
        smooth,
257
        initializationListx,
258
        initializationListy,
259
        mycoption,
260
        ell1,
261
        verbose,
262
        priorWeight,
263
        maxBasedThresh
264
        );
265
266
}
267
268
269
template< class ImageType, class IntType, class RealType >
270
nb::dict sccanCppHelperV2(
271
  std::vector<std::vector<double> > X,
272
  std::vector<std::vector<double> > Y,
273
  AntsImage<ImageType> & maskXimage,
274
  AntsImage<ImageType> & maskYimage,
275
  int maskxisnull,
276
  int maskyisnull,
277
  RealType sparsenessx,
278
  RealType sparsenessy,
279
  IntType nvecs,
280
  IntType its,
281
  IntType cthreshx,
282
  IntType cthreshy,
283
  RealType z,
284
  RealType smooth,
285
  std::vector<AntsImage<ImageType>> initializationListx,
286
  std::vector<AntsImage<ImageType>> initializationListy,
287
  IntType covering,
288
  RealType ell1,
289
  IntType verbose,
290
  RealType priorWeight,
291
  IntType useMaxBasedThresh )
292
{
293
  enum { Dimension = ImageType::ImageDimension };
294
  typename ImageType::RegionType region;
295
  typedef typename ImageType::PixelType PixelType;
296
  typedef typename ImageType::Pointer ImagePointerType;
297
  typedef double                                        Scalar;
298
  typedef itk::ants::antsSCCANObject<ImageType, Scalar> SCCANType;
299
  typedef typename SCCANType::MatrixType                vMatrix;
300
  typename SCCANType::Pointer sccanobj = SCCANType::New();
301
  sccanobj->SetMaxBasedThresholding( useMaxBasedThresh );
302
303
  // cast mask ANTsImages to itk
304
  typename ImageType::Pointer maskx = ITK_NULLPTR;
305
  if (maskxisnull > 0)
306
  {
307
    maskx = maskXimage.ptr;
308
  }
309
  typename ImageType::Pointer masky = ITK_NULLPTR;
310
  if (maskyisnull > 0)
311
  {
312
    masky = maskYimage.ptr;
313
  }
314
315
// deal with the initializationList, if any
316
  unsigned int nImagesx = initializationListx.size();
317
  if ( ( nImagesx > 0 ) && ( !maskxisnull ) )
318
  {
319
    itk::ImageRegionIteratorWithIndex<ImageType> it( maskx,
320
      maskx->GetLargestPossibleRegion() );
321
    vMatrix priorROIMatx( nImagesx , X[0].size() );
322
    priorROIMatx.fill( 0 );
323
    for ( unsigned int i = 0; i < nImagesx; i++ )
324
    {
325
      typename ImageType::Pointer init = initializationListx[i].ptr;
326
      unsigned long ct = 0;
327
      it.GoToBegin();
328
      while ( !it.IsAtEnd() )
329
      {
330
        PixelType pix = it.Get();
331
        if ( pix >= 0.5 )
332
        {
333
          pix = init->GetPixel( it.GetIndex() );
334
          priorROIMatx( i, ct ) = pix;
335
          ct++;
336
        }
337
        ++it;
338
      }
339
    }
340
    sccanobj->SetMatrixPriorROI( priorROIMatx );
341
    nvecs = nImagesx;
342
  }
343
  unsigned int nImagesy = initializationListy.size();
344
  if ( ( nImagesy > 0 ) && ( !maskyisnull ) )
345
  {
346
    itk::ImageRegionIteratorWithIndex<ImageType> it( masky,
347
      masky->GetLargestPossibleRegion() );
348
    vMatrix priorROIMaty( nImagesy , Y[0].size() );
349
    priorROIMaty.fill( 0 );
350
    for ( unsigned int i = 0; i < nImagesy; i++ )
351
    {
352
      typename ImageType::Pointer init = initializationListy[i].ptr;
353
      unsigned long ct = 0;
354
      it.GoToBegin();
355
      while ( !it.IsAtEnd() )
356
      {
357
        PixelType pix = it.Get();
358
        if ( pix >= 0.5 )
359
        {
360
          pix = init->GetPixel( it.GetIndex() );
361
          priorROIMaty( i, ct ) = pix;
362
          ct++;
363
        }
364
        ++it;
365
      }
366
    }
367
    sccanobj->SetMatrixPriorROI2( priorROIMaty );
368
    nvecs = nImagesy;
369
  }
370
  sccanobj->SetPriorWeight( priorWeight );
371
  sccanobj->SetLambda( priorWeight );
372
// cast hack from Python type to sccan type
373
  //std::vector<std::vector<double> > xdat;
374
  //const double* _xdata = &X[0][0];
375
  vMatrix vnlX( X.size(), X[0].size()  );
376
  for (int i = 0; i < X.size(); i++)
377
  {
378
    for (int j = 0; j < X[0].size(); j++)
379
    {
380
      vnlX(i,j) = X[i][j];
381
    }
382
  }
383
  //vnlX = vnlX.transpose();
384
385
  //std::vector<std::vector<double> > ydat;
386
  //const double*  _ydata = &Y[0][0];
387
  vMatrix vnlY(  Y.size(), Y[0].size() );
388
  for (int i = 0; i < Y.size(); i++)
389
  {
390
    for (int j = 0; j < Y[0].size(); j++)
391
    {
392
      vnlY(i,j) = Y[i][j];
393
    }
394
  }
395
// cast hack done
396
  sccanobj->SetGetSmall( false  );
397
  sccanobj->SetCovering( covering );
398
  sccanobj->SetSilent(  ! verbose  );
399
  if( ell1 > 0 )
400
    {
401
    sccanobj->SetUseL1( true );
402
    }
403
  else
404
    {
405
    sccanobj->SetUseL1( false );
406
    }
407
  sccanobj->SetGradStep( std::abs( ell1 ) );
408
  sccanobj->SetMaximumNumberOfIterations( its );
409
  sccanobj->SetRowSparseness( z );
410
  sccanobj->SetSmoother( smooth );
411
  if ( sparsenessx < 0 ) sccanobj->SetKeepPositiveP(false);
412
  if ( sparsenessy < 0 ) sccanobj->SetKeepPositiveQ(false);
413
  sccanobj->SetSCCANFormulation(  SCCANType::PQ );
414
  sccanobj->SetFractionNonZeroP( fabs( sparsenessx ) );
415
  sccanobj->SetFractionNonZeroQ( fabs( sparsenessy ) );
416
  sccanobj->SetMinClusterSizeP( cthreshx );
417
  sccanobj->SetMinClusterSizeQ( cthreshy );
418
  sccanobj->SetMatrixP( vnlX );
419
  sccanobj->SetMatrixQ( vnlY );
420
//  sccanobj->SetMatrixR( r ); // FIXME
421
  sccanobj->SetMaskImageP( maskx );
422
  sccanobj->SetMaskImageQ( masky );
423
  sccanobj->SparsePartialArnoldiCCA( nvecs );
424
425
  // FIXME - should not copy, should map memory
426
  vMatrix solP = sccanobj->GetVariatesP();
427
  std::vector<std::vector<double> > eanatMatp( solP.cols(), std::vector<double>(solP.rows()));
428
  unsigned long rows = solP.rows();
429
  for( unsigned long c = 0; c < solP.cols(); c++ )
430
    {
431
    for( unsigned int r = 0; r < rows; r++ )
432
      {
433
      eanatMatp[c][r] = solP( r, c );
434
      }
435
    }
436
437
  vMatrix solQ = sccanobj->GetVariatesQ();
438
439
  std::vector<std::vector<double> > eanatMatq( solQ.cols(), std::vector<double>(solQ.rows()));
440
  rows = solQ.rows();
441
  for( unsigned long c = 0; c < solQ.cols(); c++ )
442
    {
443
    for( unsigned int r = 0; r < rows; r++ )
444
      {
445
      eanatMatq[c][r] = solQ( r, c );
446
      }
447
    }
448
449
  //nb::ndarray<double> eanatMatpList = nb::cast(eanatMatp);
450
  //nb::ndarray<double> eanatMatqList = nb::cast(eanatMatq);
451
452
   nb::dict res;
453
   res["eig1"] = eanatMatp; 
454
   res["eig2"] = eanatMatq;
455
   return res;
456
457
}
458
459
template <typename ImageType>
460
nb::dict sccanCppV2(std::vector<std::vector<double> > X,
461
                   std::vector<std::vector<double> >  Y,
462
                   AntsImage<ImageType> & maskXimage,
463
                   AntsImage<ImageType> & maskYimage,
464
                   int maskxisnull,
465
                   int maskyisnull,
466
                   float sparsenessx,
467
                   float sparsenessy,
468
                   unsigned int nvecs,
469
                   unsigned int its,
470
                   unsigned int cthreshx,
471
                   unsigned int cthreshy,
472
                   float z,
473
                   float smooth,
474
                   std::vector<AntsImage<ImageType>> initializationListx,
475
                   std::vector<AntsImage<ImageType>> initializationListy,
476
                   float ell1,
477
                   unsigned int verbose,
478
                   float priorWeight,
479
                   unsigned int mycoption,
480
                   unsigned int maxBasedThresh)
481
{
482
  typedef float RealType;
483
  typedef unsigned int IntType;
484
  return  sccanCppHelperV2<ImageType,IntType,RealType>(
485
        X,
486
        Y,
487
        maskXimage,
488
        maskYimage,
489
        maskxisnull,
490
        maskyisnull,
491
        sparsenessx,
492
        sparsenessy,
493
        nvecs,
494
        its,
495
        cthreshx,
496
        cthreshy,
497
        z,
498
        smooth,
499
        initializationListx,
500
        initializationListy,
501
        mycoption,
502
        ell1,
503
        verbose,
504
        priorWeight,
505
        maxBasedThresh
506
        );
507
508
}
509
510
void local_sccaner(nb::module_ &m)
511
{
512
  m.def("sccanCpp2D", &sccanCpp<itk::Image<float, 2>>);
513
  m.def("sccanCpp3D", &sccanCpp<itk::Image<float, 3>>);
514
  m.def("sccanCpp2DV2", &sccanCppV2<itk::Image<float, 2>>);
515
  m.def("sccanCpp3DV2", &sccanCppV2<itk::Image<float, 3>>);
516
}