a b/v3/py2tfjs/meshnet.py
1
import torch
2
import torch.nn as nn
3
import torch.nn.functional as F
4
5
6
MeshNet_38_or_64_kwargs = [
7
    {
8
        "in_channels": -1,
9
        "kernel_size": 3,
10
        "out_channels": 21,
11
        "padding": 1,
12
        "stride": 1,
13
        "dilation": 1,
14
    },
15
    {
16
        "in_channels": 21,
17
        "kernel_size": 3,
18
        "out_channels": 21,
19
        "padding": 1,
20
        "stride": 1,
21
        "dilation": 1,
22
    },
23
    {
24
        "in_channels": 21,
25
        "kernel_size": 3,
26
        "out_channels": 21,
27
        "padding": 1,
28
        "stride": 1,
29
        "dilation": 1,
30
    },
31
    {
32
        "in_channels": 21,
33
        "kernel_size": 3,
34
        "out_channels": 21,
35
        "padding": 2,
36
        "stride": 1,
37
        "dilation": 2,
38
    },
39
    {
40
        "in_channels": 21,
41
        "kernel_size": 3,
42
        "out_channels": 21,
43
        "padding": 4,
44
        "stride": 1,
45
        "dilation": 4,
46
    },
47
    {
48
        "in_channels": 21,
49
        "kernel_size": 3,
50
        "out_channels": 21,
51
        "padding": 8,
52
        "stride": 1,
53
        "dilation": 8,
54
    },
55
    {
56
        "in_channels": 21,
57
        "kernel_size": 3,
58
        "out_channels": 21,
59
        "padding": 1,
60
        "stride": 1,
61
        "dilation": 1,
62
    },
63
    {
64
        "in_channels": 21,
65
        "kernel_size": 1,
66
        "out_channels": -1,
67
        "padding": 0,
68
        "stride": 1,
69
        "dilation": 1,
70
    },
71
]
72
73
MeshNet_68_kwargs = [
74
    {
75
        "in_channels": -1,
76
        "kernel_size": 3,
77
        "out_channels": 71,
78
        "padding": 1,
79
        "stride": 1,
80
        "dilation": 1,
81
    },
82
    {
83
        "in_channels": 71,
84
        "kernel_size": 3,
85
        "out_channels": 71,
86
        "padding": 1,
87
        "stride": 1,
88
        "dilation": 1,
89
    },
90
    {
91
        "in_channels": 71,
92
        "kernel_size": 3,
93
        "out_channels": 71,
94
        "padding": 2,
95
        "stride": 1,
96
        "dilation": 2,
97
    },
98
    {
99
        "in_channels": 71,
100
        "kernel_size": 3,
101
        "out_channels": 71,
102
        "padding": 4,
103
        "stride": 1,
104
        "dilation": 4,
105
    },
106
    {
107
        "in_channels": 71,
108
        "kernel_size": 3,
109
        "out_channels": 71,
110
        "padding": 8,
111
        "stride": 1,
112
        "dilation": 8,
113
    },
114
    {
115
        "in_channels": 71,
116
        "kernel_size": 3,
117
        "out_channels": 71,
118
        "padding": 16,
119
        "stride": 1,
120
        "dilation": 16,
121
    },
122
    {
123
        "in_channels": 71,
124
        "kernel_size": 3,
125
        "out_channels": 71,
126
        "padding": 1,
127
        "stride": 1,
128
        "dilation": 1,
129
    },
130
    {
131
        "in_channels": 71,
132
        "kernel_size": 1,
133
        "out_channels": -1,
134
        "padding": 0,
135
        "stride": 1,
136
        "dilation": 1,
137
    },
138
]
139
140
141
def conv_w_bn_before_act(dropout_p=0, *args, **kwargs):
142
    """Configurable Conv block with Batchnorm and Dropout"""
143
    return nn.Sequential(
144
        nn.Conv3d(*args, **kwargs),
145
        nn.BatchNorm3d(kwargs["out_channels"]),
146
        nn.ReLU(inplace=True),
147
        nn.Dropout3d(dropout_p),
148
    )
149
150
151
def init_weights(model):
152
    """Set weights to be xavier normal for all Convs"""
153
    for m in model.modules():
154
        if isinstance(m, (nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
155
            nn.init.xavier_normal_(m.weight, gain=nn.init.calculate_gain("relu"))
156
            nn.init.constant_(m.bias, 0.0)
157
158
159
class MeshNet(nn.Module):
160
    """Configurable MeshNet from https://arxiv.org/pdf/1612.00940.pdf"""
161
162
    def __init__(self, n_channels, n_classes, large=True, dropout_p=0):
163
        """Init"""
164
        if large:
165
            params = MeshNet_68_kwargs
166
        else:
167
            params = MeshNet_38_or_64_kwargs
168
169
        super(MeshNet, self).__init__()
170
        params[0]["in_channels"] = n_channels
171
        params[-1]["out_channels"] = n_classes
172
        layers = [
173
            conv_w_bn_before_act(dropout_p=dropout_p, **block_kwargs)
174
            for block_kwargs in params[:-1]
175
        ]
176
        layers.append(nn.Conv3d(**params[-1]))
177
        self.model = nn.Sequential(*layers)
178
        init_weights(self.model,)
179
180
    def forward(self, x):
181
        """Forward pass"""
182
        x = self.model(x)
183
        return x