Diff of /v3/py2tfjs/pytorch2js.py [000000] .. [b86468]

Switch to unified view

a b/v3/py2tfjs/pytorch2js.py
1
from blendbatchnorm import fuse_bn_recursively
2
import onnx
3
from onnx2keras import onnx_to_keras
4
import torch
5
import numpy as np
6
from meshnet import MeshNet
7
import tensorflowjs as tfjs
8
from fixmodeljson import fixjson_file
9
10
def preprocess_image(img):
11
    """Unit interval preprocessing"""
12
    img = (img - img.min()) / (img.max() - img.min())
13
    return img
14
15
volume_shape = [256, 256, 256]
16
subvolume_shape = [38, 38, 38]
17
n_subvolumes = 1024
18
n_classes = 3
19
atlas_classes = 104
20
scube = 64
21
22
model_path = '../meshnet_gmwm_dropout_train.30_full.pth'
23
#'meshnet_gmwm_dropout_train.30_full.pth'#'meshnet_gmwm_train.30_full.pth'
24
25
device_name = "cuda:0" if torch.cuda.is_available() else "cpu"
26
device = torch.device(device_name)
27
meshnet_model = MeshNet(n_channels=1, n_classes=n_classes, large=False)
28
29
meshnet_model.load_state_dict(torch.load(model_path, map_location=device)['model_state_dict'])
30
31
meshnet_model.to(device)
32
mnm = fuse_bn_recursively(meshnet_model)
33
mnm.model.eval();
34
35
x = torch.randn(1, 1, scube, scube, scube, requires_grad=True)
36
torch.onnx.export(mnm, x.to(device), '/tmp/mnm_model_large.onnx', export_params=True, opset_version=13, do_constant_folding=True, input_names = ['input'], output_names = ['output'],dynamic_axes={'input' : {0 : 'batch_size'},'output' : {0 : 'batch_size'}})
37
onnx_model = onnx.load('/tmp/mnm_model_large.onnx')
38
k_model = onnx_to_keras(onnx_model, ['input'])
39
40
tfjs.converters.save_keras_model(k_model, '/tmp/mnm_gmwm_dropout256')
41
fixjson_file('/tmp/mnm_gmwm_dropout256/model.json', scube=scube)