[b86468]: / py2tfjs / conversion_example / blendbatchnorm.py

Download this file

70 lines (56 with data), 2.3 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# https://github.com/MIPT-Oulu/pytorch_bn_fusion/blob/master/bn_fusion.py
import torch
import torch.nn as nn
def fuse_bn_sequential(block):
"""
This function takes a sequential block and fuses the batch normalization with convolution
:param model: nn.Sequential. Source resnet model
:return: nn.Sequential. Converted block
"""
if not isinstance(block, nn.Sequential):
return block
stack = []
for m in block.children():
if isinstance(m, nn.BatchNorm3d):
if isinstance(stack[-1], nn.Conv3d):
bn_st_dict = m.state_dict()
conv_st_dict = stack[-1].state_dict()
# BatchNorm params
eps = m.eps
mu = bn_st_dict["running_mean"]
var = bn_st_dict["running_var"]
gamma = bn_st_dict["weight"]
if "bias" in bn_st_dict:
beta = bn_st_dict["bias"]
else:
beta = torch.zeros(gamma.size(0)).float().to(gamma.device)
# Conv params
W = conv_st_dict["weight"]
if "bias" in conv_st_dict:
bias = conv_st_dict["bias"]
else:
bias = torch.zeros(W.size(0)).float().to(gamma.device)
denom = torch.sqrt(var + eps)
b = beta - gamma.mul(mu).div(denom)
A = gamma.div(denom)
bias *= A
A = A.expand_as(W.transpose(0, -1)).transpose(0, -1)
W.mul_(A)
bias.add_(b)
stack[-1].weight.data.copy_(W)
if stack[-1].bias is None:
stack[-1].bias = torch.nn.Parameter(bias)
else:
stack[-1].bias.data.copy_(bias)
else:
stack.append(m)
if len(stack) > 1:
return nn.Sequential(*stack)
else:
return stack[0]
def fuse_bn_recursively(model):
for module_name in model._modules:
model._modules[module_name] = fuse_bn_sequential(model._modules[module_name])
if len(model._modules[module_name]._modules) > 0:
fuse_bn_recursively(model._modules[module_name])
return model