Switch to unified view

a b/v3/py2tfjs/conversion_example/convert.py
1
import torch
2
3
from blendbatchnorm import fuse_bn_recursively
4
from meshnet2tfjs import meshnet2tfjs
5
6
from meshnet import (
7
    MeshNet,
8
    enMesh_checkpoint,
9
)
10
11
device_name = "cuda:0" if torch.cuda.is_available() else "cpu"
12
device = torch.device(device_name)
13
14
15
def preprocess_image(img, qmin=0.01, qmax=0.99):
16
    """Unit interval preprocessing"""
17
    img = (img - img.quantile(qmin)) / (img.quantile(qmax) - img.quantile(qmin))
18
    return img
19
20
21
# def preprocess_image(img):
22
#     """Unit interval preprocessing"""
23
#     img = (img - img.min()) / (img.max() - img.min())
24
#     return img
25
26
27
# specify how many classes does the model predict
28
n_classes = 3
29
# specify the architecture
30
config_file = "modelAE.json"
31
# how many channels does the saved model have
32
model_channels = 15
33
# path to the saved model
34
model_path = "model.pth"
35
# tfjs model output directory
36
tfjs_model_dir = "model_tfjs"
37
38
meshnet_model = enMesh_checkpoint(
39
    in_channels=1,
40
    n_classes=n_classes,
41
    channels=model_channels,
42
    config_file=config_file,
43
)
44
45
checkpoint = torch.load(model_path)
46
meshnet_model.load_state_dict(checkpoint)
47
48
meshnet_model.eval()
49
50
meshnet_model.to(device)
51
mnm = fuse_bn_recursively(meshnet_model)
52
del meshnet_model
53
mnm.model.eval()
54
55
56
meshnet2tfjs(mnm, tfjs_model_dir)