[b86468]: / py2tfjs / conversion_example / convert.py

Download this file

57 lines (41 with data), 1.3 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
from blendbatchnorm import fuse_bn_recursively
from meshnet2tfjs import meshnet2tfjs
from meshnet import (
MeshNet,
enMesh_checkpoint,
)
device_name = "cuda:0" if torch.cuda.is_available() else "cpu"
device = torch.device(device_name)
def preprocess_image(img, qmin=0.01, qmax=0.99):
"""Unit interval preprocessing"""
img = (img - img.quantile(qmin)) / (img.quantile(qmax) - img.quantile(qmin))
return img
# def preprocess_image(img):
# """Unit interval preprocessing"""
# img = (img - img.min()) / (img.max() - img.min())
# return img
# specify how many classes does the model predict
n_classes = 3
# specify the architecture
config_file = "modelAE.json"
# how many channels does the saved model have
model_channels = 15
# path to the saved model
model_path = "model.pth"
# tfjs model output directory
tfjs_model_dir = "model_tfjs"
meshnet_model = enMesh_checkpoint(
in_channels=1,
n_classes=n_classes,
channels=model_channels,
config_file=config_file,
)
checkpoint = torch.load(model_path)
meshnet_model.load_state_dict(checkpoint)
meshnet_model.eval()
meshnet_model.to(device)
mnm = fuse_bn_recursively(meshnet_model)
del meshnet_model
mnm.model.eval()
meshnet2tfjs(mnm, tfjs_model_dir)