Diff of /brainchop-mainthread.js [000000] .. [b86468]

Switch to unified view

a b/brainchop-mainthread.js
1
import * as tf from '@tensorflow/tfjs'
2
import { inferenceModelsList } from './brainchop-parameters.js'
3
import {
4
  addZeroPaddingTo3dTensor,
5
  applyMriThreshold,
6
  binarizeVolumeDataTensor,
7
  convByOutputChannelAndInputSlicing,
8
  draw3dObjBoundingVolume,
9
  firstLastNonZero3D,
10
  generateBrainMask,
11
  generateOutputSlicesV2,
12
  getAllSlicesDataAsTF3D,
13
  getModelNumLayers,
14
  getModelNumParameters,
15
  isModelChnlLast,
16
  load_model,
17
  minMaxNormalizeVolumeData,
18
  quantileNormalizeVolumeData,
19
  removeZeroPaddingFrom3dTensor,
20
  resizeWithZeroPadding,
21
  SequentialConvLayer
22
} from './tensor-utils.js'
23
24
async function inferenceFullVolumeSeqCovLayerPhase2(
25
  opts,
26
  modelEntry,
27
  model,
28
  slices_3d,
29
  num_of_slices,
30
  slice_height,
31
  slice_width,
32
  pipeline1_out,
33
  callbackUI,
34
  callbackImg,
35
  statData,
36
  niftiImage
37
) {
38
  // --Phase-2, After remove the skull try to allocate brain volume and make inferece
39
40
  console.log(' ---- Start FullVolume Inference with Sequential Conv Layer for phase-II ---- ')
41
  // console.log("BOB", callbackUI); console.log("UNCLE",callbackImg); return
42
  const quantileNorm = modelEntry.enableQuantileNorm
43
  if (quantileNorm) {
44
    // Quantile normalize function needs specific models to be used
45
    console.log('preModel Quantile normalization enabled')
46
    slices_3d = await quantileNormalizeVolumeData(slices_3d)
47
  } else {
48
    // Min Max Nomalize MRI data to be from 0 to 1
49
    console.log('preModel Min Max normalization enabled')
50
    slices_3d = await minMaxNormalizeVolumeData(slices_3d)
51
  }
52
53
  let mask_3d
54
55
  if (pipeline1_out == null) {
56
    // preModel is null
57
58
    // Check if thresholding the MRI to remove noisy voxels for better cropping is needed.
59
    const autoThresholdValue = modelEntry.autoThreshold
60
61
    if (autoThresholdValue > 0 && autoThresholdValue <= 1) {
62
      // Filtered MRI from noisy voxel below  autoThresholdValue
63
      mask_3d = await applyMriThreshold(slices_3d, autoThresholdValue)
64
    } else {
65
      console.log('No valid crop threshold value')
66
      // binarize original image
67
      mask_3d = await slices_3d.greater([0]).asType('bool')
68
    }
69
  } else {
70
    mask_3d = await pipeline1_out.greater([0]).asType('bool')
71
    // -- pipeline1_out.dispose();
72
  }
73
  console.log(' mask_3d shape :  ', mask_3d.shape)
74
  const [row_min, row_max, col_min, col_max, depth_min, depth_max] = await firstLastNonZero3D(mask_3d)
75
  mask_3d.dispose()
76
  // -- Reference voxel that cropped volume started slice with it
77
  const refVoxel = [row_min, col_min, depth_min]
78
  // -- Starting form refVoxel, size of bounding volume
79
  const boundVolSizeArr = [row_max - row_min + 1, col_max - col_min + 1, depth_max - depth_min + 1]
80
81
  // -- Extract 3d object (e.g. brain)
82
  const cropped_slices_3d = await slices_3d.slice(
83
    [row_min, col_min, depth_min],
84
    [row_max - row_min + 1, col_max - col_min + 1, depth_max - depth_min + 1]
85
  )
86
  slices_3d.dispose()
87
88
  // -- Padding size add to cropped brain
89
  const pad = modelEntry.cropPadding
90
91
  // Create margin around the bounding volume
92
  let cropped_slices_3d_w_pad = await addZeroPaddingTo3dTensor(cropped_slices_3d, [pad, pad], [pad, pad], [pad, pad])
93
  console.log(' cropped slices_3d with padding shape:  ', cropped_slices_3d_w_pad.shape)
94
95
  cropped_slices_3d.dispose()
96
97
  if (opts.drawBoundingVolume) {
98
    let testVol = await removeZeroPaddingFrom3dTensor(cropped_slices_3d_w_pad, pad, pad, pad)
99
    console.log(' outLabelVolume without padding shape :  ', testVol.shape)
100
101
    testVol = await resizeWithZeroPadding(testVol, num_of_slices, slice_height, slice_width, refVoxel, boundVolSizeArr)
102
    console.log(' outLabelVolume final shape after resizing :  ', testVol.shape)
103
104
    draw3dObjBoundingVolume(tf.unstack(testVol), opts, modelEntry, callbackImg)
105
    testVol.dispose()
106
107
    return 0
108
  }
109
110
  statData.Brainchop_Ver = 'FullVolume'
111
  // model.then(function (res) {
112
  // console.log("--->>>>", opts.drawBoundingVolume); return
113
  const res = await model
114
  try {
115
    let startTime = performance.now()
116
    const inferenceStartTime = performance.now()
117
    // maxLabelPredicted in whole volume of the brain
118
    let maxLabelPredicted = 0
119
    const transpose = modelEntry.enableTranspose
120
    const delay = modelEntry.inferenceDelay
121
    console.log('Inference delay :', delay)
122
123
    if (transpose) {
124
      cropped_slices_3d_w_pad = await cropped_slices_3d_w_pad.transpose()
125
      console.log('Input transposed for pre-model')
126
    } else {
127
      console.log('Transpose not enabled for pre-model')
128
    }
129
130
    let i = 1
131
    const layersLength = res.layers.length
132
    console.log('res.layers.length ', layersLength)
133
134
    const isChannelLast = isModelChnlLast(res)
135
    const batchSize = opts.batchSize
136
    const numOfChan = opts.numOfChan
137
    let adjusted_input_shape
138
    // -- Adjust model input shape
139
    if (isChannelLast) {
140
      res.layers[0].batchInputShape[1] = cropped_slices_3d_w_pad.shape[0]
141
      res.layers[0].batchInputShape[2] = cropped_slices_3d_w_pad.shape[1]
142
      res.layers[0].batchInputShape[3] = cropped_slices_3d_w_pad.shape[2]
143
144
      adjusted_input_shape = [
145
        batchSize,
146
        res.layers[0].batchInputShape[1],
147
        res.layers[0].batchInputShape[2],
148
        res.layers[0].batchInputShape[3],
149
        numOfChan
150
      ]
151
    } else {
152
      res.layers[0].batchInputShape[2] = cropped_slices_3d_w_pad.shape[0]
153
      res.layers[0].batchInputShape[3] = cropped_slices_3d_w_pad.shape[1]
154
      res.layers[0].batchInputShape[4] = cropped_slices_3d_w_pad.shape[2]
155
156
      adjusted_input_shape = [
157
        batchSize,
158
        numOfChan,
159
        res.layers[0].batchInputShape[2],
160
        res.layers[0].batchInputShape[3],
161
        res.layers[0].batchInputShape[4]
162
      ]
163
    }
164
165
    console.log(' Model batch input shape : ', res.layers[0].batchInputShape)
166
    // -- batchInputShape {Array} input_shape - e.g. [?, D, H, W, Ch] or [?, Ch, D, H, W]
167
168
    statData.Input_Shape = JSON.stringify(res.layers[0].batchInputShape)
169
    statData.Output_Shape = JSON.stringify(res.output.shape)
170
    statData.Channel_Last = isChannelLast
171
    statData.Model_Param = await getModelNumParameters(res)
172
    statData.Model_Layers = await getModelNumLayers(res)
173
    statData.Model = modelEntry.modelName
174
    statData.Seq_Conv = modelEntry.enableSeqConv
175
    statData.Extra_Info = null
176
177
    // Determine the number of output channels in the last layer of the model
178
    //  e.g. 3, 50, 104
179
    const outputLayer = res.layers[res.layers.length - 1]
180
    console.log('Output Layer : ', outputLayer)
181
182
    const expected_Num_labels = isChannelLast
183
      ? outputLayer.outputShape[outputLayer.outputShape.length - 1]
184
      : outputLayer.outputShape[1]
185
    console.log('Num of output channels : ', expected_Num_labels)
186
187
    const curTensor = []
188
    curTensor[0] = await cropped_slices_3d_w_pad.reshape(adjusted_input_shape)
189
    const timer = window.setInterval(async function () {
190
      try {
191
        if (res.layers[i].activation.getClassName() !== 'linear') {
192
          curTensor[i] = await res.layers[i].apply(curTensor[i - 1])
193
        } else {
194
          curTensor[i] = await convByOutputChannelAndInputSlicing(
195
            curTensor[i - 1],
196
            res.layers[i].getWeights()[0],
197
            res.layers[i].getWeights()[1],
198
            res.layers[i].strides,
199
            res.layers[i].padding,
200
            res.layers[i].dilationRate,
201
            3
202
          ) // important for memory use
203
        }
204
        tf.dispose(curTensor[i - 1])
205
      } catch (err) {
206
        const errTxt = 'Your graphics card (e.g. Intel) may not be compatible with WebGL. ' + err.message
207
        callbackUI(errTxt, -1, errTxt)
208
209
        window.clearInterval(timer)
210
        tf.engine().endScope()
211
        tf.engine().disposeVariables()
212
213
        statData.Inference_t = Infinity
214
        statData.Postprocess_t = Infinity
215
        statData.Status = 'Fail'
216
        statData.Error_Type = err.message
217
        statData.Extra_Err_Info = 'Failed while model layer ' + i + ' apply'
218
219
        callbackUI('', -1, '', statData)
220
221
        return 0
222
      }
223
224
      console.log('layer output Tensor shape : ', curTensor[i].shape)
225
      console.log('layer count params ', res.layers[i].countParams())
226
227
      res.layers[i].dispose()
228
      curTensor[i - 1].dispose()
229
230
      callbackUI('Layer ' + i.toString(), (i + 1) / layersLength)
231
      if (tf.memory().unreliable) {
232
        const unreliableReasons = 'unreliable reasons :' + tf.memory().reasons
233
        callbackUI(unreliableReasons, NaN, unreliableReasons)
234
      }
235
      if (i === layersLength - 2) {
236
        // Stop before the last layer or classification layer.
237
238
        window.clearInterval(timer)
239
        // // Create an instance of SequentialConvLayer
240
        // The second parameter is important for memory,
241
        // the larger it is, the more memory it uses
242
        // it was 8, but I set it to 3, got a different error
243
        // let seqConvLayer = new SequentialConvLayer(res, 10, isChannelLast);
244
        const seqConvLayer = await new SequentialConvLayer(res, 10, isChannelLast, callbackUI, false)
245
        // Apply the last output tensor to the seq. instance
246
        const outputTensor = await seqConvLayer.apply(curTensor[i])
247
        callbackUI('seqConvLayer Done')
248
        // -- document.getElementById("progressBarChild").style.width = 0 + "%";;
249
250
        // Dispose the previous layer input tensor
251
        tf.dispose(curTensor[i])
252
        // delete the used class
253
        // ? delete seqConvLayer;
254
255
        // You can now use 'outputTensor' as needed
256
        console.log(' Output tensor', outputTensor)
257
        console.log(' Output tensor shape : ', outputTensor.shape)
258
        // Array(3) [ 256, 256, 256 ]
259
260
        if (outputTensor.shape.length !== 3) {
261
          const msg = 'Output tensor shape should be 3 dims but it is ' + outputTensor.shape.length
262
          callbackUI(msg, -1, msg)
263
        }
264
265
        console.log(' find array max ')
266
        const curBatchMaxLabel = await outputTensor.max().dataSync()[0]
267
        const Inference_t = ((performance.now() - startTime) / 1000).toFixed(4)
268
269
        if (maxLabelPredicted < curBatchMaxLabel) {
270
          maxLabelPredicted = curBatchMaxLabel
271
        }
272
273
        const numSegClasses = maxLabelPredicted + 1
274
        console.log('Predicted num of segmentation classes', numSegClasses)
275
        statData.Actual_Labels = numSegClasses
276
        statData.Expect_Labels = expected_Num_labels
277
        statData.NumLabels_Match = numSegClasses === expected_Num_labels
278
        if (numSegClasses !== expected_Num_labels) {
279
          const msg = 'expected ' + expected_Num_labels + ' labels, but the predicted are ' + numSegClasses
280
          callbackUI(msg, -1, msg)
281
        }
282
283
        // -- Transpose back to original unpadded size
284
        let outLabelVolume = outputTensor.reshape([
285
          cropped_slices_3d_w_pad.shape[0],
286
          cropped_slices_3d_w_pad.shape[1],
287
          cropped_slices_3d_w_pad.shape[2]
288
        ])
289
        tf.dispose(outputTensor)
290
291
        // Transpose MRI data to be match pytorch/keras input output
292
        if (transpose) {
293
          console.log('outLabelVolume transposed')
294
          outLabelVolume = outLabelVolume.transpose()
295
        }
296
297
        outLabelVolume = await removeZeroPaddingFrom3dTensor(outLabelVolume, pad, pad, pad)
298
        console.log(' outLabelVolume without padding shape :  ', outLabelVolume.shape)
299
        outLabelVolume = await resizeWithZeroPadding(
300
          outLabelVolume,
301
          num_of_slices,
302
          slice_height,
303
          slice_width,
304
          refVoxel,
305
          boundVolSizeArr
306
        )
307
        console.log(' outLabelVolume final shape after resizing :  ', outLabelVolume.shape)
308
309
        // let filterOutWithPreMask =  inferenceModelsList[$$("selectModel").getValue() - 1]["filterOutWithPreMask"];
310
        const filterOutWithPreMask = modelEntry.filterOutWithPreMask
311
        // To clean the skull area wrongly segmented inphase-2.
312
        if (pipeline1_out != null && opts.isBrainCropMaskBased && filterOutWithPreMask) {
313
          const bin = await binarizeVolumeDataTensor(pipeline1_out)
314
          outLabelVolume = await outLabelVolume.mul(bin)
315
        }
316
317
        startTime = performance.now()
318
        // Generate output volume or slices
319
        console.log('Generating correct output')
320
        let outimg
321
        try {
322
          const img = await new Uint32Array(outLabelVolume.dataSync())
323
          const Vshape = outLabelVolume.shape
324
          const Vtype = outLabelVolume.dtype
325
          outimg = await generateOutputSlicesV2(
326
            img,
327
            Vshape,
328
            Vtype,
329
            num_of_slices,
330
            numSegClasses,
331
            slice_height,
332
            slice_width,
333
            modelEntry,
334
            opts,
335
            niftiImage
336
          )
337
          console.log(' Phase-2 num of tensors after generateOutputSlicesV2: ', tf.memory().numTensors)
338
339
          tf.dispose(outLabelVolume)
340
          tf.engine().endScope()
341
          tf.engine().disposeVariables()
342
        } catch (error) {
343
          // -- Timing data to collect
344
          tf.engine().endScope()
345
          tf.engine().disposeVariables()
346
          console.log('Error while generating output: ', error)
347
          const msg = 'Failed while generating output due to limited browser memory available'
348
          callbackUI(msg, -1, msg)
349
350
          statData.Inference_t = Inference_t
351
          statData.Postprocess_t = Infinity
352
          statData.Status = 'Fail'
353
          statData.Error_Type = error.message
354
          statData.Extra_Err_Info = 'Failed while generating output'
355
356
          callbackUI('', -1, '', statData)
357
358
          return 0
359
        }
360
        const Postprocess_t = ((performance.now() - startTime) / 1000).toFixed(4)
361
362
        console.log(
363
          'Processing the whole brain volume in tfjs for multi-class output mask took : ',
364
          ((performance.now() - inferenceStartTime) / 1000).toFixed(4) + '  Seconds'
365
        )
366
367
        // -- Timing data to collect
368
        statData.Inference_t = Inference_t
369
        statData.Postprocess_t = Postprocess_t
370
        statData.Status = 'OK'
371
372
        callbackUI('', -1, '', statData)
373
        callbackUI('Segmentation finished', 0)
374
        callbackImg(outimg, opts, modelEntry)
375
        return 0
376
      } else {
377
        i++
378
      }
379
    }, delay)
380
  } catch (err) {
381
    callbackUI(err.message, -1, err.message)
382
    console.log(
383
      'If webgl context is lost, try to restore webgl context by visit the link ' +
384
        '<a href="https://support.biodigital.com/hc/en-us/articles/218322977-How-to-turn-on-WebGL-in-my-browser">here</a>'
385
    )
386
    if (tf.memory().unreliable) {
387
      const unreliableReasons = 'unreliable reasons :' + tf.memory().reasons
388
      callbackUI(unreliableReasons, NaN, unreliableReasons)
389
    }
390
  }
391
}
392
393
async function inferenceFullVolumePhase2(
394
  model,
395
  slices_3d,
396
  num_of_slices,
397
  slice_height,
398
  slice_width,
399
  pipeline1_out,
400
  modelEntry,
401
  statData,
402
  opts,
403
  callbackImg,
404
  callbackUI,
405
  niftiImage
406
) {
407
  let outimg = []
408
  // --Phase-2, After remove the skull try to allocate brain volume and make inferece
409
  console.log(' ---- Start FullVolume inference phase-II ---- ')
410
  const quantileNorm = modelEntry.enableQuantileNorm
411
  if (quantileNorm) {
412
    // Quantile normalize function needs specific models to be used
413
    console.log('preModel Quantile normalization enabled')
414
    slices_3d = await quantileNormalizeVolumeData(slices_3d)
415
  } else {
416
    // Min Max Nomalize MRI data to be from 0 to 1
417
    console.log('preModel Min Max normalization enabled')
418
    slices_3d = await minMaxNormalizeVolumeData(slices_3d)
419
  }
420
  let mask_3d
421
  if (pipeline1_out == null) {
422
    // preModel is null
423
424
    // Check if thresholding the MRI to remove noisy voxels for better cropping is needed.
425
    const autoThresholdValue = modelEntry.autoThreshold
426
427
    if (autoThresholdValue > 0 && autoThresholdValue <= 1) {
428
      // Filtered MRI from noisy voxel below  autoThresholdValue
429
      mask_3d = await applyMriThreshold(slices_3d, autoThresholdValue)
430
    } else {
431
      console.log('No valid crop threshold value')
432
      // binarize original image
433
      mask_3d = await slices_3d.greater([0]).asType('bool')
434
    }
435
  } else {
436
    mask_3d = pipeline1_out.greater([0]).asType('bool')
437
    // -- pipeline1_out.dispose()
438
  }
439
  console.log(' mask_3d shape :  ', mask_3d.shape)
440
  const [row_min, row_max, col_min, col_max, depth_min, depth_max] = await firstLastNonZero3D(mask_3d)
441
  mask_3d.dispose()
442
  // -- Reference voxel that cropped volume started slice with it
443
  const refVoxel = [row_min, col_min, depth_min]
444
  console.log('refVoxel :', refVoxel)
445
446
  // -- Starting form refVoxel, size of bounding volume
447
  const boundVolSizeArr = [row_max - row_min + 1, col_max - col_min + 1, depth_max - depth_min + 1]
448
449
  console.log('boundVolSizeArr :', boundVolSizeArr)
450
451
  // -- Extract 3d object (e.g. brain)
452
  const cropped_slices_3d = slices_3d.slice(
453
    [row_min, col_min, depth_min],
454
    [row_max - row_min + 1, col_max - col_min + 1, depth_max - depth_min + 1]
455
  )
456
457
  slices_3d.dispose()
458
459
  // -- Padding size add to cropped brain
460
  const pad = modelEntry.cropPadding
461
462
  // Create margin around the bounding volume
463
  let cropped_slices_3d_w_pad = await addZeroPaddingTo3dTensor(cropped_slices_3d, [pad, pad], [pad, pad], [pad, pad])
464
  console.log(' cropped slices_3d with padding shape:  ', cropped_slices_3d_w_pad.shape)
465
466
  cropped_slices_3d.dispose()
467
468
  // -- Test dim after padding ..
469
  // for (let i = 0; i < cropped_slices_3d_w_pad.rank; i++) {
470
  //     if(cropped_slices_3d_w_pad.shape[i] > 256) {
471
  //          console.log(" cropped_slices_3d_w_pad > 256 ")
472
  //     }
473
474
  // }
475
476
  if (opts.drawBoundingVolume) {
477
    let testVol = await removeZeroPaddingFrom3dTensor(cropped_slices_3d_w_pad, pad, pad, pad)
478
    console.log(' outLabelVolume without padding shape :  ', testVol.shape)
479
480
    testVol = await resizeWithZeroPadding(testVol, num_of_slices, slice_height, slice_width, refVoxel, boundVolSizeArr)
481
    console.log(' outLabelVolume final shape after resizing :  ', testVol.shape)
482
    draw3dObjBoundingVolume(tf.unstack(testVol), opts, modelEntry, callbackImg)
483
484
    testVol.dispose()
485
486
    return 0
487
  }
488
489
  statData.Brainchop_Ver = 'FullVolume'
490
  let startTime = performance.now()
491
  let adjusted_input_shape = []
492
  const res = await model
493
  try {
494
    startTime = performance.now()
495
    const inferenceStartTime = performance.now()
496
    // maxLabelPredicted in whole volume of the brain
497
    let maxLabelPredicted = 0
498
    const transpose = modelEntry.enableTranspose
499
    const delay = modelEntry.inferenceDelay
500
    console.log('Inference delay :', delay)
501
502
    if (transpose) {
503
      cropped_slices_3d_w_pad = cropped_slices_3d_w_pad.transpose()
504
      console.log('Input transposed for pre-model')
505
    } else {
506
      console.log('Transpose not enabled for pre-model')
507
    }
508
509
    let i = 1
510
    const layersLength = res.layers.length
511
    console.log('res.layers.length ', layersLength)
512
513
    const isChannelLast = await isModelChnlLast(res)
514
    const batchSize = opts.batchSize
515
    const numOfChan = opts.numOfChan
516
517
    // -- Adjust model input shape
518
    if (isChannelLast) {
519
      res.layers[0].batchInputShape[1] = cropped_slices_3d_w_pad.shape[0]
520
      res.layers[0].batchInputShape[2] = cropped_slices_3d_w_pad.shape[1]
521
      res.layers[0].batchInputShape[3] = cropped_slices_3d_w_pad.shape[2]
522
523
      adjusted_input_shape = [
524
        batchSize,
525
        res.layers[0].batchInputShape[1],
526
        res.layers[0].batchInputShape[2],
527
        res.layers[0].batchInputShape[3],
528
        numOfChan
529
      ]
530
    } else {
531
      res.layers[0].batchInputShape[2] = cropped_slices_3d_w_pad.shape[0]
532
      res.layers[0].batchInputShape[3] = cropped_slices_3d_w_pad.shape[1]
533
      res.layers[0].batchInputShape[4] = cropped_slices_3d_w_pad.shape[2]
534
535
      adjusted_input_shape = [
536
        batchSize,
537
        numOfChan,
538
        res.layers[0].batchInputShape[2],
539
        res.layers[0].batchInputShape[3],
540
        res.layers[0].batchInputShape[4]
541
      ]
542
    }
543
544
    console.log(' Model batch input shape : ', res.layers[0].batchInputShape)
545
    // -- batchInputShape {Array} input_shape - e.g. [?, D, H, W, Ch] or [?, Ch, D, H, W]
546
547
    statData.Input_Shape = JSON.stringify(res.layers[0].batchInputShape)
548
    statData.Output_Shape = JSON.stringify(res.output.shape)
549
    statData.Channel_Last = isChannelLast
550
    statData.Model_Param = await getModelNumParameters(res)
551
    statData.Model_Layers = await getModelNumLayers(res)
552
    statData.Model = modelEntry.modelName
553
    statData.Extra_Info = null
554
555
    const curTensor = []
556
    curTensor[0] = cropped_slices_3d_w_pad.reshape(adjusted_input_shape)
557
    const timer = window.setInterval(async function () {
558
      try {
559
        curTensor[i] = res.layers[i].apply(curTensor[i - 1])
560
      } catch (err) {
561
        callbackUI(err.message, -1, err.message)
562
        window.clearInterval(timer)
563
        tf.engine().endScope()
564
        tf.engine().disposeVariables()
565
566
        statData.Inference_t = Infinity
567
        statData.Postprocess_t = Infinity
568
        statData.Status = 'Fail'
569
        statData.Error_Type = err.message
570
        statData.Extra_Err_Info = 'Failed while model layer ' + i + ' apply'
571
572
        callbackUI('', -1, '', statData)
573
574
        return 0
575
      }
576
      callbackUI('Layer ' + i.toString(), (i + 1) / layersLength)
577
      console.log('layer output Tensor shape : ', curTensor[i].shape)
578
      console.log('layer count params ', res.layers[i].countParams())
579
      res.layers[i].dispose()
580
      curTensor[i - 1].dispose()
581
      if (tf.memory().unreliable) {
582
        const unreliableReasons = 'unreliable reasons :' + tf.memory().reasons
583
        callbackUI(unreliableReasons, NaN, unreliableReasons)
584
      }
585
      if (i === layersLength - 1) {
586
        window.clearInterval(timer)
587
588
        const axis = isChannelLast ? -1 : 1
589
        console.log(' find argmax ')
590
        console.log('last Tensor shape : ', curTensor[i].shape)
591
        // -- curTensor[i].shape  e.g. [ 1, 256, 256, 256, 3 ]
592
        const expected_Num_labels = isChannelLast ? curTensor[i].shape[4] : curTensor[i].shape[1]
593
        let prediction_argmax
594
595
        // Try for argMax with model output tensor.
596
597
        try {
598
          const argMaxTime = performance.now()
599
          console.log(' Try tf.argMax for fullVolume ..')
600
          prediction_argmax = tf.argMax(curTensor[i], axis)
601
          console.log('tf.argMax for fullVolume takes : ', ((performance.now() - argMaxTime) / 1000).toFixed(4))
602
        } catch (err1) {
603
          // if channel last
604
          if (axis === -1) {
605
            try {
606
              const argMaxLargeTime = performance.now()
607
              console.log(' tf.argMax failed .. try argMaxLarge ..')
608
              window.alert('tensor2LightBuffer() is not dead code?')
609
              window.alert('argMaxLarge() is not dead code?')
610
              console.log(
611
                'argMaxLarge for fullVolume takes : ',
612
                ((performance.now() - argMaxLargeTime) / 1000).toFixed(4)
613
              )
614
            } catch (err2) {
615
              const errTxt = "argMax buffer couldn't be created due to limited memory resources."
616
              callbackUI(errTxt, -1, errTxt)
617
618
              window.clearInterval(timer)
619
              tf.engine().endScope()
620
              tf.engine().disposeVariables()
621
622
              statData.Inference_t = Infinity
623
              statData.Postprocess_t = Infinity
624
              statData.Status = 'Fail'
625
              statData.Error_Type = err2.message
626
              statData.Extra_Err_Info = 'prediction_argmax from argMaxLarge failed'
627
628
              callbackUI('', -1, '', statData)
629
630
              return 0
631
            }
632
          } else {
633
            // if channel first ..
634
            const errTxt = "argMax buffer couldn't be created due to limited memory resources."
635
            callbackUI(errTxt, -1, errTxt)
636
637
            prediction_argmax.dispose()
638
639
            window.clearInterval(timer)
640
            tf.engine().endScope()
641
            tf.engine().disposeVariables()
642
643
            statData.Inference_t = Infinity
644
            statData.Postprocess_t = Infinity
645
            statData.Status = 'Fail'
646
            statData.Error_Type = err1.message
647
            statData.Extra_Err_Info = 'prediction_argmax from argMaxLarge not support yet channel first'
648
649
            callbackUI('', -1, '', statData)
650
651
            return 0
652
          }
653
        }
654
655
        console.log(' prediction_argmax shape : ', prediction_argmax.shape)
656
        // -- prediction_argmax.shape  : [ 1, 256, 256, 256]
657
658
        const Inference_t = ((performance.now() - startTime) / 1000).toFixed(4)
659
660
        // outputDataBeforArgmx = Array.from(prediction_argmax.dataSync())
661
        tf.dispose(curTensor[i])
662
        // allPredictions.push({"id": allBatches[j].id, "coordinates": allBatches[j].coordinates, "data": Array.from(prediction_argmax.dataSync()) })
663
        const curBatchMaxLabel = await prediction_argmax.max().dataSync()[0]
664
        if (maxLabelPredicted < curBatchMaxLabel) {
665
          maxLabelPredicted = curBatchMaxLabel
666
        }
667
668
        const numSegClasses = maxLabelPredicted + 1
669
        console.log('numSegClasses', numSegClasses)
670
        statData.Actual_Labels = numSegClasses
671
        statData.Expect_Labels = expected_Num_labels
672
        statData.NumLabels_Match = numSegClasses === expected_Num_labels
673
674
        if (numSegClasses !== expected_Num_labels) {
675
          // errTxt = "expected " + expected_Num_labels + " labels, but the predicted are " + numSegClasses + ". For possible solutions please refer to <a href='https://github.com/neuroneural/brainchop/wiki/FAQ#Q3' target='_blank'><b> FAQ </b></a>.", "alert-error"
676
          const errTxt = 'expected ' + expected_Num_labels + ' labels, but the predicted are ' + numSegClasses
677
          callbackUI(errTxt, -1, errTxt)
678
        }
679
680
        // -- Transpose back to original unpadded size
681
        let outLabelVolume = prediction_argmax.reshape([
682
          cropped_slices_3d_w_pad.shape[0],
683
          cropped_slices_3d_w_pad.shape[1],
684
          cropped_slices_3d_w_pad.shape[2]
685
        ])
686
        tf.dispose(prediction_argmax)
687
688
        // Transpose MRI data to be match pytorch/keras input output
689
        if (transpose) {
690
          console.log('outLabelVolume transposed')
691
          outLabelVolume = outLabelVolume.transpose()
692
        }
693
        outLabelVolume = await removeZeroPaddingFrom3dTensor(outLabelVolume, pad, pad, pad)
694
        console.log(' outLabelVolume without padding shape :  ', outLabelVolume.shape)
695
        outLabelVolume = await resizeWithZeroPadding(
696
          outLabelVolume,
697
          num_of_slices,
698
          slice_height,
699
          slice_width,
700
          refVoxel,
701
          boundVolSizeArr
702
        )
703
        console.log(' outLabelVolume final shape after resizing :  ', outLabelVolume.shape)
704
705
        const filterOutWithPreMask = modelEntry.filterOutWithPreMask
706
        // To clean the skull area wrongly segmented in phase-2.
707
        if (pipeline1_out != null && opts.isBrainCropMaskBased && filterOutWithPreMask) {
708
          const bin = binarizeVolumeDataTensor(pipeline1_out)
709
          outLabelVolume = outLabelVolume.mul(bin)
710
        }
711
712
        startTime = performance.now()
713
        // Generate output volume or slices
714
        console.log('Generating correct output')
715
716
        try {
717
          const img = new Uint32Array(outLabelVolume.dataSync())
718
          const Vshape = outLabelVolume.shape
719
          const Vtype = outLabelVolume.dtype
720
          tf.dispose(outLabelVolume)
721
          tf.engine().endScope()
722
          tf.engine().disposeVariables()
723
          outimg = await generateOutputSlicesV2(
724
            img,
725
            Vshape,
726
            Vtype,
727
            num_of_slices,
728
            numSegClasses,
729
            slice_height,
730
            slice_width,
731
            modelEntry,
732
            opts,
733
            niftiImage
734
          )
735
          console.log(' Phase-2 num of tensors after generateOutputSlicesV2: ', tf.memory().numTensors)
736
        } catch (error) {
737
          // -- Timing data to collect
738
          tf.engine().endScope()
739
          tf.engine().disposeVariables()
740
741
          const errTxt = 'Failed while generating output due to limited browser memory available'
742
          callbackUI(errTxt, -1, errTxt)
743
          statData.Inference_t = Inference_t
744
          statData.Postprocess_t = Infinity
745
          statData.Status = 'Fail'
746
          statData.Error_Type = error.message
747
          statData.Extra_Err_Info = 'Failed while generating output'
748
749
          callbackUI('', -1, '', statData)
750
751
          return 0
752
        }
753
754
        const Postprocess_t = ((performance.now() - startTime) / 1000).toFixed(4)
755
756
        // tf.engine().endScope()
757
        tf.engine().disposeVariables()
758
759
        console.log(
760
          'Processing the whole brain volume in tfjs for multi-class output mask took : ',
761
          ((performance.now() - inferenceStartTime) / 1000).toFixed(4) + '  Seconds'
762
        )
763
764
        // -- Timing data to collect
765
        statData.Inference_t = Inference_t
766
        statData.Postprocess_t = Postprocess_t
767
        statData.Status = 'OK'
768
769
        callbackUI('', -1, '', statData)
770
        clearInterval(timer)
771
        callbackUI('Segmentation finished', 0)
772
        callbackImg(outimg, opts, modelEntry)
773
        return 0
774
      }
775
      i++
776
    }, delay)
777
  } catch (err) {
778
    callbackUI(err.message, -1, err.message)
779
    console.log(
780
      'If webgl context is lost, try to restore webgl context by visit the link ' +
781
        '<a href="https://support.biodigital.com/hc/en-us/articles/218322977-How-to-turn-on-WebGL-in-my-browser">here</a>'
782
    )
783
  }
784
  //            })
785
}
786
787
async function inferenceFullVolumePhase1(
788
  model,
789
  slices_3d,
790
  num_of_slices,
791
  slice_height,
792
  slice_width,
793
  isModelFullVol,
794
  modelEntry,
795
  statData,
796
  opts,
797
  callbackImg,
798
  callbackUI,
799
  niftiImage
800
) {
801
  statData.No_SubVolumes = 1
802
  // load pre-model for inference first, can be null if no pre-model such as GWM models
803
  if (modelEntry.preModelId) {
804
    const preModel = await load_model(opts.rootURL + inferenceModelsList[modelEntry.preModelId - 1].path)
805
    const transpose = inferenceModelsList[modelEntry.preModelId - 1].enableTranspose
806
    const quantileNorm = inferenceModelsList[modelEntry.preModelId - 1].enableQuantileNorm
807
    let preModel_slices_3d = null
808
809
    // -- If pre-model is not null then slices_3d mask will be generated..
810
    // -- The mask is needed to remove the skull and set noise in background to 0, and get the brain bounding volume properly
811
    const slices_3d_mask = null
812
813
    if (quantileNorm) {
814
      // Quantile normalize function needs specific models to be used
815
      console.log('preModel Quantile normalization enabled')
816
      preModel_slices_3d = await quantileNormalizeVolumeData(slices_3d)
817
    } else {
818
      // Min Max Nomalize MRI data to be from 0 to 1
819
      console.log('preModel Min Max normalization enabled')
820
      preModel_slices_3d = await minMaxNormalizeVolumeData(slices_3d)
821
    }
822
823
    // -- Transpose MRI data to be match pytorch/keras input output
824
    // -- Check if pre-model needs transpose..
825
    if (transpose) {
826
      preModel_slices_3d = await preModel_slices_3d.transpose()
827
      console.log('Input transposed for pre-model')
828
    } else {
829
      console.log('Transpose not enabled for pre-model')
830
    }
831
832
    statData.Brainchop_Ver = 'PreModel_FV' // e.g. "PreModel_FV"
833
834
    // preModel.then(function (res) {
835
    const res = await preModel
836
837
    try {
838
      const inferenceStartTime = performance.now()
839
      const preModelObject = res
840
841
      // read input shape from model.json object
842
      const preModelBatchInputShape = preModelObject.layers[0].batchInputShape
843
      console.log(' Pre-Model batch input shape : ', preModelBatchInputShape)
844
845
      // -- Verify input shape
846
      if (preModelBatchInputShape.length !== 5) {
847
        const errTxt = 'The pre-model input shape must be 5D '
848
        callbackUI(errTxt, -1, errTxt)
849
        return 0
850
      }
851
852
      const isPreModelChannelLast = isModelChnlLast(preModelObject)
853
      const batchSize = opts.batchSize
854
      const numOfChan = opts.numOfChan
855
      let batch_D, batch_H, batch_W
856
      let preModel_input_shape
857
      if (isPreModelChannelLast) {
858
        console.log('Pre-Model Channel Last')
859
        if (isNaN(preModelBatchInputShape[4]) || preModelBatchInputShape[4] !== 1) {
860
          const errTxt = 'The number of channels for pre-model input shape must be 1'
861
          callbackUI(errTxt, -1, errTxt)
862
          return 0
863
        }
864
865
        batch_D = preModelBatchInputShape[1]
866
        batch_H = preModelBatchInputShape[2]
867
        batch_W = preModelBatchInputShape[3]
868
869
        preModel_input_shape = [batchSize, batch_D, batch_H, batch_W, numOfChan]
870
      } else {
871
        console.log('Pre-Model Channel First')
872
        if (isNaN(preModelBatchInputShape[1]) || preModelBatchInputShape[1] !== 1) {
873
          const errTxt = 'The number of channels for pre-model input shape must be 1'
874
          callbackUI(errTxt, -1, errTxt)
875
          return 0
876
        }
877
878
        batch_D = preModelBatchInputShape[2]
879
        batch_H = preModelBatchInputShape[3]
880
        batch_W = preModelBatchInputShape[4]
881
882
        preModel_input_shape = [batchSize, numOfChan, batch_D, batch_H, batch_W]
883
      }
884
885
      statData.Input_Shape = JSON.stringify(preModel_input_shape)
886
      statData.Output_Shape = JSON.stringify(preModelObject.output.shape)
887
      statData.Channel_Last = isPreModelChannelLast
888
      statData.Model_Param = await getModelNumParameters(preModelObject)
889
      statData.Model_Layers = await getModelNumLayers(preModelObject)
890
891
      // maxLabelPredicted in whole volume of the brain
892
      let maxLabelPredicted = 0
893
      const delay = inferenceModelsList[modelEntry.preModelId - 1].inferenceDelay
894
895
      let i = 1
896
      const layersLength = res.layers.length
897
898
      const curTensor = []
899
      // -- reshape MRI to model input shape
900
      curTensor[0] = preModel_slices_3d.reshape(preModel_input_shape)
901
902
      // Dispose the volume
903
      tf.dispose(preModel_slices_3d)
904
905
      const timer = window.setInterval(async function () {
906
        try {
907
          curTensor[i] = await res.layers[i].apply(curTensor[i - 1])
908
        } catch (err) {
909
          const errTxt = 'Your graphics card (e.g. Intel) may not be compatible with WebGL. ' + err.message
910
          callbackUI(errTxt, -1, errTxt)
911
912
          window.clearInterval(timer)
913
          tf.engine().endScope()
914
          tf.engine().disposeVariables()
915
916
          statData.Inference_t = Infinity
917
          statData.Postprocess_t = Infinity
918
          statData.Status = 'Fail'
919
          statData.Error_Type = err.message
920
          statData.Extra_Err_Info = 'PreModel Failed while model layer ' + i + ' apply'
921
922
          callbackUI('', -1, '', statData)
923
924
          return 0
925
        }
926
927
        res.layers[i].dispose()
928
        curTensor[i - 1].dispose()
929
930
        callbackUI('Layer ' + i.toString(), (i + 1) / layersLength)
931
        if (tf.memory().unreliable) {
932
          const unreliableReasons = 'unreliable reasons :' + tf.memory().reasons
933
          callbackUI(unreliableReasons, NaN, unreliableReasons)
934
        }
935
936
        if (i === layersLength - 1) {
937
          window.clearInterval(timer)
938
939
          // -- prediction = res.layers[res.layers.length-1].apply(curTensor[i])
940
          // -- curTensor[i].print()
941
          // -- outputDataBeforArgmx = Array.from(curTensor[i].dataSync())
942
943
          const axis = isPreModelChannelLast ? -1 : 1
944
          console.log(' find argmax ')
945
          console.log('last Tensor shape : ', curTensor[i].shape)
946
          // -- curTensor[i].shape  : [ 1, 256, 256, 256, 3 ]
947
          const expected_Num_labels = isPreModelChannelLast ? curTensor[i].shape[4] : curTensor[i].shape[1]
948
          let prediction_argmax
949
950
          // Try for argMax with model output tensor.
951
952
          try {
953
            console.log(' Try tf.argMax for fullVolume ..')
954
            prediction_argmax = await tf.argMax(curTensor[i], axis)
955
          } catch (err1) {
956
            // if channel last
957
            if (axis === -1) {
958
              try {
959
                const argMaxLargeTime = performance.now()
960
                console.log(' tf.argMax failed .. try argMaxLarge ..')
961
                window.alert('tensor2LightBuffer() is not dead code?')
962
                window.alert('argMaxLarge() is not dead code?')
963
                console.log(
964
                  'argMaxLarge for fullVolume takes : ',
965
                  ((performance.now() - argMaxLargeTime) / 1000).toFixed(4)
966
                )
967
              } catch (err2) {
968
                const errTxt = "argMax buffer couldn't be created due to limited memory resources."
969
                callbackUI(errTxt, -1, errTxt)
970
971
                prediction_argmax.dispose()
972
973
                window.clearInterval(timer)
974
                tf.engine().endScope()
975
                tf.engine().disposeVariables()
976
977
                statData.Inference_t = Infinity
978
                statData.Postprocess_t = Infinity
979
                statData.Status = 'Fail'
980
                statData.Error_Type = err2.message
981
                statData.Extra_Err_Info = 'preModel prediction_argmax from argMaxLarge failed'
982
983
                callbackUI('', -1, '', statData)
984
985
                return 0
986
              }
987
            } else {
988
              // if channel first ..
989
              const errTxt = "argMax buffer couldn't be created due to limited memory resources."
990
              callbackUI(errTxt, -1, errTxt)
991
992
              prediction_argmax.dispose()
993
994
              window.clearInterval(timer)
995
              tf.engine().endScope()
996
              tf.engine().disposeVariables()
997
998
              statData.Inference_t = Infinity
999
              statData.Postprocess_t = Infinity
1000
              statData.Status = 'Fail'
1001
              statData.Error_Type = err1.message
1002
              statData.Extra_Err_Info = 'preModel prediction_argmax from argMaxLarge not support yet channel first'
1003
1004
              callbackUI('', -1, '', statData)
1005
1006
              return 0
1007
            }
1008
          }
1009
          console.log(' Pre-model prediction_argmax shape : ', prediction_argmax.shape)
1010
          // -- prediction_argmax.shape  : [ 1, 256, 256, 256]
1011
1012
          const Inference_t = ((performance.now() - inferenceStartTime) / 1000).toFixed(4)
1013
1014
          tf.dispose(curTensor[i])
1015
1016
          console.log(' Pre-model find array max ')
1017
          const curBatchMaxLabel = await prediction_argmax.max().dataSync()[0]
1018
          if (maxLabelPredicted < curBatchMaxLabel) {
1019
            maxLabelPredicted = curBatchMaxLabel
1020
          }
1021
1022
          const numSegClasses = maxLabelPredicted + 1
1023
          console.log('Pre-model numSegClasses', numSegClasses)
1024
1025
          statData.Actual_Labels = numSegClasses
1026
          statData.Expect_Labels = expected_Num_labels
1027
          statData.NumLabels_Match = numSegClasses === expected_Num_labels
1028
1029
          // -- Transpose back to original unpadded size
1030
          let outLabelVolume = await prediction_argmax.reshape([num_of_slices, slice_height, slice_width])
1031
          tf.dispose(prediction_argmax)
1032
          // Transpose MRI data to be match pytorch/keras input output
1033
          if (transpose) {
1034
            console.log('Pre-model outLabelVolume transposed')
1035
            outLabelVolume = outLabelVolume.transpose()
1036
          }
1037
          const startTime = performance.now()
1038
          // Generate output volume or slices
1039
          console.log('Generating pre-model output')
1040
          let slices_3d_mask
1041
          try {
1042
            const unstackOutVolumeTensor = await tf.unstack(outLabelVolume)
1043
            slices_3d_mask = await generateBrainMask(
1044
              unstackOutVolumeTensor,
1045
              num_of_slices,
1046
              slice_height,
1047
              slice_width,
1048
              modelEntry,
1049
              opts,
1050
              callbackUI,
1051
              callbackImg,
1052
              false
1053
            )
1054
            await tf.dispose(outLabelVolume)
1055
            console.log(' Phase-1 num of tensors after generateBrainMask: ', tf.memory().numTensors)
1056
          } catch (error) {
1057
            // -- Timing data to collect
1058
            tf.engine().endScope()
1059
            tf.engine().disposeVariables()
1060
1061
            const errTxt = 'Failed while generating pre-model output due to limited browser memory available'
1062
            callbackUI(errTxt, -1, errTxt)
1063
1064
            statData.Inference_t = Inference_t
1065
            statData.Postprocess_t = Infinity
1066
            statData.Status = 'Fail'
1067
            statData.Error_Type = error.message
1068
            statData.Extra_Err_Info = 'Pre-model failed while generating output'
1069
1070
            callbackUI('', -1, '', statData)
1071
1072
            return 0
1073
          }
1074
          const Postprocess_t = ((performance.now() - startTime) / 1000).toFixed(4)
1075
          console.log(
1076
            'Pre-model processing the whole brain volume in tfjs tooks for multi-class output mask : ',
1077
            ((performance.now() - inferenceStartTime) / 1000).toFixed(4) + '  Seconds'
1078
          )
1079
1080
          // -- Timing data to collect
1081
          statData.Inference_t = Inference_t
1082
          statData.Postprocess_t = Postprocess_t
1083
          statData.Status = 'OK'
1084
1085
          callbackUI('', -1, '', statData)
1086
1087
          if (slices_3d_mask == null) {
1088
            const msg = 'slice_3d_mask failed ...'
1089
            callbackUI(msg, -1, msg)
1090
            return 0
1091
          } else {
1092
            // --Phase-2, After remove the skull try to allocate brain volume and make inferece
1093
            console.log('--- pre-model done ---')
1094
            // --mask_3d = slices_3d_mask.greater([0]).asType('bool')
1095
            // --slices_3d_mask.dispose()
1096
1097
            if (isModelFullVol) {
1098
              if (modelEntry.enableSeqConv) {
1099
                // Mask cropping & seq conv
1100
                // Non-Atlas model (e.g. GWM) needs sequential convolution layer.
1101
                // Sequential convolution layer to be used after cropping - slow but reliable on most machines
1102
                console.log('------ Mask Cropping & Seq Convoluton ------')
1103
                await inferenceFullVolumeSeqCovLayerPhase2(
1104
                  opts,
1105
                  modelEntry,
1106
                  model,
1107
                  slices_3d,
1108
                  num_of_slices,
1109
                  slice_height,
1110
                  slice_width,
1111
                  slices_3d_mask,
1112
                  callbackUI,
1113
                  callbackImg,
1114
                  statData,
1115
                  niftiImage
1116
                )
1117
                return 0
1118
                // inferenceFullVolumeSeqCovLayerPhase2(model, slices_3d.transpose(), num_of_slices, slice_height, slice_width, slices_3d_mask)
1119
              } else {
1120
                // Mask cropping BUT no seq conv
1121
                console.log('------ Mask Cropping  -  NO Seq Convoluton ------')
1122
                await inferenceFullVolumePhase2(
1123
                  model,
1124
                  slices_3d,
1125
                  num_of_slices,
1126
                  slice_height,
1127
                  slice_width,
1128
                  slices_3d_mask,
1129
                  modelEntry,
1130
                  statData,
1131
                  opts,
1132
                  callbackImg,
1133
                  callbackUI,
1134
                  niftiImage
1135
                )
1136
                // inferenceFullVolumePhase2(model, slices_3d.transpose(), num_of_slices, slice_height, slice_width, slices_3d_mask)
1137
              }
1138
            } else {
1139
              // -- In version 3.0.0 this function not used
1140
              window.alert('inferenceSubVolumes() is not dead code?')
1141
            }
1142
          }
1143
        }
1144
        i++
1145
      }, delay)
1146
    } catch (err) {
1147
      callbackUI(err.message, -1, err.message)
1148
      console.log(
1149
        'If webgl context is lost, try to restore webgl context by visit the link ' +
1150
          '<a href="https://support.biodigital.com/hc/en-us/articles/218322977-How-to-turn-on-WebGL-in-my-browser">here</a>'
1151
      )
1152
1153
      // document.getElementById("webGl2Status").style.backgroundColor =  isWebGL2ContextLost() ? "Red" : "Green"
1154
1155
      // document.getElementById("memoryStatus").style.backgroundColor =  tf.memory().unreliable ? "Red" : "Green"
1156
    }
1157
    // })
1158
1159
    // -- if(...)  end
1160
  } else {
1161
    // No preModel
1162
1163
    // --Phase-2, After remove the skull try to allocate brain volume and make inferece
1164
    console.log('--- No pre-model is selected ---')
1165
    console.log('------ Run voxel cropping ------')
1166
    // -- mask_3d = slices_3d.greater([0]).asType('bool')
1167
1168
    if (isModelFullVol) {
1169
      if (modelEntry.enableSeqConv) {
1170
        // Voxel cropping & seq conv
1171
        // Non-Atlas model (e.g. GWM) needs sequential convolution layer.
1172
        // Sequential convolution layer to be used after cropping - slow but reliable on most machines
1173
        console.log('------ Seq Convoluton ------')
1174
        await inferenceFullVolumeSeqCovLayerPhase2(
1175
          opts,
1176
          modelEntry,
1177
          model,
1178
          slices_3d,
1179
          num_of_slices,
1180
          slice_height,
1181
          slice_width,
1182
          null,
1183
          callbackUI,
1184
          callbackImg,
1185
          statData,
1186
          niftiImage
1187
        )
1188
      } else {
1189
        // Voxel cropping BUT no seq conv
1190
        // todo: we do not use result const outimg = await
1191
        inferenceFullVolumePhase2(
1192
          model,
1193
          slices_3d,
1194
          num_of_slices,
1195
          slice_height,
1196
          slice_width,
1197
          null,
1198
          modelEntry,
1199
          statData,
1200
          opts,
1201
          callbackImg,
1202
          callbackUI,
1203
          niftiImage
1204
        )
1205
      }
1206
    } else {
1207
      // -- In version 3.0.0 this function not used
1208
      window.alert('inferenceSubVolumes() is not dead code?')
1209
    }
1210
  }
1211
}
1212
1213
async function enableProductionMode(textureF16Flag = true) {
1214
  // -- tf.setBackend('cpu')
1215
  // -- tf.removeBackend('cpu')
1216
  // -- Calling enableProdMode() method
1217
  await tf.enableProdMode()
1218
  // -- Setting debug mode of the  environment
1219
  tf.env().set('DEBUG', false)
1220
  tf.env().set('WEBGL_FORCE_F16_TEXTURES', textureF16Flag)
1221
  // -- set this flag so that textures are deleted when tensors are disposed.
1222
  tf.env().set('WEBGL_DELETE_TEXTURE_THRESHOLD', -1)
1223
  // -- tf.env().set('WEBGL_PACK', false)
1224
  // -- Put ready after sets above
1225
  await tf.ready()
1226
  // -- Printing output
1227
  console.log('tf env() flags :', tf.env().flags)
1228
  console.log('tf env() features :', tf.env().features)
1229
  console.log('tf env total features: ', Object.keys(tf.env().features).length)
1230
  console.log(tf.getBackend())
1231
}
1232
1233
export async function runInference(opts, modelEntry, niftiHeader, niftiImage, callbackImg, callbackUI) {
1234
  const statData = []
1235
  statData.startTime = Date.now() // for common webworker/mainthread do not use performance.now()
1236
  callbackUI('Segmentation started', 0)
1237
  const startTime = performance.now()
1238
  const batchSize = opts.batchSize
1239
  const numOfChan = opts.numOfChan
1240
  if (isNaN(batchSize) || batchSize !== 1) {
1241
    const errTxt = 'The batch Size for input shape must be 1'
1242
    callbackUI(errTxt, -1, errTxt)
1243
    return 0
1244
  }
1245
  if (isNaN(numOfChan) || numOfChan !== 1) {
1246
    const errTxt = 'The number of channels for input shape must be 1'
1247
    callbackUI(errTxt, -1, errTxt)
1248
    return 0
1249
  }
1250
  tf.engine().startScope()
1251
  console.log('Batch size: ', batchSize)
1252
  console.log('Num of Channels: ', numOfChan)
1253
  const model = await load_model(opts.rootURL + modelEntry.path)
1254
  await enableProductionMode(true)
1255
  statData.TF_Backend = tf.getBackend()
1256
  const modelObject = model
1257
  let batchInputShape = []
1258
  // free global variable of 16777216 voxel
1259
  // allOutputSlices3DCC1DimArray = []
1260
  // outputSceneRendered = false
1261
  // read input shape from model.json object
1262
  batchInputShape = modelObject.layers[0].batchInputShape
1263
  console.log(' Model batch input shape : ', batchInputShape)
1264
  // -- Verify input shape
1265
  if (batchInputShape.length !== 5) {
1266
    const errTxt = 'The model input shape must be 5D'
1267
    callbackUI(errTxt, -1, errTxt)
1268
    return 0
1269
  }
1270
  let batch_D, batch_H, batch_W
1271
  const slice_width = niftiHeader.dims[1]
1272
  const slice_height = niftiHeader.dims[2]
1273
  const num_of_slices = niftiHeader.dims[3]
1274
  const isChannelLast = await isModelChnlLast(modelObject)
1275
  if (isChannelLast) {
1276
    console.log('Model Channel Last')
1277
    if (isNaN(batchInputShape[4]) || batchInputShape[4] !== 1) {
1278
      const errTxt = 'The number of channels for input shape must be 1'
1279
      callbackUI(errTxt, -1, errTxt)
1280
      return 0
1281
    }
1282
    batch_D = batchInputShape[1]
1283
    batch_H = batchInputShape[2]
1284
    batch_W = batchInputShape[3]
1285
  } else {
1286
    console.log('Model Channel First')
1287
    if (isNaN(batchInputShape[1]) || batchInputShape[1] !== 1) {
1288
      const errTxt = 'The number of channels for input shape must be 1'
1289
      callbackUI(errTxt, -1, errTxt)
1290
      return 0
1291
    }
1292
    batch_D = batchInputShape[2]
1293
    batch_H = batchInputShape[3]
1294
    batch_W = batchInputShape[4]
1295
  }
1296
  // const input_shape = [batchSize, numOfChan, batch_D, batch_H, batch_W]
1297
  //  //-- Atlas version check
1298
  // if ( (batch_D > 30) && (batch_H == 256) && (batch_W == 256) ) {
1299
  //    const errTxt = "The subvolume dimension in z-axis shouldn't exceed 30 number of slices for browser limitation"
1300
  //    callbackUI(errTxt, -1, errTxt)
1301
  //    return 0
1302
  // }
1303
  // --Check whether the model will make inference at once as FullVolumeModel
1304
  let isModelFullVol
1305
  if (batch_D === 256 && batch_H === 256 && batch_W === 256) {
1306
    isModelFullVol = true
1307
  } else {
1308
    isModelFullVol = false
1309
  }
1310
  statData.isModelFullVol = isModelFullVol
1311
  // Model output number of segmentations
1312
  let slices_3d = await getAllSlicesDataAsTF3D(num_of_slices, niftiHeader, niftiImage)
1313
  const transpose = modelEntry.enableTranspose
1314
  const enableCrop = modelEntry.enableCrop
1315
  if (isModelFullVol) {
1316
    if (enableCrop) {
1317
      // FullVolume with Crop option before inference ..
1318
      // pre-model to mask the volume, can also be null and the cropping will be on the MRI.
1319
      await inferenceFullVolumePhase1(
1320
        model,
1321
        slices_3d,
1322
        num_of_slices,
1323
        slice_height,
1324
        slice_width,
1325
        isModelFullVol,
1326
        modelEntry,
1327
        statData,
1328
        opts,
1329
        callbackImg,
1330
        callbackUI,
1331
        niftiImage
1332
      )
1333
    } else {
1334
      // Transpose MRI data to be match pytorch/keras input output
1335
      console.log('Cropping Disabled')
1336
1337
      if (transpose) {
1338
        slices_3d = slices_3d.transpose()
1339
        console.log('Input transposed')
1340
      } else {
1341
        console.log('Transpose NOT Enabled')
1342
      }
1343
1344
      const enableSeqConv = modelEntry.enableSeqConv
1345
1346
      if (enableSeqConv) {
1347
        console.log('Seq Convoluton Enabled')
1348
        window.alert('inferenceFullVolumeSeqCovLayer() is not dead code?')
1349
      } else {
1350
        console.log('Seq Convoluton Disabled')
1351
        window.alert('inferenceFullVolume() is not dead code?')
1352
      }
1353
    }
1354
  }
1355
}