Diff of /CellGraph/resnet.py [000000] .. [2095ed]

Switch to unified view

a b/CellGraph/resnet.py
1
'''
2
Properly implemented ResNet-s for CIFAR10 as described in paper [1].
3
4
The implementation and structure of this file is hugely influenced by [2]
5
which is implemented for ImageNet and doesn't have option A for identity.
6
Moreover, most of the implementations on the web is copy-paste from
7
torchvision's resnet and has wrong number of params.
8
9
Proper ResNet-s for CIFAR10 (for fair comparision and etc.) has following
10
number of layers and parameters:
11
12
name      | layers | params
13
ResNet20  |    20  | 0.27M
14
ResNet32  |    32  | 0.46M
15
ResNet44  |    44  | 0.66M
16
ResNet56  |    56  | 0.85M
17
ResNet110 |   110  |  1.7M
18
ResNet1202|  1202  | 19.4m
19
20
which this implementation indeed has.
21
22
Reference:
23
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
24
    Deep Residual Learning for Image Recognition. arXiv:1512.03385
25
[2] https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
26
27
If you use this implementation in you work, please don't forget to mention the
28
author, Yerlan Idelbayev.
29
'''
30
import torch
31
import torch.nn as nn
32
import torch.nn.functional as F
33
import torch.nn.init as init
34
35
from torch.autograd import Variable
36
37
__all__ = ['ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202']
38
39
def _weights_init(m):
40
    classname = m.__class__.__name__
41
    print(classname)
42
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
43
        init.kaiming_normal(m.weight)
44
45
class LambdaLayer(nn.Module):
46
    def __init__(self, lambd):
47
        super(LambdaLayer, self).__init__()
48
        self.lambd = lambd
49
50
    def forward(self, x):
51
        return self.lambd(x)
52
53
54
class BasicBlock(nn.Module):
55
    expansion = 1
56
57
    def __init__(self, in_planes, planes, stride=1, option='A'):
58
        super(BasicBlock, self).__init__()
59
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
60
        self.bn1 = nn.BatchNorm2d(planes)
61
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
62
        self.bn2 = nn.BatchNorm2d(planes)
63
64
        self.shortcut = nn.Sequential()
65
        if stride != 1 or in_planes != planes:
66
            if option == 'A':
67
                """
68
                For CIFAR10 ResNet paper uses option A.
69
                """
70
                self.shortcut = LambdaLayer(lambda x:
71
                                            F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
72
            elif option == 'B':
73
                self.shortcut = nn.Sequential(
74
                     nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
75
                     nn.BatchNorm2d(self.expansion * planes)
76
                )
77
78
    def forward(self, x):
79
        out = F.relu(self.bn1(self.conv1(x)))
80
        out = self.bn2(self.conv2(out))
81
        out += self.shortcut(x)
82
        out = F.relu(out)
83
        return out
84
85
86
class ResNet(nn.Module):
87
    def __init__(self, block, num_blocks, num_classes=10):
88
        super(ResNet, self).__init__()
89
        self.in_planes = 16
90
91
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
92
        self.bn1 = nn.BatchNorm2d(16)
93
        self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
94
        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
95
        self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
96
        self.linear = nn.Linear(64, num_classes)
97
98
        self.apply(_weights_init)
99
100
    def _make_layer(self, block, planes, num_blocks, stride):
101
        strides = [stride] + [1]*(num_blocks-1)
102
        layers = []
103
        for stride in strides:
104
            layers.append(block(self.in_planes, planes, stride))
105
            self.in_planes = planes * block.expansion
106
107
        return nn.Sequential(*layers)
108
109
110
    def forward(self, x):
111
        out = F.relu(self.bn1(self.conv1(x)))
112
        out = self.layer1(out)
113
        out = self.layer2(out)
114
        out = self.layer3(out)
115
        out = F.avg_pool2d(out, out.size()[3])
116
        out = out.view(out.size(0), -1)
117
        return out
118
119
120
121
def resnet20():
122
    return ResNet(BasicBlock, [3, 3, 3])
123
124
125
def resnet32():
126
    return ResNet(BasicBlock, [5, 5, 5])
127
128
129
def resnet44():
130
    return ResNet(BasicBlock, [7, 7, 7])
131
132
133
def resnet56():
134
    return ResNet(BasicBlock, [9, 9, 9])
135
136
137
def resnet110():
138
    return ResNet(BasicBlock, [18, 18, 18])
139
140
141
def resnet1202():
142
    return ResNet(BasicBlock, [200, 200, 200])
143
144
145
def test(net):
146
    import numpy as np
147
    total_params = 0
148
149
    for x in filter(lambda p: p.requires_grad, net.parameters()):
150
        total_params += np.prod(x.data.numpy().shape)
151
    print("Total number of params", total_params)
152
    print("Total layers", len(list(filter(lambda p: p.requires_grad and len(p.data.size())>1, net.parameters()))))
153
154
155
if __name__ == "__main__":
156
    for net_name in __all__:
157
        if net_name.startswith('resnet'):
158
            print(net_name)
159
            test(globals()[net_name]())
160
            print()