a b/py2tfjs/conversion_example/blendbatchnorm.py
1
# https://github.com/MIPT-Oulu/pytorch_bn_fusion/blob/master/bn_fusion.py
2
import torch
3
import torch.nn as nn
4
5
6
def fuse_bn_sequential(block):
7
    """
8
    This function takes a sequential block and fuses the batch normalization with convolution
9
    :param model: nn.Sequential. Source resnet model
10
    :return: nn.Sequential. Converted block
11
    """
12
    if not isinstance(block, nn.Sequential):
13
        return block
14
    stack = []
15
    for m in block.children():
16
        if isinstance(m, nn.BatchNorm3d):
17
            if isinstance(stack[-1], nn.Conv3d):
18
                bn_st_dict = m.state_dict()
19
                conv_st_dict = stack[-1].state_dict()
20
21
                # BatchNorm params
22
                eps = m.eps
23
                mu = bn_st_dict["running_mean"]
24
                var = bn_st_dict["running_var"]
25
                gamma = bn_st_dict["weight"]
26
27
                if "bias" in bn_st_dict:
28
                    beta = bn_st_dict["bias"]
29
                else:
30
                    beta = torch.zeros(gamma.size(0)).float().to(gamma.device)
31
32
                # Conv params
33
                W = conv_st_dict["weight"]
34
                if "bias" in conv_st_dict:
35
                    bias = conv_st_dict["bias"]
36
                else:
37
                    bias = torch.zeros(W.size(0)).float().to(gamma.device)
38
39
                denom = torch.sqrt(var + eps)
40
                b = beta - gamma.mul(mu).div(denom)
41
                A = gamma.div(denom)
42
                bias *= A
43
                A = A.expand_as(W.transpose(0, -1)).transpose(0, -1)
44
45
                W.mul_(A)
46
                bias.add_(b)
47
48
                stack[-1].weight.data.copy_(W)
49
                if stack[-1].bias is None:
50
                    stack[-1].bias = torch.nn.Parameter(bias)
51
                else:
52
                    stack[-1].bias.data.copy_(bias)
53
54
        else:
55
            stack.append(m)
56
57
    if len(stack) > 1:
58
        return nn.Sequential(*stack)
59
    else:
60
        return stack[0]
61
62
63
def fuse_bn_recursively(model):
64
    for module_name in model._modules:
65
        model._modules[module_name] = fuse_bn_sequential(model._modules[module_name])
66
        if len(model._modules[module_name]._modules) > 0:
67
            fuse_bn_recursively(model._modules[module_name])
68
69
    return model