--- a +++ b/opengait/modeling/backbones/resnet.py @@ -0,0 +1,58 @@ +from torch.nn import functional as F +import torch.nn as nn +from torchvision.models.resnet import BasicBlock, Bottleneck, ResNet +from ..modules import BasicConv2d + + +block_map = {'BasicBlock': BasicBlock, + 'Bottleneck': Bottleneck} + + +class ResNet9(ResNet): + def __init__(self, block, channels=[32, 64, 128, 256], in_channel=1, layers=[1, 2, 2, 1], strides=[1, 2, 2, 1], maxpool=True): + if block in block_map.keys(): + block = block_map[block] + else: + raise ValueError( + "Error type for -block-Cfg-, supported: 'BasicBlock' or 'Bottleneck'.") + self.maxpool_flag = maxpool + super(ResNet9, self).__init__(block, layers) + + # Not used # + self.fc = None + ############ + self.inplanes = channels[0] + self.bn1 = nn.BatchNorm2d(self.inplanes) + + self.conv1 = BasicConv2d(in_channel, self.inplanes, 3, 1, 1) + + self.layer1 = self._make_layer( + block, channels[0], layers[0], stride=strides[0], dilate=False) + + self.layer2 = self._make_layer( + block, channels[1], layers[1], stride=strides[1], dilate=False) + self.layer3 = self._make_layer( + block, channels[2], layers[2], stride=strides[2], dilate=False) + self.layer4 = self._make_layer( + block, channels[3], layers[3], stride=strides[3], dilate=False) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + if blocks >= 1: + layer = super()._make_layer(block, planes, blocks, stride=stride, dilate=dilate) + else: + def layer(x): return x + return layer + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + if self.maxpool_flag: + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + return x +