--- a +++ b/v3/py2tfjs/meshnet.py @@ -0,0 +1,183 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +MeshNet_38_or_64_kwargs = [ + { + "in_channels": -1, + "kernel_size": 3, + "out_channels": 21, + "padding": 1, + "stride": 1, + "dilation": 1, + }, + { + "in_channels": 21, + "kernel_size": 3, + "out_channels": 21, + "padding": 1, + "stride": 1, + "dilation": 1, + }, + { + "in_channels": 21, + "kernel_size": 3, + "out_channels": 21, + "padding": 1, + "stride": 1, + "dilation": 1, + }, + { + "in_channels": 21, + "kernel_size": 3, + "out_channels": 21, + "padding": 2, + "stride": 1, + "dilation": 2, + }, + { + "in_channels": 21, + "kernel_size": 3, + "out_channels": 21, + "padding": 4, + "stride": 1, + "dilation": 4, + }, + { + "in_channels": 21, + "kernel_size": 3, + "out_channels": 21, + "padding": 8, + "stride": 1, + "dilation": 8, + }, + { + "in_channels": 21, + "kernel_size": 3, + "out_channels": 21, + "padding": 1, + "stride": 1, + "dilation": 1, + }, + { + "in_channels": 21, + "kernel_size": 1, + "out_channels": -1, + "padding": 0, + "stride": 1, + "dilation": 1, + }, +] + +MeshNet_68_kwargs = [ + { + "in_channels": -1, + "kernel_size": 3, + "out_channels": 71, + "padding": 1, + "stride": 1, + "dilation": 1, + }, + { + "in_channels": 71, + "kernel_size": 3, + "out_channels": 71, + "padding": 1, + "stride": 1, + "dilation": 1, + }, + { + "in_channels": 71, + "kernel_size": 3, + "out_channels": 71, + "padding": 2, + "stride": 1, + "dilation": 2, + }, + { + "in_channels": 71, + "kernel_size": 3, + "out_channels": 71, + "padding": 4, + "stride": 1, + "dilation": 4, + }, + { + "in_channels": 71, + "kernel_size": 3, + "out_channels": 71, + "padding": 8, + "stride": 1, + "dilation": 8, + }, + { + "in_channels": 71, + "kernel_size": 3, + "out_channels": 71, + "padding": 16, + "stride": 1, + "dilation": 16, + }, + { + "in_channels": 71, + "kernel_size": 3, + "out_channels": 71, + "padding": 1, + "stride": 1, + "dilation": 1, + }, + { + "in_channels": 71, + "kernel_size": 1, + "out_channels": -1, + "padding": 0, + "stride": 1, + "dilation": 1, + }, +] + + +def conv_w_bn_before_act(dropout_p=0, *args, **kwargs): + """Configurable Conv block with Batchnorm and Dropout""" + return nn.Sequential( + nn.Conv3d(*args, **kwargs), + nn.BatchNorm3d(kwargs["out_channels"]), + nn.ReLU(inplace=True), + nn.Dropout3d(dropout_p), + ) + + +def init_weights(model): + """Set weights to be xavier normal for all Convs""" + for m in model.modules(): + if isinstance(m, (nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d, nn.ConvTranspose3d)): + nn.init.xavier_normal_(m.weight, gain=nn.init.calculate_gain("relu")) + nn.init.constant_(m.bias, 0.0) + + +class MeshNet(nn.Module): + """Configurable MeshNet from https://arxiv.org/pdf/1612.00940.pdf""" + + def __init__(self, n_channels, n_classes, large=True, dropout_p=0): + """Init""" + if large: + params = MeshNet_68_kwargs + else: + params = MeshNet_38_or_64_kwargs + + super(MeshNet, self).__init__() + params[0]["in_channels"] = n_channels + params[-1]["out_channels"] = n_classes + layers = [ + conv_w_bn_before_act(dropout_p=dropout_p, **block_kwargs) + for block_kwargs in params[:-1] + ] + layers.append(nn.Conv3d(**params[-1])) + self.model = nn.Sequential(*layers) + init_weights(self.model,) + + def forward(self, x): + """Forward pass""" + x = self.model(x) + return x