Switch to unified view

a b/CaraNet/pretrain/ResNet.py
1
# -*- coding: utf-8 -*-
2
"""
3
Created on Mon Jun 21 21:52:37 2021
4
5
@author: angelou
6
"""
7
8
import torch.nn as nn
9
import math
10
11
12
def conv3x3(in_planes, out_planes, stride=1):
13
    """3x3 convolution with padding"""
14
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
15
                     padding=1, bias=False)
16
17
18
class BasicBlock(nn.Module):
19
    expansion = 1
20
21
    def __init__(self, inplanes, planes, stride=1, downsample=None):
22
        super(BasicBlock, self).__init__()
23
        self.conv1 = conv3x3(inplanes, planes, stride)
24
        self.bn1 = nn.BatchNorm2d(planes)
25
        self.relu = nn.ReLU(inplace=True)
26
        self.conv2 = conv3x3(planes, planes)
27
        self.bn2 = nn.BatchNorm2d(planes)
28
        self.downsample = downsample
29
        self.stride = stride
30
31
    def forward(self, x):
32
        residual = x
33
34
        out = self.conv1(x)
35
        out = self.bn1(out)
36
        out = self.relu(out)
37
38
        out = self.conv2(out)
39
        out = self.bn2(out)
40
41
        if self.downsample is not None:
42
            residual = self.downsample(x)
43
44
        out += residual
45
        out = self.relu(out)
46
47
        return out
48
49
50
class Bottleneck(nn.Module):
51
    expansion = 4
52
53
    def __init__(self, inplanes, planes, stride=1, downsample=None):
54
        super(Bottleneck, self).__init__()
55
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
56
        self.bn1 = nn.BatchNorm2d(planes)
57
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
58
                               padding=1, bias=False)
59
        self.bn2 = nn.BatchNorm2d(planes)
60
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
61
        self.bn3 = nn.BatchNorm2d(planes * 4)
62
        self.relu = nn.ReLU(inplace=True)
63
        self.downsample = downsample
64
        self.stride = stride
65
66
    def forward(self, x):
67
        residual = x
68
69
        out = self.conv1(x)
70
        out = self.bn1(out)
71
        out = self.relu(out)
72
73
        out = self.conv2(out)
74
        out = self.bn2(out)
75
        out = self.relu(out)
76
77
        out = self.conv3(out)
78
        out = self.bn3(out)
79
80
        if self.downsample is not None:
81
            residual = self.downsample(x)
82
83
        out += residual
84
        out = self.relu(out)
85
86
        return out
87
88
89
class ResNet(nn.Module):
90
    # ResNet50 with two branches
91
    def __init__(self):
92
        # self.inplanes = 128
93
        self.inplanes = 64
94
        super(ResNet, self).__init__()
95
96
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
97
                               bias=False)
98
        self.bn1 = nn.BatchNorm2d(64)
99
        self.relu = nn.ReLU(inplace=True)
100
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
101
        self.layer1 = self._make_layer(Bottleneck, 64, 3)
102
        self.layer2 = self._make_layer(Bottleneck, 128, 4, stride=2)
103
        self.layer3 = self._make_layer(Bottleneck, 256, 6, stride=2)
104
        self.layer4 = self._make_layer(Bottleneck, 512, 3, stride=2)
105
106
        for m in self.modules():
107
            if isinstance(m, nn.Conv2d):
108
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
109
                m.weight.data.normal_(0, math.sqrt(2. / n))
110
            elif isinstance(m, nn.BatchNorm2d):
111
                m.weight.data.fill_(1)
112
                m.bias.data.zero_()
113
114
    def _make_layer(self, block, planes, blocks, stride=1):
115
        downsample = None
116
        if stride != 1 or self.inplanes != planes * block.expansion:
117
            downsample = nn.Sequential(
118
                nn.Conv2d(self.inplanes, planes * block.expansion,
119
                          kernel_size=1, stride=stride, bias=False),
120
                nn.BatchNorm2d(planes * block.expansion),
121
            )
122
123
        layers = []
124
        layers.append(block(self.inplanes, planes, stride, downsample))
125
        self.inplanes = planes * block.expansion
126
        for i in range(1, blocks):
127
            layers.append(block(self.inplanes, planes))
128
129
        return nn.Sequential(*layers)
130
131
    def forward(self, x):
132
        x = self.conv1(x)
133
        x = self.bn1(x)
134
        x = self.relu(x)
135
        x = self.maxpool(x)
136
137
        x = self.layer1(x)
138
        x = self.layer2(x)
139
        x1 = self.layer3_1(x)
140
        x1 = self.layer4_1(x1)
141
142
        x2 = self.layer3_2(x)
143
        x2 = self.layer4_2(x2)
144
145
        return x1, x2