|
a |
|
b/v3/py2tfjs/conversion_example/convert.py |
|
|
1 |
import torch |
|
|
2 |
|
|
|
3 |
from blendbatchnorm import fuse_bn_recursively |
|
|
4 |
from meshnet2tfjs import meshnet2tfjs |
|
|
5 |
|
|
|
6 |
from meshnet import ( |
|
|
7 |
MeshNet, |
|
|
8 |
enMesh_checkpoint, |
|
|
9 |
) |
|
|
10 |
|
|
|
11 |
device_name = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
|
12 |
device = torch.device(device_name) |
|
|
13 |
|
|
|
14 |
|
|
|
15 |
def preprocess_image(img, qmin=0.01, qmax=0.99): |
|
|
16 |
"""Unit interval preprocessing""" |
|
|
17 |
img = (img - img.quantile(qmin)) / (img.quantile(qmax) - img.quantile(qmin)) |
|
|
18 |
return img |
|
|
19 |
|
|
|
20 |
|
|
|
21 |
# def preprocess_image(img): |
|
|
22 |
# """Unit interval preprocessing""" |
|
|
23 |
# img = (img - img.min()) / (img.max() - img.min()) |
|
|
24 |
# return img |
|
|
25 |
|
|
|
26 |
|
|
|
27 |
# specify how many classes does the model predict |
|
|
28 |
n_classes = 3 |
|
|
29 |
# specify the architecture |
|
|
30 |
config_file = "modelAE.json" |
|
|
31 |
# how many channels does the saved model have |
|
|
32 |
model_channels = 15 |
|
|
33 |
# path to the saved model |
|
|
34 |
model_path = "model.pth" |
|
|
35 |
# tfjs model output directory |
|
|
36 |
tfjs_model_dir = "model_tfjs" |
|
|
37 |
|
|
|
38 |
meshnet_model = enMesh_checkpoint( |
|
|
39 |
in_channels=1, |
|
|
40 |
n_classes=n_classes, |
|
|
41 |
channels=model_channels, |
|
|
42 |
config_file=config_file, |
|
|
43 |
) |
|
|
44 |
|
|
|
45 |
checkpoint = torch.load(model_path) |
|
|
46 |
meshnet_model.load_state_dict(checkpoint) |
|
|
47 |
|
|
|
48 |
meshnet_model.eval() |
|
|
49 |
|
|
|
50 |
meshnet_model.to(device) |
|
|
51 |
mnm = fuse_bn_recursively(meshnet_model) |
|
|
52 |
del meshnet_model |
|
|
53 |
mnm.model.eval() |
|
|
54 |
|
|
|
55 |
|
|
|
56 |
meshnet2tfjs(mnm, tfjs_model_dir) |