[b86468]: / v3 / py2tfjs / pytorch2js.py

Download this file

42 lines (33 with data), 1.5 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from blendbatchnorm import fuse_bn_recursively
import onnx
from onnx2keras import onnx_to_keras
import torch
import numpy as np
from meshnet import MeshNet
import tensorflowjs as tfjs
from fixmodeljson import fixjson_file
def preprocess_image(img):
"""Unit interval preprocessing"""
img = (img - img.min()) / (img.max() - img.min())
return img
volume_shape = [256, 256, 256]
subvolume_shape = [38, 38, 38]
n_subvolumes = 1024
n_classes = 3
atlas_classes = 104
scube = 64
model_path = '../meshnet_gmwm_dropout_train.30_full.pth'
#'meshnet_gmwm_dropout_train.30_full.pth'#'meshnet_gmwm_train.30_full.pth'
device_name = "cuda:0" if torch.cuda.is_available() else "cpu"
device = torch.device(device_name)
meshnet_model = MeshNet(n_channels=1, n_classes=n_classes, large=False)
meshnet_model.load_state_dict(torch.load(model_path, map_location=device)['model_state_dict'])
meshnet_model.to(device)
mnm = fuse_bn_recursively(meshnet_model)
mnm.model.eval();
x = torch.randn(1, 1, scube, scube, scube, requires_grad=True)
torch.onnx.export(mnm, x.to(device), '/tmp/mnm_model_large.onnx', export_params=True, opset_version=13, do_constant_folding=True, input_names = ['input'], output_names = ['output'],dynamic_axes={'input' : {0 : 'batch_size'},'output' : {0 : 'batch_size'}})
onnx_model = onnx.load('/tmp/mnm_model_large.onnx')
k_model = onnx_to_keras(onnx_model, ['input'])
tfjs.converters.save_keras_model(k_model, '/tmp/mnm_gmwm_dropout256')
fixjson_file('/tmp/mnm_gmwm_dropout256/model.json', scube=scube)