|
a |
|
b/v3/py2tfjs/pytorch2js.py |
|
|
1 |
from blendbatchnorm import fuse_bn_recursively |
|
|
2 |
import onnx |
|
|
3 |
from onnx2keras import onnx_to_keras |
|
|
4 |
import torch |
|
|
5 |
import numpy as np |
|
|
6 |
from meshnet import MeshNet |
|
|
7 |
import tensorflowjs as tfjs |
|
|
8 |
from fixmodeljson import fixjson_file |
|
|
9 |
|
|
|
10 |
def preprocess_image(img): |
|
|
11 |
"""Unit interval preprocessing""" |
|
|
12 |
img = (img - img.min()) / (img.max() - img.min()) |
|
|
13 |
return img |
|
|
14 |
|
|
|
15 |
volume_shape = [256, 256, 256] |
|
|
16 |
subvolume_shape = [38, 38, 38] |
|
|
17 |
n_subvolumes = 1024 |
|
|
18 |
n_classes = 3 |
|
|
19 |
atlas_classes = 104 |
|
|
20 |
scube = 64 |
|
|
21 |
|
|
|
22 |
model_path = '../meshnet_gmwm_dropout_train.30_full.pth' |
|
|
23 |
#'meshnet_gmwm_dropout_train.30_full.pth'#'meshnet_gmwm_train.30_full.pth' |
|
|
24 |
|
|
|
25 |
device_name = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
|
26 |
device = torch.device(device_name) |
|
|
27 |
meshnet_model = MeshNet(n_channels=1, n_classes=n_classes, large=False) |
|
|
28 |
|
|
|
29 |
meshnet_model.load_state_dict(torch.load(model_path, map_location=device)['model_state_dict']) |
|
|
30 |
|
|
|
31 |
meshnet_model.to(device) |
|
|
32 |
mnm = fuse_bn_recursively(meshnet_model) |
|
|
33 |
mnm.model.eval(); |
|
|
34 |
|
|
|
35 |
x = torch.randn(1, 1, scube, scube, scube, requires_grad=True) |
|
|
36 |
torch.onnx.export(mnm, x.to(device), '/tmp/mnm_model_large.onnx', export_params=True, opset_version=13, do_constant_folding=True, input_names = ['input'], output_names = ['output'],dynamic_axes={'input' : {0 : 'batch_size'},'output' : {0 : 'batch_size'}}) |
|
|
37 |
onnx_model = onnx.load('/tmp/mnm_model_large.onnx') |
|
|
38 |
k_model = onnx_to_keras(onnx_model, ['input']) |
|
|
39 |
|
|
|
40 |
tfjs.converters.save_keras_model(k_model, '/tmp/mnm_gmwm_dropout256') |
|
|
41 |
fixjson_file('/tmp/mnm_gmwm_dropout256/model.json', scube=scube) |