Switch to side-by-side view

--- a
+++ b/v3/py2tfjs/blendbatchnorm.py
@@ -0,0 +1,67 @@
+import torch
+import torch.nn as nn
+
+def fuse_bn_sequential(block):
+    """
+    This function takes a sequential block and fuses the batch normalization with convolution
+    :param model: nn.Sequential. Source resnet model
+    :return: nn.Sequential. Converted block
+    """
+    if not isinstance(block, nn.Sequential):
+        return block
+    stack = []
+    for m in block.children():
+        if isinstance(m, nn.BatchNorm3d):
+            if isinstance(stack[-1], nn.Conv3d):
+                bn_st_dict = m.state_dict()
+                conv_st_dict = stack[-1].state_dict()
+
+                # BatchNorm params
+                eps = m.eps
+                mu = bn_st_dict['running_mean']
+                var = bn_st_dict['running_var']
+                gamma = bn_st_dict['weight']
+
+                if 'bias' in bn_st_dict:
+                    beta = bn_st_dict['bias']
+                else:
+                    beta = torch.zeros(gamma.size(0)).float().to(gamma.device)
+
+                # Conv params
+                W = conv_st_dict['weight']
+                if 'bias' in conv_st_dict:
+                    bias = conv_st_dict['bias']
+                else:
+                    bias = torch.zeros(W.size(0)).float().to(gamma.device)
+
+                denom = torch.sqrt(var + eps)
+                b = beta - gamma.mul(mu).div(denom)
+                A = gamma.div(denom)
+                bias *= A
+                A = A.expand_as(W.transpose(0, -1)).transpose(0, -1)
+
+                W.mul_(A)
+                bias.add_(b)
+
+                stack[-1].weight.data.copy_(W)
+                if stack[-1].bias is None:
+                    stack[-1].bias = torch.nn.Parameter(bias)
+                else:
+                    stack[-1].bias.data.copy_(bias)
+
+        else:
+            stack.append(m)
+
+    if len(stack) > 1:
+        return nn.Sequential(*stack)
+    else:
+        return stack[0]
+
+
+def fuse_bn_recursively(model):
+    for module_name in model._modules:
+        model._modules[module_name] = fuse_bn_sequential(model._modules[module_name])
+        if len(model._modules[module_name]._modules) > 0:
+            fuse_bn_recursively(model._modules[module_name])
+
+    return model