[c1b1c5]: / ViTPose / mmpose / models / backbones / base_backbone.py

Download this file

44 lines (35 with data), 1.6 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# Copyright (c) OpenMMLab. All rights reserved.
import logging
from abc import ABCMeta, abstractmethod
import torch.nn as nn
# from .utils import load_checkpoint
from mmcv_custom.checkpoint import load_checkpoint
class BaseBackbone(nn.Module, metaclass=ABCMeta):
"""Base backbone.
This class defines the basic functions of a backbone. Any backbone that
inherits this class should at least define its own `forward` function.
"""
def init_weights(self, pretrained=None, patch_padding='pad'):
"""Init backbone models.
Args:
pretrained (str | None): If pretrained is a string, then it
initializes backbone models by loading the pretrained
checkpoint. If pretrained is None, then it follows default
initializer or customized initializer in subclasses.
"""
if isinstance(pretrained, str):
logger = logging.getLogger()
load_checkpoint(self, pretrained, strict=False, logger=logger, patch_padding=patch_padding)
elif pretrained is None:
# use default initializer or customized initializer in subclasses
pass
else:
raise TypeError('pretrained must be a str or None.'
f' But received {type(pretrained)}.')
@abstractmethod
def forward(self, x):
"""Forward function.
Args:
x (Tensor | tuple[Tensor]): x could be a torch.Tensor or a tuple of
torch.Tensor, containing input data for forward computation.
"""