Switch to unified view

a b/v3/py2tfjs/conversion_example/meshnet.py
1
import torch
2
import torch.nn as nn
3
from torch.utils.checkpoint import checkpoint_sequential
4
import json
5
6
7
def set_channel_num(config, in_channels, n_classes, channels):
8
    """
9
    Takes a configuration json for a convolutional neural network of MeshNet architecture and changes it to have the specified number of input channels, output classes, and number of channels that each layer except the input and output layers have.
10
11
    Args:
12
        config (dict): The configuration json for the network.
13
        in_channels (int): The number of input channels.
14
        n_classes (int): The number of output classes.
15
        channels (int): The number of channels that each layer except the input and output layers will have.
16
17
    Returns:
18
        dict: The updated configuration json.
19
    """
20
    # input layer
21
    config["layers"][0]["in_channels"] = in_channels
22
    config["layers"][0]["out_channels"] = channels
23
24
    # output layer
25
    config["layers"][-1]["in_channels"] = channels
26
    config["layers"][-1]["out_channels"] = n_classes
27
28
    # hidden layers
29
    for layer in config["layers"][1:-1]:
30
        layer["in_channels"] = layer["out_channels"] = channels
31
32
    return config
33
34
35
def construct_layer(dropout_p=0, bnorm=True, gelu=False, *args, **kwargs):
36
    """Constructs a configurable Convolutional block with Batch Normalization and Dropout.
37
38
    Args:
39
    dropout_p (float): Dropout probability. Default is 0.
40
    bnorm (bool): Whether to include batch normalization. Default is True.
41
    gelu (bool): Whether to use GELU activation. Default is False.
42
    *args: Additional positional arguments to pass to nn.Conv3d.
43
    **kwargs: Additional keyword arguments to pass to nn.Conv3d.
44
45
    Returns:
46
    nn.Sequential: A sequential container of Convolutional block with optional Batch Normalization and Dropout.
47
    """
48
    layers = []
49
    layers.append(nn.Conv3d(*args, **kwargs))
50
    if bnorm:
51
        # track_running_stats=False is needed to run the forward mode AD
52
        layers.append(
53
            nn.BatchNorm3d(kwargs["out_channels"], track_running_stats=True)
54
        )
55
    layers.append(nn.ELU(inplace=True) if gelu else nn.ReLU(inplace=True))
56
    if dropout_p > 0:
57
        layers.append(nn.Dropout3d(dropout_p))
58
    return nn.Sequential(*layers)
59
60
61
def init_weights(model):
62
    """Set weights to be xavier normal for all Convs"""
63
    for m in model.modules():
64
        if isinstance(
65
            m, (nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d, nn.ConvTranspose3d)
66
        ):
67
            # nn.init.xavier_normal_(m.weight, gain=nn.init.calculate_gain("relu"))
68
            nn.init.kaiming_normal_(
69
                m.weight, mode="fan_out", nonlinearity="relu"
70
            )
71
            if m.bias is not None:
72
                nn.init.constant_(m.bias, 0.0)
73
74
75
class MeshNet(nn.Module):
76
    """Configurable MeshNet from https://arxiv.org/pdf/1612.00940.pdf"""
77
78
    def __init__(self, in_channels, n_classes, channels, config_file, fat=None):
79
        """Init"""
80
        with open(config_file, "r") as f:
81
            config = set_channel_num(
82
                json.load(f), in_channels, n_classes, channels
83
            )
84
85
        if fat is not None:
86
            chn = int(channels * 1.5)
87
            if fat in {"i", "io"}:
88
                config["layers"][0]["out_channels"] = chn
89
                config["layers"][1]["in_channels"] = chn
90
            if fat == "io":
91
                config["layers"][-1]["in_channels"] = chn
92
                config["layers"][-2]["out_channels"] = chn
93
            if fat == "b":
94
                config["layers"][3]["out_channels"] = chn
95
                config["layers"][4]["in_channels"] = chn
96
97
        super(MeshNet, self).__init__()
98
99
        layers = [
100
            construct_layer(
101
                dropout_p=config["dropout_p"],
102
                bnorm=config["bnorm"],
103
                gelu=config["gelu"],
104
                **block_kwargs,
105
            )
106
            for block_kwargs in config["layers"]
107
        ]
108
        layers[-1] = layers[-1][0]
109
        self.model = nn.Sequential(*layers)
110
        init_weights(self.model)
111
112
    def forward(self, x):
113
        """Forward pass"""
114
        x = self.model(x)
115
        return x
116
117
118
class enMesh_checkpoint(MeshNet):
119
    def train_forward(self, x):
120
        y = x
121
        y.requires_grad_()
122
        y = checkpoint_sequential(
123
            self.model, len(self.model), y, preserve_rng_state=False
124
        )
125
        return y
126
127
    def eval_forward(self, x):
128
        """Forward pass"""
129
        self.model.eval()
130
        with torch.inference_mode():
131
            x = self.model(x)
132
        return x
133
134
    def forward(self, x):
135
        if self.training:
136
            return self.train_forward(x)
137
        else:
138
            return self.eval_forward(x)