a b/utils/init_net.py
1
import torch.nn as nn
2
3
def init_weights(net, init_type='normal', gain=0.02):
4
    def init_func(m):
5
        classname = m.__class__.__name__
6
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
7
            if init_type == 'normal':
8
                nn.init.normal_(m.weight.data, 0.0, gain)
9
            elif init_type == 'xavier':
10
                nn.init.xavier_normal_(m.weight.data, gain=gain)
11
            elif init_type == 'kaiming':
12
                nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
13
            elif init_type == 'orthogonal':
14
                nn.init.orthogonal_(m.weight.data, gain=gain)
15
            else:
16
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
17
            if hasattr(m, 'bias') and m.bias is not None:
18
                nn.init.constant_(m.bias.data, 0.0)
19
        elif classname.find('BatchNorm2d') != -1:
20
            nn.init.normal_(m.weight.data, 1.0, gain)
21
            nn.init.constant_(m.bias.data, 0.0)
22
23
    print('initialize network with %s' % init_type)
24
    net.apply(init_func)