|
a |
|
b/mmaction/models/heads/tsm_head.py |
|
|
1 |
# Copyright (c) OpenMMLab. All rights reserved. |
|
|
2 |
import torch |
|
|
3 |
import torch.nn as nn |
|
|
4 |
from mmcv.cnn import normal_init |
|
|
5 |
|
|
|
6 |
from ..builder import HEADS |
|
|
7 |
from .base import AvgConsensus, BaseHead |
|
|
8 |
|
|
|
9 |
|
|
|
10 |
@HEADS.register_module() |
|
|
11 |
class TSMHead(BaseHead): |
|
|
12 |
"""Class head for TSM. |
|
|
13 |
|
|
|
14 |
Args: |
|
|
15 |
num_classes (int): Number of classes to be classified. |
|
|
16 |
in_channels (int): Number of channels in input feature. |
|
|
17 |
num_segments (int): Number of frame segments. Default: 8. |
|
|
18 |
loss_cls (dict): Config for building loss. |
|
|
19 |
Default: dict(type='CrossEntropyLoss') |
|
|
20 |
spatial_type (str): Pooling type in spatial dimension. Default: 'avg'. |
|
|
21 |
consensus (dict): Consensus config dict. |
|
|
22 |
dropout_ratio (float): Probability of dropout layer. Default: 0.4. |
|
|
23 |
init_std (float): Std value for Initiation. Default: 0.01. |
|
|
24 |
is_shift (bool): Indicating whether the feature is shifted. |
|
|
25 |
Default: True. |
|
|
26 |
temporal_pool (bool): Indicating whether feature is temporal pooled. |
|
|
27 |
Default: False. |
|
|
28 |
kwargs (dict, optional): Any keyword argument to be used to initialize |
|
|
29 |
the head. |
|
|
30 |
""" |
|
|
31 |
|
|
|
32 |
def __init__(self, |
|
|
33 |
num_classes, |
|
|
34 |
in_channels, |
|
|
35 |
num_segments=8, |
|
|
36 |
loss_cls=dict(type='CrossEntropyLoss'), |
|
|
37 |
spatial_type='avg', |
|
|
38 |
consensus=dict(type='AvgConsensus', dim=1), |
|
|
39 |
dropout_ratio=0.8, |
|
|
40 |
init_std=0.001, |
|
|
41 |
is_shift=True, |
|
|
42 |
temporal_pool=False, |
|
|
43 |
**kwargs): |
|
|
44 |
super().__init__(num_classes, in_channels, loss_cls, **kwargs) |
|
|
45 |
|
|
|
46 |
self.spatial_type = spatial_type |
|
|
47 |
self.dropout_ratio = dropout_ratio |
|
|
48 |
self.num_segments = num_segments |
|
|
49 |
self.init_std = init_std |
|
|
50 |
self.is_shift = is_shift |
|
|
51 |
self.temporal_pool = temporal_pool |
|
|
52 |
|
|
|
53 |
consensus_ = consensus.copy() |
|
|
54 |
|
|
|
55 |
consensus_type = consensus_.pop('type') |
|
|
56 |
if consensus_type == 'AvgConsensus': |
|
|
57 |
self.consensus = AvgConsensus(**consensus_) |
|
|
58 |
else: |
|
|
59 |
self.consensus = None |
|
|
60 |
|
|
|
61 |
if self.dropout_ratio != 0: |
|
|
62 |
self.dropout = nn.Dropout(p=self.dropout_ratio) |
|
|
63 |
else: |
|
|
64 |
self.dropout = None |
|
|
65 |
self.fc_cls = nn.Linear(self.in_channels, self.num_classes) |
|
|
66 |
|
|
|
67 |
if self.spatial_type == 'avg': |
|
|
68 |
# use `nn.AdaptiveAvgPool2d` to adaptively match the in_channels. |
|
|
69 |
self.avg_pool = nn.AdaptiveAvgPool2d(1) |
|
|
70 |
else: |
|
|
71 |
self.avg_pool = None |
|
|
72 |
|
|
|
73 |
def init_weights(self): |
|
|
74 |
"""Initiate the parameters from scratch.""" |
|
|
75 |
normal_init(self.fc_cls, std=self.init_std) |
|
|
76 |
|
|
|
77 |
def forward(self, x, num_segs): |
|
|
78 |
"""Defines the computation performed at every call. |
|
|
79 |
|
|
|
80 |
Args: |
|
|
81 |
x (torch.Tensor): The input data. |
|
|
82 |
num_segs (int): Useless in TSMHead. By default, `num_segs` |
|
|
83 |
is equal to `clip_len * num_clips * num_crops`, which is |
|
|
84 |
automatically generated in Recognizer forward phase and |
|
|
85 |
useless in TSM models. The `self.num_segments` we need is a |
|
|
86 |
hyper parameter to build TSM models. |
|
|
87 |
Returns: |
|
|
88 |
torch.Tensor: The classification scores for input samples. |
|
|
89 |
""" |
|
|
90 |
# [N * num_segs, in_channels, 7, 7] |
|
|
91 |
if self.avg_pool is not None: |
|
|
92 |
x = self.avg_pool(x) |
|
|
93 |
# [N * num_segs, in_channels, 1, 1] |
|
|
94 |
x = torch.flatten(x, 1) |
|
|
95 |
# [N * num_segs, in_channels] |
|
|
96 |
if self.dropout is not None: |
|
|
97 |
x = self.dropout(x) |
|
|
98 |
# [N * num_segs, num_classes] |
|
|
99 |
cls_score = self.fc_cls(x) |
|
|
100 |
|
|
|
101 |
if self.is_shift and self.temporal_pool: |
|
|
102 |
# [2 * N, num_segs // 2, num_classes] |
|
|
103 |
cls_score = cls_score.view((-1, self.num_segments // 2) + |
|
|
104 |
cls_score.size()[1:]) |
|
|
105 |
else: |
|
|
106 |
# [N, num_segs, num_classes] |
|
|
107 |
cls_score = cls_score.view((-1, self.num_segments) + |
|
|
108 |
cls_score.size()[1:]) |
|
|
109 |
# [N, 1, num_classes] |
|
|
110 |
cls_score = self.consensus(cls_score) |
|
|
111 |
# [N, num_classes] |
|
|
112 |
return cls_score.squeeze(1) |