|
a |
|
b/CellGraph/resnet.py |
|
|
1 |
''' |
|
|
2 |
Properly implemented ResNet-s for CIFAR10 as described in paper [1]. |
|
|
3 |
|
|
|
4 |
The implementation and structure of this file is hugely influenced by [2] |
|
|
5 |
which is implemented for ImageNet and doesn't have option A for identity. |
|
|
6 |
Moreover, most of the implementations on the web is copy-paste from |
|
|
7 |
torchvision's resnet and has wrong number of params. |
|
|
8 |
|
|
|
9 |
Proper ResNet-s for CIFAR10 (for fair comparision and etc.) has following |
|
|
10 |
number of layers and parameters: |
|
|
11 |
|
|
|
12 |
name | layers | params |
|
|
13 |
ResNet20 | 20 | 0.27M |
|
|
14 |
ResNet32 | 32 | 0.46M |
|
|
15 |
ResNet44 | 44 | 0.66M |
|
|
16 |
ResNet56 | 56 | 0.85M |
|
|
17 |
ResNet110 | 110 | 1.7M |
|
|
18 |
ResNet1202| 1202 | 19.4m |
|
|
19 |
|
|
|
20 |
which this implementation indeed has. |
|
|
21 |
|
|
|
22 |
Reference: |
|
|
23 |
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun |
|
|
24 |
Deep Residual Learning for Image Recognition. arXiv:1512.03385 |
|
|
25 |
[2] https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py |
|
|
26 |
|
|
|
27 |
If you use this implementation in you work, please don't forget to mention the |
|
|
28 |
author, Yerlan Idelbayev. |
|
|
29 |
''' |
|
|
30 |
import torch |
|
|
31 |
import torch.nn as nn |
|
|
32 |
import torch.nn.functional as F |
|
|
33 |
import torch.nn.init as init |
|
|
34 |
|
|
|
35 |
from torch.autograd import Variable |
|
|
36 |
|
|
|
37 |
__all__ = ['ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202'] |
|
|
38 |
|
|
|
39 |
def _weights_init(m): |
|
|
40 |
classname = m.__class__.__name__ |
|
|
41 |
print(classname) |
|
|
42 |
if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): |
|
|
43 |
init.kaiming_normal(m.weight) |
|
|
44 |
|
|
|
45 |
class LambdaLayer(nn.Module): |
|
|
46 |
def __init__(self, lambd): |
|
|
47 |
super(LambdaLayer, self).__init__() |
|
|
48 |
self.lambd = lambd |
|
|
49 |
|
|
|
50 |
def forward(self, x): |
|
|
51 |
return self.lambd(x) |
|
|
52 |
|
|
|
53 |
|
|
|
54 |
class BasicBlock(nn.Module): |
|
|
55 |
expansion = 1 |
|
|
56 |
|
|
|
57 |
def __init__(self, in_planes, planes, stride=1, option='A'): |
|
|
58 |
super(BasicBlock, self).__init__() |
|
|
59 |
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) |
|
|
60 |
self.bn1 = nn.BatchNorm2d(planes) |
|
|
61 |
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) |
|
|
62 |
self.bn2 = nn.BatchNorm2d(planes) |
|
|
63 |
|
|
|
64 |
self.shortcut = nn.Sequential() |
|
|
65 |
if stride != 1 or in_planes != planes: |
|
|
66 |
if option == 'A': |
|
|
67 |
""" |
|
|
68 |
For CIFAR10 ResNet paper uses option A. |
|
|
69 |
""" |
|
|
70 |
self.shortcut = LambdaLayer(lambda x: |
|
|
71 |
F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)) |
|
|
72 |
elif option == 'B': |
|
|
73 |
self.shortcut = nn.Sequential( |
|
|
74 |
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), |
|
|
75 |
nn.BatchNorm2d(self.expansion * planes) |
|
|
76 |
) |
|
|
77 |
|
|
|
78 |
def forward(self, x): |
|
|
79 |
out = F.relu(self.bn1(self.conv1(x))) |
|
|
80 |
out = self.bn2(self.conv2(out)) |
|
|
81 |
out += self.shortcut(x) |
|
|
82 |
out = F.relu(out) |
|
|
83 |
return out |
|
|
84 |
|
|
|
85 |
|
|
|
86 |
class ResNet(nn.Module): |
|
|
87 |
def __init__(self, block, num_blocks, num_classes=10): |
|
|
88 |
super(ResNet, self).__init__() |
|
|
89 |
self.in_planes = 16 |
|
|
90 |
|
|
|
91 |
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) |
|
|
92 |
self.bn1 = nn.BatchNorm2d(16) |
|
|
93 |
self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) |
|
|
94 |
self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) |
|
|
95 |
self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) |
|
|
96 |
self.linear = nn.Linear(64, num_classes) |
|
|
97 |
|
|
|
98 |
self.apply(_weights_init) |
|
|
99 |
|
|
|
100 |
def _make_layer(self, block, planes, num_blocks, stride): |
|
|
101 |
strides = [stride] + [1]*(num_blocks-1) |
|
|
102 |
layers = [] |
|
|
103 |
for stride in strides: |
|
|
104 |
layers.append(block(self.in_planes, planes, stride)) |
|
|
105 |
self.in_planes = planes * block.expansion |
|
|
106 |
|
|
|
107 |
return nn.Sequential(*layers) |
|
|
108 |
|
|
|
109 |
|
|
|
110 |
def forward(self, x): |
|
|
111 |
out = F.relu(self.bn1(self.conv1(x))) |
|
|
112 |
out = self.layer1(out) |
|
|
113 |
out = self.layer2(out) |
|
|
114 |
out = self.layer3(out) |
|
|
115 |
out = F.avg_pool2d(out, out.size()[3]) |
|
|
116 |
out = out.view(out.size(0), -1) |
|
|
117 |
return out |
|
|
118 |
|
|
|
119 |
|
|
|
120 |
|
|
|
121 |
def resnet20(): |
|
|
122 |
return ResNet(BasicBlock, [3, 3, 3]) |
|
|
123 |
|
|
|
124 |
|
|
|
125 |
def resnet32(): |
|
|
126 |
return ResNet(BasicBlock, [5, 5, 5]) |
|
|
127 |
|
|
|
128 |
|
|
|
129 |
def resnet44(): |
|
|
130 |
return ResNet(BasicBlock, [7, 7, 7]) |
|
|
131 |
|
|
|
132 |
|
|
|
133 |
def resnet56(): |
|
|
134 |
return ResNet(BasicBlock, [9, 9, 9]) |
|
|
135 |
|
|
|
136 |
|
|
|
137 |
def resnet110(): |
|
|
138 |
return ResNet(BasicBlock, [18, 18, 18]) |
|
|
139 |
|
|
|
140 |
|
|
|
141 |
def resnet1202(): |
|
|
142 |
return ResNet(BasicBlock, [200, 200, 200]) |
|
|
143 |
|
|
|
144 |
|
|
|
145 |
def test(net): |
|
|
146 |
import numpy as np |
|
|
147 |
total_params = 0 |
|
|
148 |
|
|
|
149 |
for x in filter(lambda p: p.requires_grad, net.parameters()): |
|
|
150 |
total_params += np.prod(x.data.numpy().shape) |
|
|
151 |
print("Total number of params", total_params) |
|
|
152 |
print("Total layers", len(list(filter(lambda p: p.requires_grad and len(p.data.size())>1, net.parameters())))) |
|
|
153 |
|
|
|
154 |
|
|
|
155 |
if __name__ == "__main__": |
|
|
156 |
for net_name in __all__: |
|
|
157 |
if net_name.startswith('resnet'): |
|
|
158 |
print(net_name) |
|
|
159 |
test(globals()[net_name]()) |
|
|
160 |
print() |