a b/brainchop-webworker.js
1
import * as tf from '@tensorflow/tfjs'
2
import { inferenceModelsList } from './brainchop-parameters.js'
3
import {
4
  addZeroPaddingTo3dTensor,
5
  applyMriThreshold,
6
  binarizeVolumeDataTensor,
7
  convByOutputChannelAndInputSlicing,
8
  draw3dObjBoundingVolume,
9
  firstLastNonZero3D,
10
  generateBrainMask,
11
  generateOutputSlicesV2,
12
  getAllSlicesDataAsTF3D,
13
  getModelNumLayers,
14
  getModelNumParameters,
15
  isModelChnlLast,
16
  load_model,
17
  minMaxNormalizeVolumeData,
18
  quantileNormalizeVolumeData,
19
  removeZeroPaddingFrom3dTensor,
20
  resizeWithZeroPadding,
21
  SequentialConvLayer
22
} from './tensor-utils.js'
23
24
function callbackUI(message = '', progressFrac = -1, modalMessage = '', statData = []) {
25
  let statStr = []
26
  if (Object.keys(statData).length > 0) {
27
    function arrayToStr() {
28
      const list = {}
29
      for (const key in statData) {
30
        list[key] = statData[key]
31
      }
32
      return JSON.stringify(list)
33
    }
34
    statStr = arrayToStr(statData)
35
  }
36
  self.postMessage({
37
    cmd: 'ui',
38
    message,
39
    progressFrac,
40
    modalMessage,
41
    statData: statStr
42
  })
43
}
44
45
function callbackImg(img, opts, modelEntry) {
46
  self.postMessage({ cmd: 'img', img, opts, modelEntry })
47
}
48
49
async function inferenceFullVolumeSeqCovLayerPhase2(
50
  opts,
51
  modelEntry,
52
  model,
53
  slices_3d,
54
  num_of_slices,
55
  slice_height,
56
  slice_width,
57
  pipeline1_out,
58
  statData,
59
  niftiImage
60
) {
61
  // --Phase-2, After remove the skull try to allocate brain volume and make inferece
62
63
  console.log(' ---- Start FullVolume Inference with Sequential Conv Layer for phase-II ---- ')
64
  const quantileNorm = modelEntry.enableQuantileNorm
65
  if (quantileNorm) {
66
    // Quantile normalize function needs specific models to be used
67
    console.log('preModel Quantile normalization enabled')
68
    slices_3d = await quantileNormalizeVolumeData(slices_3d)
69
  } else {
70
    // Min Max Nomalize MRI data to be from 0 to 1
71
    console.log('preModel Min Max normalization enabled')
72
    slices_3d = await minMaxNormalizeVolumeData(slices_3d)
73
  }
74
75
  let mask_3d
76
77
  if (pipeline1_out == null) {
78
    // preModel is null
79
80
    // Check if thresholding the MRI to remove noisy voxels for better cropping is needed.
81
    const autoThresholdValue = modelEntry.autoThreshold
82
83
    if (autoThresholdValue > 0 && autoThresholdValue <= 1) {
84
      // Filtered MRI from noisy voxel below  autoThresholdValue
85
      mask_3d = await applyMriThreshold(slices_3d, autoThresholdValue)
86
    } else {
87
      console.log('No valid crop threshold value')
88
      // binarize original image
89
      mask_3d = await slices_3d.greater([0]).asType('bool')
90
    }
91
  } else {
92
    mask_3d = await pipeline1_out.greater([0]).asType('bool')
93
    // -- pipeline1_out.dispose()
94
  }
95
96
  console.log(' mask_3d shape :  ', mask_3d.shape)
97
  const [row_min, row_max, col_min, col_max, depth_min, depth_max] = await firstLastNonZero3D(mask_3d)
98
  mask_3d.dispose()
99
  // -- Reference voxel that cropped volume started slice with it
100
  const refVoxel = [row_min, col_min, depth_min]
101
  // -- Starting form refVoxel, size of bounding volume
102
  const boundVolSizeArr = [row_max - row_min + 1, col_max - col_min + 1, depth_max - depth_min + 1]
103
104
  // -- Extract 3d object (e.g. brain)
105
  const cropped_slices_3d = await slices_3d.slice(
106
    [row_min, col_min, depth_min],
107
    [row_max - row_min + 1, col_max - col_min + 1, depth_max - depth_min + 1]
108
  )
109
  slices_3d.dispose()
110
111
  // -- Padding size add to cropped brain
112
  const pad = modelEntry.cropPadding
113
114
  // Create margin around the bounding volume
115
  let cropped_slices_3d_w_pad = await addZeroPaddingTo3dTensor(cropped_slices_3d, [pad, pad], [pad, pad], [pad, pad])
116
  console.log(' cropped slices_3d with padding shape:  ', cropped_slices_3d_w_pad.shape)
117
118
  cropped_slices_3d.dispose()
119
120
  if (opts.drawBoundingVolume) {
121
    let testVol = await removeZeroPaddingFrom3dTensor(cropped_slices_3d_w_pad, pad, pad, pad)
122
    console.log(' outLabelVolume without padding shape :  ', testVol.shape)
123
124
    testVol = await resizeWithZeroPadding(testVol, num_of_slices, slice_height, slice_width, refVoxel, boundVolSizeArr)
125
    console.log(' outLabelVolume final shape after resizing :  ', testVol.shape)
126
    draw3dObjBoundingVolume(tf.unstack(testVol), opts, modelEntry, callbackImg)
127
    testVol.dispose()
128
129
    return 0
130
  }
131
132
  statData.Brainchop_Ver = 'FullVolume'
133
  const res = await model
134
  try {
135
    let startTime = performance.now()
136
    const inferenceStartTime = performance.now()
137
    // maxLabelPredicted in whole volume of the brain
138
    let maxLabelPredicted = 0
139
    const transpose = modelEntry.enableTranspose
140
141
    if (transpose) {
142
      cropped_slices_3d_w_pad = await cropped_slices_3d_w_pad.transpose()
143
      console.log('Input transposed for pre-model')
144
    } else {
145
      console.log('Transpose not enabled for pre-model')
146
    }
147
148
    let i = 1
149
    const layersLength = res.layers.length
150
    console.log('res.layers.length ', layersLength)
151
152
    const isChannelLast = isModelChnlLast(res)
153
    const batchSize = opts.batchSize
154
    const numOfChan = opts.numOfChan
155
    let adjusted_input_shape
156
    // -- Adjust model input shape
157
    if (isChannelLast) {
158
      res.layers[0].batchInputShape[1] = cropped_slices_3d_w_pad.shape[0]
159
      res.layers[0].batchInputShape[2] = cropped_slices_3d_w_pad.shape[1]
160
      res.layers[0].batchInputShape[3] = cropped_slices_3d_w_pad.shape[2]
161
162
      adjusted_input_shape = [
163
        batchSize,
164
        res.layers[0].batchInputShape[1],
165
        res.layers[0].batchInputShape[2],
166
        res.layers[0].batchInputShape[3],
167
        numOfChan
168
      ]
169
    } else {
170
      res.layers[0].batchInputShape[2] = cropped_slices_3d_w_pad.shape[0]
171
      res.layers[0].batchInputShape[3] = cropped_slices_3d_w_pad.shape[1]
172
      res.layers[0].batchInputShape[4] = cropped_slices_3d_w_pad.shape[2]
173
174
      adjusted_input_shape = [
175
        batchSize,
176
        numOfChan,
177
        res.layers[0].batchInputShape[2],
178
        res.layers[0].batchInputShape[3],
179
        res.layers[0].batchInputShape[4]
180
      ]
181
    }
182
183
    console.log(' Model batch input shape : ', res.layers[0].batchInputShape)
184
    // -- batchInputShape {Array} input_shape - e.g. [?, D, H, W, Ch] or [?, Ch, D, H, W]
185
186
    statData.Input_Shape = JSON.stringify(res.layers[0].batchInputShape)
187
    statData.Output_Shape = JSON.stringify(res.output.shape)
188
    statData.Channel_Last = await isChannelLast
189
    statData.Model_Param = await getModelNumParameters(res)
190
    statData.Model_Layers = await getModelNumLayers(res)
191
    statData.Model = modelEntry.modelName
192
    statData.Seq_Conv = modelEntry.enableSeqConv
193
    // statData.Extra_Info = null
194
195
    // Determine the number of output channels in the last layer of the model
196
    //  e.g. 3, 50, 104
197
    const outputLayer = res.layers[res.layers.length - 1]
198
    console.log('Output Layer : ', outputLayer)
199
200
    const expected_Num_labels = isChannelLast
201
      ? outputLayer.outputShape[outputLayer.outputShape.length - 1]
202
      : outputLayer.outputShape[1]
203
    console.log('Num of output channels x: ', expected_Num_labels)
204
205
    const curTensor = []
206
    curTensor[0] = await cropped_slices_3d_w_pad.reshape(adjusted_input_shape)
207
    while (true) {
208
      try {
209
        if (res.layers[i].activation.getClassName() !== 'linear') {
210
          curTensor[i] = await res.layers[i].apply(curTensor[i - 1])
211
        } else {
212
          curTensor[i] = await convByOutputChannelAndInputSlicing(
213
            curTensor[i - 1],
214
            res.layers[i].getWeights()[0],
215
            res.layers[i].getWeights()[1],
216
            res.layers[i].strides,
217
            res.layers[i].padding,
218
            res.layers[i].dilationRate,
219
            3
220
          ) // important for memory use
221
        }
222
223
        tf.dispose(curTensor[i - 1])
224
      } catch (err) {
225
        const errTxt = 'Your graphics card (e.g. Intel) may not be compatible with WebGL. ' + err.message
226
        callbackUI(errTxt, -1, errTxt)
227
228
        tf.engine().endScope()
229
        tf.engine().disposeVariables()
230
231
        statData.Inference_t = Infinity
232
        statData.Postprocess_t = Infinity
233
        statData.Status = 'Fail'
234
        statData.Error_Type = err.message
235
        statData.Extra_Err_Info = 'Failed while model layer ' + i + ' apply'
236
237
        callbackUI('', -1, '', statData)
238
239
        return 0
240
      }
241
242
      console.log('layer output Tensor shape : ', curTensor[i].shape)
243
      console.log('layer count params ', res.layers[i].countParams())
244
245
      res.layers[i].dispose()
246
      curTensor[i - 1].dispose()
247
248
      callbackUI('Layer ' + i.toString(), (i + 1) / layersLength)
249
      if (tf.memory().unreliable) {
250
        const unreliableReasons = 'unreliable reasons :' + tf.memory().reasons
251
        callbackUI(unreliableReasons, NaN, unreliableReasons)
252
      }
253
      if (i === layersLength - 2) {
254
        // Stop before the last layer or classification layer.
255
256
        // // Create an instance of SequentialConvLayer
257
        // The second parameter is important for memory,
258
        // the larger it is, the more memory it uses
259
        // it was 8, but I set it to 3, got a different error
260
        // let seqConvLayer = new SequentialConvLayer(res, 10, isChannelLast)
261
        const seqConvLayer = await new SequentialConvLayer(res, 10, isChannelLast, callbackUI)
262
263
        // Apply the last output tensor to the seq. instance
264
        let outputTensor = null
265
        const profileInfo = await tf.profile(async () => {
266
          // Your tensor operations here
267
          outputTensor = await seqConvLayer.apply(curTensor[i])
268
        })
269
        console.log('profileInfo : ', profileInfo)
270
271
        // -- document.getElementById("progressBarChild").style.width = 0 + "%";
272
273
        // Dispose the previous layer input tensor
274
        tf.dispose(curTensor[i])
275
        // delete the used class
276
        // ? delete seqConvLayer
277
278
        // You can now use 'outputTensor' as needed
279
        console.log(' Output tensor', outputTensor)
280
        console.log(' Output tensor shape : ', outputTensor.shape)
281
        // Array(3) [ 256, 256, 256 ]
282
283
        if (outputTensor.shape.length !== 3) {
284
          const msg = 'Output tensor shape should be 3 dims but it is ' + outputTensor.shape.length
285
          callbackUI(msg, -1, msg)
286
        }
287
288
        console.log(' find array max ')
289
        const curBatchMaxLabel = await outputTensor.max().dataSync()[0]
290
        const Inference_t = ((performance.now() - startTime) / 1000).toFixed(4)
291
        
292
        if (maxLabelPredicted < curBatchMaxLabel) {
293
          maxLabelPredicted = curBatchMaxLabel
294
        }
295
296
        const numSegClasses = maxLabelPredicted + 1
297
        console.log('Predicted num of segmentation classes', numSegClasses)
298
        statData.Actual_Labels = numSegClasses
299
        statData.Expect_Labels = expected_Num_labels
300
        statData.NumLabels_Match = numSegClasses === expected_Num_labels
301
        if (numSegClasses !== expected_Num_labels) {
302
          const msg = 'expected ' + expected_Num_labels + ' labels, but the predicted are ' + numSegClasses
303
          callbackUI(msg, -1, msg)
304
        }
305
306
        // -- Transpose back to original unpadded size
307
        let outLabelVolume = outputTensor.reshape([
308
          cropped_slices_3d_w_pad.shape[0],
309
          cropped_slices_3d_w_pad.shape[1],
310
          cropped_slices_3d_w_pad.shape[2]
311
        ])
312
        tf.dispose(outputTensor)
313
314
        // Transpose MRI data to be match pytorch/keras input output
315
        if (transpose) {
316
          console.log('outLabelVolume transposed')
317
          outLabelVolume = outLabelVolume.transpose()
318
        }
319
320
        outLabelVolume = await removeZeroPaddingFrom3dTensor(outLabelVolume, pad, pad, pad)
321
        console.log(' outLabelVolume without padding shape :  ', outLabelVolume.shape)
322
        outLabelVolume = await resizeWithZeroPadding(
323
          outLabelVolume,
324
          num_of_slices,
325
          slice_height,
326
          slice_width,
327
          refVoxel,
328
          boundVolSizeArr
329
        )
330
        console.log(' outLabelVolume final shape after resizing :  ', outLabelVolume.shape)
331
332
        // let filterOutWithPreMask =  inferenceModelsList[$$("selectModel").getValue() - 1]["filterOutWithPreMask"]
333
        const filterOutWithPreMask = modelEntry.filterOutWithPreMask
334
        // To clean the skull area wrongly segmented inphase-2.
335
        if (pipeline1_out != null && opts.isBrainCropMaskBased && filterOutWithPreMask) {
336
          const bin = await binarizeVolumeDataTensor(pipeline1_out)
337
          outLabelVolume = await outLabelVolume.mul(bin)
338
        }
339
340
        startTime = performance.now()
341
        // Generate output volume or slices
342
        console.log('Generating correct output')
343
        let outimg
344
        try {
345
          const img = await new Uint32Array(outLabelVolume.dataSync())
346
          const Vshape = outLabelVolume.shape
347
          const Vtype = outLabelVolume.dtype
348
          outimg = await generateOutputSlicesV2(
349
            img,
350
            Vshape,
351
            Vtype,
352
            num_of_slices,
353
            numSegClasses,
354
            slice_height,
355
            slice_width,
356
            modelEntry,
357
            opts,
358
            niftiImage
359
          )
360
          console.log(' Phase-2 num of tensors after generateOutputSlicesV2: ', tf.memory().numTensors)
361
362
          tf.dispose(outLabelVolume)
363
          tf.engine().endScope()
364
          tf.engine().disposeVariables()
365
        } catch (error) {
366
          // -- Timing data to collect
367
          tf.engine().endScope()
368
          tf.engine().disposeVariables()
369
          console.log('Error while generating output: ', error)
370
          const msg = 'Failed while generating output due to limited browser memory available'
371
          callbackUI(msg, -1, msg)
372
373
          statData.Inference_t = Inference_t
374
          statData.Postprocess_t = Infinity
375
          statData.Status = 'Fail'
376
          statData.Error_Type = error.message
377
          statData.Extra_Err_Info = 'Failed while generating output'
378
379
          callbackUI('', -1, '', statData)
380
381
          return 0
382
        }
383
        const Postprocess_t = ((performance.now() - startTime) / 1000).toFixed(4)
384
385
        console.log(
386
          'Processing the whole brain volume in tfjs for multi-class output mask took : ',
387
          ((performance.now() - inferenceStartTime) / 1000).toFixed(4) + '  Seconds'
388
        )
389
390
        // -- Timing data to collect
391
        statData.Inference_t = Inference_t
392
        statData.Postprocess_t = Postprocess_t
393
        statData.Status = 'OK'
394
395
        callbackUI('', -1, '', statData)
396
        callbackUI('Segmentation finished', 0)
397
        callbackImg(outimg, opts, modelEntry)
398
        return 0
399
      } else {
400
        i++
401
      }
402
    }
403
  } catch (err) {
404
    callbackUI(err.message, -1, err.message)
405
    console.log(
406
      'If webgl context is lost, try to restore webgl context by visit the link ' +
407
        '<a href="https://support.biodigital.com/hc/en-us/articles/218322977-How-to-turn-on-WebGL-in-my-browser">here</a>'
408
    )
409
    if (tf.memory().unreliable) {
410
      const unreliableReasons = 'unreliable reasons :' + tf.memory().reasons
411
      callbackUI(unreliableReasons, NaN, unreliableReasons)
412
    }
413
  }
414
}
415
416
async function inferenceFullVolumePhase2(
417
  model,
418
  slices_3d,
419
  num_of_slices,
420
  slice_height,
421
  slice_width,
422
  pipeline1_out,
423
  modelEntry,
424
  statData,
425
  opts,
426
  niftiImage
427
) {
428
  let outimg = []
429
  // --Phase-2, After remove the skull try to allocate brain volume and make inferece
430
  console.log(' ---- Start FullVolume inference phase-II ---- ')
431
  const quantileNorm = modelEntry.enableQuantileNorm
432
  if (quantileNorm) {
433
    // Quantile normalize function needs specific models to be used
434
    console.log('preModel Quantile normalization enabled')
435
    slices_3d = await quantileNormalizeVolumeData(slices_3d)
436
  } else {
437
    // Min Max Nomalize MRI data to be from 0 to 1
438
    console.log('preModel Min Max normalization enabled')
439
    slices_3d = await minMaxNormalizeVolumeData(slices_3d)
440
  }
441
  let mask_3d
442
  if (pipeline1_out == null) {
443
    // preModel is null
444
445
    // Check if thresholding the MRI to remove noisy voxels for better cropping is needed.
446
    const autoThresholdValue = modelEntry.autoThreshold
447
448
    if (autoThresholdValue > 0 && autoThresholdValue <= 1) {
449
      // Filtered MRI from noisy voxel below  autoThresholdValue
450
      mask_3d = await applyMriThreshold(slices_3d, autoThresholdValue)
451
    } else {
452
      console.log('No valid crop threshold value')
453
      // binarize original image
454
      mask_3d = await slices_3d.greater([0]).asType('bool')
455
    }
456
  } else {
457
    mask_3d = await pipeline1_out.greater([0]).asType('bool')
458
    // -- pipeline1_out.dispose()
459
  }
460
  console.log(' mask_3d shape :  ', mask_3d.shape)
461
  const [row_min, row_max, col_min, col_max, depth_min, depth_max] = await firstLastNonZero3D(mask_3d)
462
  mask_3d.dispose()
463
  // -- Reference voxel that cropped volume started slice with it
464
  const refVoxel = [row_min, col_min, depth_min]
465
  console.log('refVoxel :', refVoxel)
466
467
  // -- Starting form refVoxel, size of bounding volume
468
  const boundVolSizeArr = [row_max - row_min + 1, col_max - col_min + 1, depth_max - depth_min + 1]
469
470
  console.log('boundVolSizeArr :', boundVolSizeArr)
471
  // -- Extract 3d object (e.g. brain)
472
  const cropped_slices_3d = slices_3d.slice(
473
    [row_min, col_min, depth_min],
474
    [row_max - row_min + 1, col_max - col_min + 1, depth_max - depth_min + 1]
475
  )
476
477
  slices_3d.dispose()
478
479
  // -- Padding size add to cropped brain
480
  const pad = modelEntry.cropPadding
481
482
  // Create margin around the bounding volume
483
  let cropped_slices_3d_w_pad = await addZeroPaddingTo3dTensor(cropped_slices_3d, [pad, pad], [pad, pad], [pad, pad])
484
  console.log(' cropped slices_3d with padding shape:  ', cropped_slices_3d_w_pad.shape)
485
486
  cropped_slices_3d.dispose()
487
488
  // -- Test dim after padding ..
489
  // for (let i = 0; i < cropped_slices_3d_w_pad.rank; i++) {
490
  //     if(cropped_slices_3d_w_pad.shape[i] > 256) {
491
  //          console.log(" cropped_slices_3d_w_pad > 256 ")
492
  //     }
493
494
  // }
495
496
  if (opts.drawBoundingVolume) {
497
    let testVol = await removeZeroPaddingFrom3dTensor(cropped_slices_3d_w_pad, pad, pad, pad)
498
    console.log(' outLabelVolume without padding shape :  ', testVol.shape)
499
500
    testVol = await resizeWithZeroPadding(testVol, num_of_slices, slice_height, slice_width, refVoxel, boundVolSizeArr)
501
    console.log(' outLabelVolume final shape after resizing :  ', testVol.shape)
502
    draw3dObjBoundingVolume(tf.unstack(testVol), opts, modelEntry, callbackImg)
503
    testVol.dispose()
504
505
    return 0
506
  }
507
508
  statData.Brainchop_Ver = 'FullVolume'
509
  let startTime = performance.now()
510
  let adjusted_input_shape = []
511
  const res = await model
512
  try {
513
    startTime = performance.now()
514
    const inferenceStartTime = performance.now()
515
    // maxLabelPredicted in whole volume of the brain
516
    let maxLabelPredicted = 0
517
    const transpose = modelEntry.enableTranspose
518
519
    if (transpose) {
520
      cropped_slices_3d_w_pad = cropped_slices_3d_w_pad.transpose()
521
      console.log('Input transposed for pre-model')
522
    } else {
523
      console.log('Transpose not enabled for pre-model')
524
    }
525
526
    let i = 1
527
    const layersLength = res.layers.length
528
    console.log('res.layers.length ', layersLength)
529
530
    const isChannelLast = isModelChnlLast(res)
531
    const batchSize = opts.batchSize
532
    const numOfChan = opts.numOfChan
533
534
    // -- Adjust model input shape
535
    if (isChannelLast) {
536
      res.layers[0].batchInputShape[1] = cropped_slices_3d_w_pad.shape[0]
537
      res.layers[0].batchInputShape[2] = cropped_slices_3d_w_pad.shape[1]
538
      res.layers[0].batchInputShape[3] = cropped_slices_3d_w_pad.shape[2]
539
540
      adjusted_input_shape = [
541
        batchSize,
542
        res.layers[0].batchInputShape[1],
543
        res.layers[0].batchInputShape[2],
544
        res.layers[0].batchInputShape[3],
545
        numOfChan
546
      ]
547
    } else {
548
      res.layers[0].batchInputShape[2] = cropped_slices_3d_w_pad.shape[0]
549
      res.layers[0].batchInputShape[3] = cropped_slices_3d_w_pad.shape[1]
550
      res.layers[0].batchInputShape[4] = cropped_slices_3d_w_pad.shape[2]
551
552
      adjusted_input_shape = [
553
        batchSize,
554
        numOfChan,
555
        res.layers[0].batchInputShape[2],
556
        res.layers[0].batchInputShape[3],
557
        res.layers[0].batchInputShape[4]
558
      ]
559
    }
560
561
    console.log(' Model batch input shape : ', res.layers[0].batchInputShape)
562
    // -- batchInputShape {Array} input_shape - e.g. [?, D, H, W, Ch] or [?, Ch, D, H, W]
563
564
    statData.Input_Shape = JSON.stringify(res.layers[0].batchInputShape)
565
    statData.Output_Shape = JSON.stringify(res.output.shape)
566
    statData.Channel_Last = await isChannelLast
567
    statData.Model_Param = await getModelNumParameters(res)
568
    statData.Model_Layers = await getModelNumLayers(res)
569
    statData.Model = modelEntry.modelName
570
    // statData.Extra_Info = null
571
572
    const curTensor = []
573
    curTensor[0] = cropped_slices_3d_w_pad.reshape(adjusted_input_shape)
574
    // console.log("curTensor[0] :", curTensor[0].dataSync())
575
576
    while (true) {
577
      try {
578
        // -- curTensor[i] = res.layers[i].apply( curTensor[i-1])
579
        curTensor[i] = res.layers[i].apply(curTensor[i - 1])
580
      } catch (err) {
581
        callbackUI(err.message, -1, err.message)
582
        tf.engine().endScope()
583
        tf.engine().disposeVariables()
584
585
        statData.Inference_t = Infinity
586
        statData.Postprocess_t = Infinity
587
        statData.Status = 'Fail'
588
        statData.Error_Type = err.message
589
        statData.Extra_Err_Info = 'Failed while model layer ' + i + ' apply'
590
591
        callbackUI('', -1, '', statData)
592
593
        return 0
594
      }
595
      callbackUI('Layer ' + i.toString(), (i + 1) / layersLength)
596
      console.log('layer output Tensor shape : ', curTensor[i].shape)
597
      console.log('layer count params ', res.layers[i].countParams())
598
      res.layers[i].dispose()
599
      curTensor[i - 1].dispose()
600
      if (tf.memory().unreliable) {
601
        const unreliableReasons = 'unreliable reasons :' + tf.memory().reasons
602
        callbackUI(unreliableReasons, NaN, unreliableReasons)
603
      }
604
605
      if (i === layersLength - 1) {
606
        // prediction = res.layers[res.layers.length-1].apply(curTensor[i])
607
        // curTensor[i].print()
608
        // outputDataBeforArgmx = Array.from(curTensor[i].dataSync())
609
610
        const axis = isChannelLast ? -1 : 1
611
        console.log(' find argmax ')
612
        console.log('last Tensor shape : ', curTensor[i].shape)
613
        // -- curTensor[i].shape  e.g. [ 1, 256, 256, 256, 3 ]
614
        const expected_Num_labels = isChannelLast ? curTensor[i].shape[4] : curTensor[i].shape[1]
615
        let prediction_argmax
616
617
        // Try for argMax with model output tensor.
618
619
        try {
620
          const argMaxTime = performance.now()
621
          console.log(' Try tf.argMax for fullVolume ..')
622
          prediction_argmax = tf.argMax(curTensor[i], axis)
623
          console.log('tf.argMax for fullVolume takes : ', ((performance.now() - argMaxTime) / 1000).toFixed(4))
624
        } catch (err1) {
625
          // if channel last
626
          if (axis === -1) {
627
            try {
628
              const argMaxLargeTime = performance.now()
629
              console.log(' tf.argMax failed .. try argMaxLarge ..')
630
              callbackUI('', -1, 'tensor2LightBuffer() is not dead code?')
631
              callbackUI('', -1, 'argMaxLarge() is not dead code?')
632
              console.log(
633
                'argMaxLarge for fullVolume takes : ',
634
                ((performance.now() - argMaxLargeTime) / 1000).toFixed(4)
635
              )
636
            } catch (err2) {
637
              const errTxt = "argMax buffer couldn't be created due to limited memory resources."
638
              callbackUI(errTxt, -1, errTxt)
639
640
              tf.engine().endScope()
641
              tf.engine().disposeVariables()
642
643
              statData.Inference_t = Infinity
644
              statData.Postprocess_t = Infinity
645
              statData.Status = 'Fail'
646
              statData.Error_Type = err2.message
647
              statData.Extra_Err_Info = 'prediction_argmax from argMaxLarge failed'
648
649
              callbackUI('', -1, '', statData)
650
              return 0
651
            }
652
          } else {
653
            // if channel first ..
654
            const errTxt = "argMax buffer couldn't be created due to limited memory resources."
655
            callbackUI(errTxt, -1, errTxt)
656
657
            prediction_argmax.dispose()
658
659
            tf.engine().endScope()
660
            tf.engine().disposeVariables()
661
662
            statData.Inference_t = Infinity
663
            statData.Postprocess_t = Infinity
664
            statData.Status = 'Fail'
665
            statData.Error_Type = err1.message
666
            statData.Extra_Err_Info = 'prediction_argmax from argMaxLarge not support yet channel first'
667
668
            callbackUI('', -1, '', statData)
669
670
            return 0
671
          }
672
        }
673
674
        console.log(' prediction_argmax shape : ', prediction_argmax.shape)
675
        // -- prediction_argmax.shape  : [ 1, 256, 256, 256]
676
677
        const Inference_t = ((performance.now() - startTime) / 1000).toFixed(4)
678
679
        // outputDataBeforArgmx = Array.from(prediction_argmax.dataSync())
680
        tf.dispose(curTensor[i])
681
        console.log(' find array max ')
682
        const curBatchMaxLabel = await prediction_argmax.max().dataSync()[0]
683
684
        if (maxLabelPredicted < curBatchMaxLabel) {
685
          maxLabelPredicted = curBatchMaxLabel
686
        }
687
688
        const numSegClasses = maxLabelPredicted + 1
689
        console.log('numSegClasses', numSegClasses)
690
        statData.Actual_Labels = numSegClasses
691
        statData.Expect_Labels = expected_Num_labels
692
        statData.NumLabels_Match = numSegClasses === expected_Num_labels
693
694
        if (numSegClasses !== expected_Num_labels) {
695
          // errTxt = "expected " + expected_Num_labels + " labels, but the predicted are " + numSegClasses + ". For possible solutions please refer to <a href='https://github.com/neuroneural/brainchop/wiki/FAQ#Q3' target='_blank'><b> FAQ </b></a>.", "alert-error"
696
          const errTxt = 'expected ' + expected_Num_labels + ' labels, but the predicted are ' + numSegClasses
697
          callbackUI(errTxt, -1, errTxt)
698
        }
699
700
        // -- Transpose back to original unpadded size
701
        let outLabelVolume = prediction_argmax.reshape([
702
          cropped_slices_3d_w_pad.shape[0],
703
          cropped_slices_3d_w_pad.shape[1],
704
          cropped_slices_3d_w_pad.shape[2]
705
        ])
706
        tf.dispose(prediction_argmax)
707
708
        // Transpose MRI data to be match pytorch/keras input output
709
        if (transpose) {
710
          console.log('outLabelVolume transposed')
711
          outLabelVolume = outLabelVolume.transpose()
712
        }
713
        outLabelVolume = await removeZeroPaddingFrom3dTensor(outLabelVolume, pad, pad, pad)
714
        console.log(' outLabelVolume without padding shape :  ', outLabelVolume.shape)
715
        outLabelVolume = await resizeWithZeroPadding(
716
          outLabelVolume,
717
          num_of_slices,
718
          slice_height,
719
          slice_width,
720
          refVoxel,
721
          boundVolSizeArr
722
        )
723
        console.log(' outLabelVolume final shape after resizing :  ', outLabelVolume.shape)
724
725
        const filterOutWithPreMask = modelEntry.filterOutWithPreMask
726
        // To clean the skull area wrongly segmented in phase-2.
727
        if (pipeline1_out != null && opts.isBrainCropMaskBased && filterOutWithPreMask) {
728
          const bin = binarizeVolumeDataTensor(pipeline1_out)
729
          outLabelVolume = outLabelVolume.mul(bin)
730
        }
731
732
        startTime = performance.now()
733
        // Generate output volume or slices
734
        console.log('Generating correct output')
735
736
        try {
737
          const img = new Uint32Array(outLabelVolume.dataSync())
738
          const Vshape = outLabelVolume.shape
739
          const Vtype = outLabelVolume.dtype
740
          tf.dispose(outLabelVolume)
741
          tf.engine().endScope()
742
          tf.engine().disposeVariables()
743
          outimg = await generateOutputSlicesV2(
744
            img,
745
            Vshape,
746
            Vtype,
747
            num_of_slices,
748
            numSegClasses,
749
            slice_height,
750
            slice_width,
751
            modelEntry,
752
            opts,
753
            niftiImage
754
          )
755
          console.log(' Phase-2 num of tensors after generateOutputSlicesV2: ', tf.memory().numTensors)
756
        } catch (error) {
757
          // -- Timing data to collect
758
          tf.engine().endScope()
759
          tf.engine().disposeVariables()
760
761
          const errTxt = 'Failed while generating output due to limited browser memory available'
762
          callbackUI(errTxt, -1, errTxt)
763
          statData.Inference_t = Inference_t
764
          statData.Postprocess_t = Infinity
765
          statData.Status = 'Fail'
766
          statData.Error_Type = error.message
767
          statData.Extra_Err_Info = 'Failed while generating output'
768
769
          callbackUI('', -1, '', statData)
770
771
          return 0
772
        }
773
774
        const Postprocess_t = ((performance.now() - startTime) / 1000).toFixed(4)
775
776
        tf.engine().disposeVariables()
777
778
        console.log(
779
          'Processing the whole brain volume in tfjs for multi-class output mask took : ',
780
          ((performance.now() - inferenceStartTime) / 1000).toFixed(4) + '  Seconds'
781
        )
782
783
        // -- Timing data to collect
784
        statData.Inference_t = Inference_t
785
        statData.Postprocess_t = Postprocess_t
786
        statData.Status = 'OK'
787
        callbackUI('Segmentation finished', 0)
788
        callbackUI('', -1, '', statData)
789
        callbackImg(outimg, opts, modelEntry)
790
791
        return 0
792
      }
793
      i++
794
    }
795
  } catch (err) {
796
    callbackUI(err.message, -1, err.message)
797
    console.log(
798
      'If webgl context is lost, try to restore webgl context by visit the link ' +
799
        '<a href="https://support.biodigital.com/hc/en-us/articles/218322977-How-to-turn-on-WebGL-in-my-browser">here</a>'
800
    )
801
  }
802
}
803
804
async function inferenceFullVolumePhase1(
805
  model,
806
  slices_3d,
807
  num_of_slices,
808
  slice_height,
809
  slice_width,
810
  isModelFullVol,
811
  modelEntry,
812
  statData,
813
  opts,
814
  niftiHeader,
815
  niftiImage
816
) {
817
  statData.No_SubVolumes = 1
818
  // load pre-model for inference first, can be null if no pre-model such as GWM models
819
  if (modelEntry.preModelId) {
820
    const preModel = await load_model(opts.rootURL + inferenceModelsList[modelEntry.preModelId - 1].path)
821
    const transpose = inferenceModelsList[modelEntry.preModelId - 1].enableTranspose
822
    const quantileNorm = inferenceModelsList[modelEntry.preModelId - 1].enableQuantileNorm
823
    let preModel_slices_3d = null
824
825
    // -- If pre-model is not null then slices_3d mask will be generated..
826
    // -- The mask is needed to remove the skull and set noise in background to 0, and get the brain bounding volume properly
827
    const slices_3d_mask = null
828
829
    if (quantileNorm) {
830
      // Quantile normalize function needs specific models to be used
831
      console.log('preModel Quantile normalization enabled')
832
      preModel_slices_3d = await quantileNormalizeVolumeData(slices_3d)
833
    } else {
834
      // Min Max Nomalize MRI data to be from 0 to 1
835
      console.log('preModel Min Max normalization enabled')
836
      preModel_slices_3d = await minMaxNormalizeVolumeData(slices_3d)
837
    }
838
839
    // -- Transpose MRI data to be match pytorch/keras input output
840
    // -- Check if pre-model needs transpose..
841
    if (transpose) {
842
      preModel_slices_3d = preModel_slices_3d.transpose()
843
      console.log('Input transposed for pre-model')
844
    } else {
845
      console.log('Transpose not enabled for pre-model')
846
    }
847
848
    statData.Brainchop_Ver = 'PreModel_FV' // e.g. "PreModel_FV"
849
850
    // preModel.then(function (res) {
851
    const res = await preModel
852
853
    try {
854
      const inferenceStartTime = performance.now()
855
      const preModelObject = res
856
857
      // read input shape from model.json object
858
      const preModelBatchInputShape = preModelObject.layers[0].batchInputShape
859
      console.log(' Pre-Model batch input shape : ', preModelBatchInputShape)
860
861
      // -- Verify input shape
862
      if (preModelBatchInputShape.length !== 5) {
863
        const errTxt = 'The pre-model input shape must be 5D '
864
        callbackUI(errTxt, -1, errTxt)
865
        return 0
866
      }
867
868
      const isPreModelChannelLast = await isModelChnlLast(preModelObject)
869
      const batchSize = opts.batchSize
870
      const numOfChan = opts.numOfChan
871
      let batch_D, batch_H, batch_W
872
      let preModel_input_shape
873
      if (isPreModelChannelLast) {
874
        console.log('Pre-Model Channel Last')
875
        if (isNaN(preModelBatchInputShape[4]) || preModelBatchInputShape[4] !== 1) {
876
          const errTxt = 'The number of channels for pre-model input shape must be 1'
877
          callbackUI(errTxt, -1, errTxt)
878
          return 0
879
        }
880
881
        batch_D = preModelBatchInputShape[1]
882
        batch_H = preModelBatchInputShape[2]
883
        batch_W = preModelBatchInputShape[3]
884
885
        preModel_input_shape = [batchSize, batch_D, batch_H, batch_W, numOfChan]
886
      } else {
887
        console.log('Pre-Model Channel First')
888
        if (isNaN(preModelBatchInputShape[1]) || preModelBatchInputShape[1] !== 1) {
889
          const errTxt = 'The number of channels for pre-model input shape must be 1'
890
          callbackUI(errTxt, -1, errTxt)
891
          return 0
892
        }
893
894
        batch_D = preModelBatchInputShape[2]
895
        batch_H = preModelBatchInputShape[3]
896
        batch_W = preModelBatchInputShape[4]
897
898
        preModel_input_shape = [batchSize, numOfChan, batch_D, batch_H, batch_W]
899
      }
900
901
      statData.Input_Shape = JSON.stringify(preModel_input_shape)
902
      statData.Output_Shape = JSON.stringify(preModelObject.output.shape)
903
      statData.Channel_Last = await isPreModelChannelLast
904
      statData.Model_Param = await getModelNumParameters(preModelObject)
905
      statData.Model_Layers = await getModelNumLayers(preModelObject)
906
907
      // maxLabelPredicted in whole volume of the brain
908
      let maxLabelPredicted = 0
909
910
      let i = 1
911
      const layersLength = res.layers.length
912
913
      const curTensor = []
914
      // -- reshape MRI to model input shape
915
      curTensor[0] = preModel_slices_3d.reshape(preModel_input_shape)
916
917
      // Dispose the volume
918
      tf.dispose(preModel_slices_3d)
919
      while (true) {
920
        try {
921
          curTensor[i] = res.layers[i].apply(curTensor[i - 1])
922
        } catch (err) {
923
          const errTxt = 'Your graphics card (e.g. Intel) may not be compatible with WebGL. ' + err.message
924
          callbackUI(errTxt, -1, errTxt)
925
926
          tf.engine().endScope()
927
          tf.engine().disposeVariables()
928
929
          statData.Inference_t = Infinity
930
          statData.Postprocess_t = Infinity
931
          statData.Status = 'Fail'
932
          statData.Error_Type = err.message
933
          statData.Extra_Err_Info = 'PreModel Failed while model layer ' + i + ' apply'
934
935
          callbackUI('', -1, '', statData)
936
937
          return 0
938
        }
939
940
        res.layers[i].dispose()
941
        curTensor[i - 1].dispose()
942
943
        callbackUI('Layer ' + i.toString(), (i + 1) / layersLength)
944
        if (tf.memory().unreliable) {
945
          const unreliableReasons = 'unreliable reasons :' + tf.memory().reasons
946
          callbackUI(unreliableReasons, NaN, unreliableReasons)
947
        }
948
949
        if (i === layersLength - 1) {
950
          // -- prediction = res.layers[res.layers.length-1].apply(curTensor[i])
951
          // -- curTensor[i].print()
952
          // -- outputDataBeforArgmx = Array.from(curTensor[i].dataSync())
953
954
          const axis = isPreModelChannelLast ? -1 : 1
955
          console.log(' find argmax ')
956
          console.log('last Tensor shape : ', curTensor[i].shape)
957
          // -- curTensor[i].shape  : [ 1, 256, 256, 256, 3 ]
958
          const expected_Num_labels = isPreModelChannelLast ? curTensor[i].shape[4] : curTensor[i].shape[1]
959
          let prediction_argmax
960
961
          // Try for argMax with model output tensor.
962
963
          try {
964
            console.log(' Try tf.argMax for fullVolume ..')
965
            prediction_argmax = await tf.argMax(curTensor[i], axis)
966
          } catch (err1) {
967
            // if channel last
968
            if (axis === -1) {
969
              try {
970
                const argMaxLargeTime = performance.now()
971
                console.log(' tf.argMax failed .. try argMaxLarge ..')
972
                callbackUI('', -1, 'tensor2LightBuffer() is not dead code?')
973
                callbackUI('', -1, 'argMaxLarge() is not dead code?')
974
                console.log(
975
                  'argMaxLarge for fullVolume takes : ',
976
                  ((performance.now() - argMaxLargeTime) / 1000).toFixed(4)
977
                )
978
              } catch (err2) {
979
                const errTxt = "argMax buffer couldn't be created due to limited memory resources."
980
                callbackUI(errTxt, -1, errTxt)
981
982
                prediction_argmax.dispose()
983
984
                tf.engine().endScope()
985
                tf.engine().disposeVariables()
986
987
                statData.Inference_t = Infinity
988
                statData.Postprocess_t = Infinity
989
                statData.Status = 'Fail'
990
                statData.Error_Type = err2.message
991
                statData.Extra_Err_Info = 'preModel prediction_argmax from argMaxLarge failed'
992
993
                callbackUI('', -1, '', statData)
994
995
                return 0
996
              }
997
            } else {
998
              // if channel first ..
999
              const errTxt = "argMax buffer couldn't be created due to limited memory resources."
1000
              callbackUI(errTxt, -1, errTxt)
1001
1002
              prediction_argmax.dispose()
1003
1004
              tf.engine().endScope()
1005
              tf.engine().disposeVariables()
1006
1007
              statData.Inference_t = Infinity
1008
              statData.Postprocess_t = Infinity
1009
              statData.Status = 'Fail'
1010
              statData.Error_Type = err1.message
1011
              statData.Extra_Err_Info = 'preModel prediction_argmax from argMaxLarge not support yet channel first'
1012
1013
              callbackUI('', -1, '', statData)
1014
1015
              return 0
1016
            }
1017
          }
1018
1019
          console.log(' Pre-model prediction_argmax shape : ', prediction_argmax.shape)
1020
          // -- prediction_argmax.shape  : [ 1, 256, 256, 256]
1021
1022
          const Inference_t = ((performance.now() - inferenceStartTime) / 1000).toFixed(4)
1023
1024
          tf.dispose(curTensor[i])
1025
1026
          console.log(' Pre-model find array max ')
1027
          const curBatchMaxLabel = await prediction_argmax.max().dataSync()[0]
1028
1029
          if (maxLabelPredicted < curBatchMaxLabel) {
1030
            maxLabelPredicted = curBatchMaxLabel
1031
          }
1032
1033
          const numSegClasses = maxLabelPredicted + 1
1034
          console.log('Pre-model numSegClasses', numSegClasses)
1035
1036
          statData.Actual_Labels = numSegClasses
1037
          statData.Expect_Labels = expected_Num_labels
1038
          statData.NumLabels_Match = numSegClasses === expected_Num_labels
1039
1040
          // -- Transpose back to original unpadded size
1041
          let outLabelVolume = await prediction_argmax.reshape([num_of_slices, slice_height, slice_width])
1042
          tf.dispose(prediction_argmax)
1043
          // Transpose MRI data to be match pytorch/keras input output
1044
          if (transpose) {
1045
            console.log('Pre-model outLabelVolume transposed')
1046
            outLabelVolume = outLabelVolume.transpose()
1047
          }
1048
          const startTime = performance.now()
1049
          // Generate output volume or slices
1050
          console.log('Generating pre-model output')
1051
          let slices_3d_mask
1052
          try {
1053
            const unstackOutVolumeTensor = await tf.unstack(outLabelVolume)
1054
            slices_3d_mask = await generateBrainMask(
1055
              unstackOutVolumeTensor,
1056
              num_of_slices,
1057
              slice_height,
1058
              slice_width,
1059
              modelEntry,
1060
              opts,
1061
              niftiHeader,
1062
              niftiImage,
1063
              false
1064
            )
1065
            await tf.dispose(outLabelVolume)
1066
            console.log(' Phase-1 num of tensors after generateBrainMask: ', tf.memory().numTensors)
1067
          } catch (error) {
1068
            // -- Timing data to collect
1069
            tf.engine().endScope()
1070
            tf.engine().disposeVariables()
1071
1072
            const errTxt = 'Failed while generating pre-model output due to limited browser memory available'
1073
            callbackUI(errTxt, -1, errTxt)
1074
1075
            statData.Inference_t = Inference_t
1076
            statData.Postprocess_t = Infinity
1077
            statData.Status = 'Fail'
1078
            statData.Error_Type = error.message
1079
            statData.Extra_Err_Info = 'Pre-model failed while generating output'
1080
1081
            callbackUI('', -1, '', statData)
1082
1083
            return 0
1084
          }
1085
          const Postprocess_t = ((performance.now() - startTime) / 1000).toFixed(4)
1086
          console.log(
1087
            'Pre-model processing the whole brain volume in tfjs tooks for multi-class output mask : ',
1088
            ((performance.now() - inferenceStartTime) / 1000).toFixed(4) + '  Seconds'
1089
          )
1090
1091
          // -- Timing data to collect
1092
          statData.Inference_t = Inference_t
1093
          statData.Postprocess_t = Postprocess_t
1094
          statData.Status = 'OK'
1095
1096
          callbackUI('', -1, '', statData)
1097
1098
          if (slices_3d_mask == null) {
1099
            const msg = 'slice_3d_mask failed ...'
1100
            callbackUI(msg, -1, msg)
1101
            return 0
1102
          } else {
1103
            // --Phase-2, After remove the skull try to allocate brain volume and make inferece
1104
            console.log('--- pre-model done ---')
1105
            // --mask_3d = slices_3d_mask.greater([0]).asType('bool')
1106
            // --slices_3d_mask.dispose()
1107
1108
            if (isModelFullVol) {
1109
              if (modelEntry.enableSeqConv) {
1110
                // Mask cropping & seq conv
1111
                // Non-Atlas model (e.g. GWM) needs sequential convolution layer.
1112
                // Sequential convolution layer to be used after cropping - slow but reliable on most machines
1113
                console.log('------ Mask Cropping & Seq Convoluton ------')
1114
                await inferenceFullVolumeSeqCovLayerPhase2(
1115
                  opts,
1116
                  modelEntry,
1117
                  model,
1118
                  slices_3d,
1119
                  num_of_slices,
1120
                  slice_height,
1121
                  slice_width,
1122
                  slices_3d_mask,
1123
                  statData,
1124
                  niftiImage
1125
                )
1126
                return 0
1127
                // inferenceFullVolumeSeqCovLayerPhase2(model, slices_3d.transpose(), num_of_slices, slice_height, slice_width, slices_3d_mask)
1128
              } else {
1129
                // Mask cropping BUT no seq conv
1130
                console.log('------ Mask Cropping  -  NO Seq Convoluton ------')
1131
                await inferenceFullVolumePhase2(
1132
                  model,
1133
                  slices_3d,
1134
                  num_of_slices,
1135
                  slice_height,
1136
                  slice_width,
1137
                  slices_3d_mask,
1138
                  modelEntry,
1139
                  statData,
1140
                  opts,
1141
                  niftiImage
1142
                )
1143
                // inferenceFullVolumePhase2(model, slices_3d.transpose(), num_of_slices, slice_height, slice_width, slices_3d_mask)
1144
              }
1145
            } else {
1146
              // -- In version 3.0.0 this function not used
1147
              callbackUI('', -1, 'inferenceSubVolumes() is not dead code?')
1148
            }
1149
          }
1150
        }
1151
        i++
1152
      }
1153
    } catch (err) {
1154
      callbackUI(err.message, -1, err.message)
1155
      console.log(
1156
        'If webgl context is lost, try to restore webgl context by visit the link ' +
1157
          '<a href="https://support.biodigital.com/hc/en-us/articles/218322977-How-to-turn-on-WebGL-in-my-browser">here</a>'
1158
      )
1159
1160
      // document.getElementById("webGl2Status").style.backgroundColor =  isWebGL2ContextLost() ? "Red" : "Green"
1161
      // document.getElementById("memoryStatus").style.backgroundColor =  tf.memory().unreliable ? "Red" : "Green"
1162
    }
1163
    // })
1164
1165
    // -- if(...)  end
1166
  } else {
1167
    // No preModel
1168
1169
    // --Phase-2, After remove the skull try to allocate brain volume and make inferece
1170
    console.log('--- No pre-model is selected ---')
1171
    console.log('------ Run voxel cropping ------')
1172
    // -- mask_3d = slices_3d.greater([0]).asType('bool')
1173
1174
    if (isModelFullVol) {
1175
      if (modelEntry.enableSeqConv) {
1176
        // Voxel cropping & seq conv
1177
        // Non-Atlas model (e.g. GWM) needs sequential convolution layer.
1178
        // Sequential convolution layer to be used after cropping - slow but reliable on most machines
1179
        console.log('------ Seq Convoluton ------')
1180
        await inferenceFullVolumeSeqCovLayerPhase2(
1181
          opts,
1182
          modelEntry,
1183
          model,
1184
          slices_3d,
1185
          num_of_slices,
1186
          slice_height,
1187
          slice_width,
1188
          null,
1189
          statData,
1190
          niftiImage
1191
        )
1192
      } else {
1193
        // Voxel cropping BUT no seq conv
1194
        // todo: we do not use result const outimg = await
1195
        inferenceFullVolumePhase2(
1196
          model,
1197
          slices_3d,
1198
          num_of_slices,
1199
          slice_height,
1200
          slice_width,
1201
          null,
1202
          modelEntry,
1203
          statData,
1204
          opts,
1205
          niftiImage
1206
        )
1207
      }
1208
    } else {
1209
      // -- In version 3.0.0 this function not used
1210
      callbackUI('', -1, 'inferenceSubVolumes() is not dead code?')
1211
    }
1212
  }
1213
}
1214
1215
async function enableProductionMode(textureF16Flag = true) {
1216
  // -- tf.setBackend('cpu')
1217
  tf.setBackend('webgl')
1218
  // -- tf.removeBackend('cpu')
1219
  // -- Calling enableProdMode() method
1220
  await tf.enableProdMode()
1221
  // -- Setting debug mode of the  environment
1222
  tf.env().set('DEBUG', false)
1223
  tf.env().set('WEBGL_FORCE_F16_TEXTURES', textureF16Flag)
1224
  // -- set this flag so that textures are deleted when tensors are disposed.
1225
  tf.env().set('WEBGL_DELETE_TEXTURE_THRESHOLD', -1)
1226
  // -- tf.env().set('WEBGL_PACK', false)
1227
  // -- Put ready after sets above
1228
  await tf.ready()
1229
  // -- Printing output
1230
  console.log('tf env() flags :', tf.env().flags)
1231
  console.log('tf env() features :', tf.env().features)
1232
  console.log('tf env total features: ', Object.keys(tf.env().features).length)
1233
  console.log('tf backend: ', tf.getBackend())
1234
}
1235
1236
async function runInferenceWW(opts, modelEntry, niftiHeader, niftiImage) {
1237
  const statData = []
1238
  statData.startTime = Date.now() // for common webworker/mainthread do not use performance.now()
1239
  callbackUI('Segmentation started', 0)
1240
  const batchSize = opts.batchSize
1241
  const numOfChan = opts.numOfChan
1242
  if (isNaN(batchSize) || batchSize !== 1) {
1243
    const errTxt = 'The batch Size for input shape must be 1'
1244
    callbackUI(errTxt, -1, errTxt)
1245
    return 0
1246
  }
1247
  if (isNaN(numOfChan) || numOfChan !== 1) {
1248
    const errTxt = 'The number of channels for input shape must be 1'
1249
    callbackUI(errTxt, -1, errTxt)
1250
    return 0
1251
  }
1252
  tf.engine().startScope()
1253
  console.log('Batch size: ', batchSize)
1254
  console.log('Num of Channels: ', numOfChan)
1255
  const model = await load_model(opts.rootURL + modelEntry.path)
1256
  await enableProductionMode(true)
1257
  statData.TF_Backend = tf.getBackend()
1258
  const modelObject = model
1259
  let batchInputShape = []
1260
  // free global variable of 16777216 voxel
1261
  // allOutputSlices3DCC1DimArray = []
1262
  // outputSceneRendered = false
1263
  // read input shape from model.json object
1264
  batchInputShape = modelObject.layers[0].batchInputShape
1265
  console.log(' Model batch input shape : ', batchInputShape)
1266
  // -- Verify input shape
1267
  if (batchInputShape.length !== 5) {
1268
    const errTxt = 'The model input shape must be 5D'
1269
    callbackUI(errTxt, -1, errTxt)
1270
    return 0
1271
  }
1272
  let batch_D, batch_H, batch_W
1273
  const slice_width = niftiHeader.dims[1]
1274
  const slice_height = niftiHeader.dims[2]
1275
  const num_of_slices = niftiHeader.dims[3]
1276
  const isChannelLast = await isModelChnlLast(modelObject)
1277
  if (isChannelLast) {
1278
    console.log('Model Channel Last')
1279
    if (isNaN(batchInputShape[4]) || batchInputShape[4] !== 1) {
1280
      const errTxt = 'The number of channels for input shape must be 1'
1281
      callbackUI(errTxt, -1, errTxt)
1282
      return 0
1283
    }
1284
    batch_D = batchInputShape[1]
1285
    batch_H = batchInputShape[2]
1286
    batch_W = batchInputShape[3]
1287
  } else {
1288
    console.log('Model Channel First')
1289
    if (isNaN(batchInputShape[1]) || batchInputShape[1] !== 1) {
1290
      const errTxt = 'The number of channels for input shape must be 1'
1291
      callbackUI(errTxt, -1, errTxt)
1292
      return 0
1293
    }
1294
    batch_D = batchInputShape[2]
1295
    batch_H = batchInputShape[3]
1296
    batch_W = batchInputShape[4]
1297
  }
1298
  // const input_shape = [batchSize, numOfChan, batch_D, batch_H, batch_W]
1299
  // --Check whether the model will make inference at once as FullVolumeModel
1300
  let isModelFullVol
1301
  if (batch_D === 256 && batch_H === 256 && batch_W === 256) {
1302
    isModelFullVol = true
1303
  } else {
1304
    isModelFullVol = false
1305
  }
1306
  statData.isModelFullVol = isModelFullVol
1307
  // Model output number of segmentations
1308
  let slices_3d = await getAllSlicesDataAsTF3D(num_of_slices, niftiHeader, niftiImage)
1309
  const transpose = modelEntry.enableTranspose
1310
  const enableCrop = modelEntry.enableCrop
1311
  if (isModelFullVol) {
1312
    if (enableCrop) {
1313
      // FullVolume with Crop option before inference ..
1314
      // pre-model to mask the volume, can also be null and the cropping will be on the MRI.
1315
      await inferenceFullVolumePhase1(
1316
        model,
1317
        slices_3d,
1318
        num_of_slices,
1319
        slice_height,
1320
        slice_width,
1321
        isModelFullVol,
1322
        modelEntry,
1323
        statData,
1324
        opts,
1325
        niftiHeader,
1326
        niftiImage
1327
      )
1328
    } else {
1329
      // Transpose MRI data to be match pytorch/keras input output
1330
      console.log('Cropping Disabled')
1331
1332
      if (transpose) {
1333
        slices_3d = slices_3d.transpose()
1334
        console.log('Input transposed')
1335
      } else {
1336
        console.log('Transpose NOT Enabled')
1337
      }
1338
1339
      const enableSeqConv = modelEntry.enableSeqConv
1340
1341
      if (enableSeqConv) {
1342
        callbackUI('', -1, 'inferenceFullVolumeSeqCovLayer() is not dead code?')
1343
      } else {
1344
        callbackUI('', -1, 'inferenceFullVolume() is not dead code?')
1345
      }
1346
    }
1347
  }
1348
}
1349
1350
self.addEventListener(
1351
  'message',
1352
  function (event) {
1353
    runInferenceWW(event.data.opts, event.data.modelEntry, event.data.niftiHeader, event.data.niftiImage)
1354
  },
1355
  false
1356
)