|
a |
|
b/v3/py2tfjs/conversion_example/meshnet.py |
|
|
1 |
import torch |
|
|
2 |
import torch.nn as nn |
|
|
3 |
from torch.utils.checkpoint import checkpoint_sequential |
|
|
4 |
import json |
|
|
5 |
|
|
|
6 |
|
|
|
7 |
def set_channel_num(config, in_channels, n_classes, channels): |
|
|
8 |
""" |
|
|
9 |
Takes a configuration json for a convolutional neural network of MeshNet architecture and changes it to have the specified number of input channels, output classes, and number of channels that each layer except the input and output layers have. |
|
|
10 |
|
|
|
11 |
Args: |
|
|
12 |
config (dict): The configuration json for the network. |
|
|
13 |
in_channels (int): The number of input channels. |
|
|
14 |
n_classes (int): The number of output classes. |
|
|
15 |
channels (int): The number of channels that each layer except the input and output layers will have. |
|
|
16 |
|
|
|
17 |
Returns: |
|
|
18 |
dict: The updated configuration json. |
|
|
19 |
""" |
|
|
20 |
# input layer |
|
|
21 |
config["layers"][0]["in_channels"] = in_channels |
|
|
22 |
config["layers"][0]["out_channels"] = channels |
|
|
23 |
|
|
|
24 |
# output layer |
|
|
25 |
config["layers"][-1]["in_channels"] = channels |
|
|
26 |
config["layers"][-1]["out_channels"] = n_classes |
|
|
27 |
|
|
|
28 |
# hidden layers |
|
|
29 |
for layer in config["layers"][1:-1]: |
|
|
30 |
layer["in_channels"] = layer["out_channels"] = channels |
|
|
31 |
|
|
|
32 |
return config |
|
|
33 |
|
|
|
34 |
|
|
|
35 |
def construct_layer(dropout_p=0, bnorm=True, gelu=False, *args, **kwargs): |
|
|
36 |
"""Constructs a configurable Convolutional block with Batch Normalization and Dropout. |
|
|
37 |
|
|
|
38 |
Args: |
|
|
39 |
dropout_p (float): Dropout probability. Default is 0. |
|
|
40 |
bnorm (bool): Whether to include batch normalization. Default is True. |
|
|
41 |
gelu (bool): Whether to use GELU activation. Default is False. |
|
|
42 |
*args: Additional positional arguments to pass to nn.Conv3d. |
|
|
43 |
**kwargs: Additional keyword arguments to pass to nn.Conv3d. |
|
|
44 |
|
|
|
45 |
Returns: |
|
|
46 |
nn.Sequential: A sequential container of Convolutional block with optional Batch Normalization and Dropout. |
|
|
47 |
""" |
|
|
48 |
layers = [] |
|
|
49 |
layers.append(nn.Conv3d(*args, **kwargs)) |
|
|
50 |
if bnorm: |
|
|
51 |
# track_running_stats=False is needed to run the forward mode AD |
|
|
52 |
layers.append( |
|
|
53 |
nn.BatchNorm3d(kwargs["out_channels"], track_running_stats=True) |
|
|
54 |
) |
|
|
55 |
layers.append(nn.ELU(inplace=True) if gelu else nn.ReLU(inplace=True)) |
|
|
56 |
if dropout_p > 0: |
|
|
57 |
layers.append(nn.Dropout3d(dropout_p)) |
|
|
58 |
return nn.Sequential(*layers) |
|
|
59 |
|
|
|
60 |
|
|
|
61 |
def init_weights(model): |
|
|
62 |
"""Set weights to be xavier normal for all Convs""" |
|
|
63 |
for m in model.modules(): |
|
|
64 |
if isinstance( |
|
|
65 |
m, (nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d, nn.ConvTranspose3d) |
|
|
66 |
): |
|
|
67 |
# nn.init.xavier_normal_(m.weight, gain=nn.init.calculate_gain("relu")) |
|
|
68 |
nn.init.kaiming_normal_( |
|
|
69 |
m.weight, mode="fan_out", nonlinearity="relu" |
|
|
70 |
) |
|
|
71 |
if m.bias is not None: |
|
|
72 |
nn.init.constant_(m.bias, 0.0) |
|
|
73 |
|
|
|
74 |
|
|
|
75 |
class MeshNet(nn.Module): |
|
|
76 |
"""Configurable MeshNet from https://arxiv.org/pdf/1612.00940.pdf""" |
|
|
77 |
|
|
|
78 |
def __init__(self, in_channels, n_classes, channels, config_file, fat=None): |
|
|
79 |
"""Init""" |
|
|
80 |
with open(config_file, "r") as f: |
|
|
81 |
config = set_channel_num( |
|
|
82 |
json.load(f), in_channels, n_classes, channels |
|
|
83 |
) |
|
|
84 |
|
|
|
85 |
if fat is not None: |
|
|
86 |
chn = int(channels * 1.5) |
|
|
87 |
if fat in {"i", "io"}: |
|
|
88 |
config["layers"][0]["out_channels"] = chn |
|
|
89 |
config["layers"][1]["in_channels"] = chn |
|
|
90 |
if fat == "io": |
|
|
91 |
config["layers"][-1]["in_channels"] = chn |
|
|
92 |
config["layers"][-2]["out_channels"] = chn |
|
|
93 |
if fat == "b": |
|
|
94 |
config["layers"][3]["out_channels"] = chn |
|
|
95 |
config["layers"][4]["in_channels"] = chn |
|
|
96 |
|
|
|
97 |
super(MeshNet, self).__init__() |
|
|
98 |
|
|
|
99 |
layers = [ |
|
|
100 |
construct_layer( |
|
|
101 |
dropout_p=config["dropout_p"], |
|
|
102 |
bnorm=config["bnorm"], |
|
|
103 |
gelu=config["gelu"], |
|
|
104 |
**block_kwargs, |
|
|
105 |
) |
|
|
106 |
for block_kwargs in config["layers"] |
|
|
107 |
] |
|
|
108 |
layers[-1] = layers[-1][0] |
|
|
109 |
self.model = nn.Sequential(*layers) |
|
|
110 |
init_weights(self.model) |
|
|
111 |
|
|
|
112 |
def forward(self, x): |
|
|
113 |
"""Forward pass""" |
|
|
114 |
x = self.model(x) |
|
|
115 |
return x |
|
|
116 |
|
|
|
117 |
|
|
|
118 |
class enMesh_checkpoint(MeshNet): |
|
|
119 |
def train_forward(self, x): |
|
|
120 |
y = x |
|
|
121 |
y.requires_grad_() |
|
|
122 |
y = checkpoint_sequential( |
|
|
123 |
self.model, len(self.model), y, preserve_rng_state=False |
|
|
124 |
) |
|
|
125 |
return y |
|
|
126 |
|
|
|
127 |
def eval_forward(self, x): |
|
|
128 |
"""Forward pass""" |
|
|
129 |
self.model.eval() |
|
|
130 |
with torch.inference_mode(): |
|
|
131 |
x = self.model(x) |
|
|
132 |
return x |
|
|
133 |
|
|
|
134 |
def forward(self, x): |
|
|
135 |
if self.training: |
|
|
136 |
return self.train_forward(x) |
|
|
137 |
else: |
|
|
138 |
return self.eval_forward(x) |