Diff of /utils/init_net.py [000000] .. [98e649]

Switch to side-by-side view

--- a
+++ b/utils/init_net.py
@@ -0,0 +1,24 @@
+import torch.nn as nn
+
+def init_weights(net, init_type='normal', gain=0.02):
+    def init_func(m):
+        classname = m.__class__.__name__
+        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
+            if init_type == 'normal':
+                nn.init.normal_(m.weight.data, 0.0, gain)
+            elif init_type == 'xavier':
+                nn.init.xavier_normal_(m.weight.data, gain=gain)
+            elif init_type == 'kaiming':
+                nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
+            elif init_type == 'orthogonal':
+                nn.init.orthogonal_(m.weight.data, gain=gain)
+            else:
+                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
+            if hasattr(m, 'bias') and m.bias is not None:
+                nn.init.constant_(m.bias.data, 0.0)
+        elif classname.find('BatchNorm2d') != -1:
+            nn.init.normal_(m.weight.data, 1.0, gain)
+            nn.init.constant_(m.bias.data, 0.0)
+
+    print('initialize network with %s' % init_type)
+    net.apply(init_func)
\ No newline at end of file