a b/opengait/modeling/backbones/resnet.py
1
from torch.nn import functional as F
2
import torch.nn as nn
3
from torchvision.models.resnet import BasicBlock, Bottleneck, ResNet
4
from ..modules import BasicConv2d
5
6
7
block_map = {'BasicBlock': BasicBlock,
8
             'Bottleneck': Bottleneck}
9
10
11
class ResNet9(ResNet):
12
    def __init__(self, block, channels=[32, 64, 128, 256], in_channel=1, layers=[1, 2, 2, 1], strides=[1, 2, 2, 1], maxpool=True):
13
        if block in block_map.keys():
14
            block = block_map[block]
15
        else:
16
            raise ValueError(
17
                "Error type for -block-Cfg-, supported: 'BasicBlock' or 'Bottleneck'.")
18
        self.maxpool_flag = maxpool
19
        super(ResNet9, self).__init__(block, layers)
20
21
        # Not used #
22
        self.fc = None
23
        ############
24
        self.inplanes = channels[0]
25
        self.bn1 = nn.BatchNorm2d(self.inplanes)
26
27
        self.conv1 = BasicConv2d(in_channel, self.inplanes, 3, 1, 1)
28
29
        self.layer1 = self._make_layer(
30
            block, channels[0], layers[0], stride=strides[0], dilate=False)
31
32
        self.layer2 = self._make_layer(
33
            block, channels[1], layers[1], stride=strides[1], dilate=False)
34
        self.layer3 = self._make_layer(
35
            block, channels[2], layers[2], stride=strides[2], dilate=False)
36
        self.layer4 = self._make_layer(
37
            block, channels[3], layers[3], stride=strides[3], dilate=False)
38
39
    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
40
        if blocks >= 1:
41
            layer = super()._make_layer(block, planes, blocks, stride=stride, dilate=dilate)
42
        else:
43
            def layer(x): return x
44
        return layer
45
46
    def forward(self, x):
47
        x = self.conv1(x)
48
        x = self.bn1(x)
49
        x = self.relu(x)
50
        if self.maxpool_flag:
51
            x = self.maxpool(x)
52
53
        x = self.layer1(x)
54
        x = self.layer2(x)
55
        x = self.layer3(x)
56
        x = self.layer4(x)
57
        return x
58