a b/v3/py2tfjs/blendbatchnorm.py
1
import torch
2
import torch.nn as nn
3
4
def fuse_bn_sequential(block):
5
    """
6
    This function takes a sequential block and fuses the batch normalization with convolution
7
    :param model: nn.Sequential. Source resnet model
8
    :return: nn.Sequential. Converted block
9
    """
10
    if not isinstance(block, nn.Sequential):
11
        return block
12
    stack = []
13
    for m in block.children():
14
        if isinstance(m, nn.BatchNorm3d):
15
            if isinstance(stack[-1], nn.Conv3d):
16
                bn_st_dict = m.state_dict()
17
                conv_st_dict = stack[-1].state_dict()
18
19
                # BatchNorm params
20
                eps = m.eps
21
                mu = bn_st_dict['running_mean']
22
                var = bn_st_dict['running_var']
23
                gamma = bn_st_dict['weight']
24
25
                if 'bias' in bn_st_dict:
26
                    beta = bn_st_dict['bias']
27
                else:
28
                    beta = torch.zeros(gamma.size(0)).float().to(gamma.device)
29
30
                # Conv params
31
                W = conv_st_dict['weight']
32
                if 'bias' in conv_st_dict:
33
                    bias = conv_st_dict['bias']
34
                else:
35
                    bias = torch.zeros(W.size(0)).float().to(gamma.device)
36
37
                denom = torch.sqrt(var + eps)
38
                b = beta - gamma.mul(mu).div(denom)
39
                A = gamma.div(denom)
40
                bias *= A
41
                A = A.expand_as(W.transpose(0, -1)).transpose(0, -1)
42
43
                W.mul_(A)
44
                bias.add_(b)
45
46
                stack[-1].weight.data.copy_(W)
47
                if stack[-1].bias is None:
48
                    stack[-1].bias = torch.nn.Parameter(bias)
49
                else:
50
                    stack[-1].bias.data.copy_(bias)
51
52
        else:
53
            stack.append(m)
54
55
    if len(stack) > 1:
56
        return nn.Sequential(*stack)
57
    else:
58
        return stack[0]
59
60
61
def fuse_bn_recursively(model):
62
    for module_name in model._modules:
63
        model._modules[module_name] = fuse_bn_sequential(model._modules[module_name])
64
        if len(model._modules[module_name]._modules) > 0:
65
            fuse_bn_recursively(model._modules[module_name])
66
67
    return model