|
a |
|
b/mmaction/models/heads/trn_head.py |
|
|
1 |
# Copyright (c) OpenMMLab. All rights reserved. |
|
|
2 |
import itertools |
|
|
3 |
|
|
|
4 |
import numpy as np |
|
|
5 |
import torch |
|
|
6 |
import torch.nn as nn |
|
|
7 |
from mmcv.cnn import normal_init |
|
|
8 |
|
|
|
9 |
from ..builder import HEADS |
|
|
10 |
from .base import BaseHead |
|
|
11 |
|
|
|
12 |
|
|
|
13 |
class RelationModule(nn.Module): |
|
|
14 |
"""Relation Module of TRN. |
|
|
15 |
|
|
|
16 |
Args: |
|
|
17 |
hidden_dim (int): The dimension of hidden layer of MLP in relation |
|
|
18 |
module. |
|
|
19 |
num_segments (int): Number of frame segments. |
|
|
20 |
num_classes (int): Number of classes to be classified. |
|
|
21 |
""" |
|
|
22 |
|
|
|
23 |
def __init__(self, hidden_dim, num_segments, num_classes): |
|
|
24 |
super().__init__() |
|
|
25 |
self.hidden_dim = hidden_dim |
|
|
26 |
self.num_segments = num_segments |
|
|
27 |
self.num_classes = num_classes |
|
|
28 |
bottleneck_dim = 512 |
|
|
29 |
self.classifier = nn.Sequential( |
|
|
30 |
nn.ReLU(), |
|
|
31 |
nn.Linear(self.num_segments * self.hidden_dim, bottleneck_dim), |
|
|
32 |
nn.ReLU(), nn.Linear(bottleneck_dim, self.num_classes)) |
|
|
33 |
|
|
|
34 |
def init_weights(self): |
|
|
35 |
# Use the default kaiming_uniform for all nn.linear layers. |
|
|
36 |
pass |
|
|
37 |
|
|
|
38 |
def forward(self, x): |
|
|
39 |
# [N, num_segs * hidden_dim] |
|
|
40 |
x = x.view(x.size(0), -1) |
|
|
41 |
x = self.classifier(x) |
|
|
42 |
return x |
|
|
43 |
|
|
|
44 |
|
|
|
45 |
class RelationModuleMultiScale(nn.Module): |
|
|
46 |
"""Relation Module with Multi Scale of TRN. |
|
|
47 |
|
|
|
48 |
Args: |
|
|
49 |
hidden_dim (int): The dimension of hidden layer of MLP in relation |
|
|
50 |
module. |
|
|
51 |
num_segments (int): Number of frame segments. |
|
|
52 |
num_classes (int): Number of classes to be classified. |
|
|
53 |
""" |
|
|
54 |
|
|
|
55 |
def __init__(self, hidden_dim, num_segments, num_classes): |
|
|
56 |
super().__init__() |
|
|
57 |
self.hidden_dim = hidden_dim |
|
|
58 |
self.num_segments = num_segments |
|
|
59 |
self.num_classes = num_classes |
|
|
60 |
|
|
|
61 |
# generate the multiple frame relations |
|
|
62 |
self.scales = range(num_segments, 1, -1) |
|
|
63 |
|
|
|
64 |
self.relations_scales = [] |
|
|
65 |
self.subsample_scales = [] |
|
|
66 |
max_subsample = 3 |
|
|
67 |
for scale in self.scales: |
|
|
68 |
# select the different frame features for different scales |
|
|
69 |
relations_scale = list( |
|
|
70 |
itertools.combinations(range(self.num_segments), scale)) |
|
|
71 |
self.relations_scales.append(relations_scale) |
|
|
72 |
# sample `max_subsample` relation_scale at most |
|
|
73 |
self.subsample_scales.append( |
|
|
74 |
min(max_subsample, len(relations_scale))) |
|
|
75 |
assert len(self.relations_scales[0]) == 1 |
|
|
76 |
|
|
|
77 |
bottleneck_dim = 256 |
|
|
78 |
self.fc_fusion_scales = nn.ModuleList() |
|
|
79 |
for scale in self.scales: |
|
|
80 |
fc_fusion = nn.Sequential( |
|
|
81 |
nn.ReLU(), nn.Linear(scale * self.hidden_dim, bottleneck_dim), |
|
|
82 |
nn.ReLU(), nn.Linear(bottleneck_dim, self.num_classes)) |
|
|
83 |
self.fc_fusion_scales.append(fc_fusion) |
|
|
84 |
|
|
|
85 |
def init_weights(self): |
|
|
86 |
# Use the default kaiming_uniform for all nn.linear layers. |
|
|
87 |
pass |
|
|
88 |
|
|
|
89 |
def forward(self, x): |
|
|
90 |
# the first one is the largest scale |
|
|
91 |
act_all = x[:, self.relations_scales[0][0], :] |
|
|
92 |
act_all = act_all.view( |
|
|
93 |
act_all.size(0), self.scales[0] * self.hidden_dim) |
|
|
94 |
act_all = self.fc_fusion_scales[0](act_all) |
|
|
95 |
|
|
|
96 |
for scaleID in range(1, len(self.scales)): |
|
|
97 |
# iterate over the scales |
|
|
98 |
idx_relations_randomsample = np.random.choice( |
|
|
99 |
len(self.relations_scales[scaleID]), |
|
|
100 |
self.subsample_scales[scaleID], |
|
|
101 |
replace=False) |
|
|
102 |
for idx in idx_relations_randomsample: |
|
|
103 |
act_relation = x[:, self.relations_scales[scaleID][idx], :] |
|
|
104 |
act_relation = act_relation.view( |
|
|
105 |
act_relation.size(0), |
|
|
106 |
self.scales[scaleID] * self.hidden_dim) |
|
|
107 |
act_relation = self.fc_fusion_scales[scaleID](act_relation) |
|
|
108 |
act_all += act_relation |
|
|
109 |
return act_all |
|
|
110 |
|
|
|
111 |
|
|
|
112 |
@HEADS.register_module() |
|
|
113 |
class TRNHead(BaseHead): |
|
|
114 |
"""Class head for TRN. |
|
|
115 |
|
|
|
116 |
Args: |
|
|
117 |
num_classes (int): Number of classes to be classified. |
|
|
118 |
in_channels (int): Number of channels in input feature. |
|
|
119 |
num_segments (int): Number of frame segments. Default: 8. |
|
|
120 |
loss_cls (dict): Config for building loss. Default: |
|
|
121 |
dict(type='CrossEntropyLoss') |
|
|
122 |
spatial_type (str): Pooling type in spatial dimension. Default: 'avg'. |
|
|
123 |
relation_type (str): The relation module type. Choices are 'TRN' or |
|
|
124 |
'TRNMultiScale'. Default: 'TRNMultiScale'. |
|
|
125 |
hidden_dim (int): The dimension of hidden layer of MLP in relation |
|
|
126 |
module. Default: 256. |
|
|
127 |
dropout_ratio (float): Probability of dropout layer. Default: 0.8. |
|
|
128 |
init_std (float): Std value for Initiation. Default: 0.001. |
|
|
129 |
kwargs (dict, optional): Any keyword argument to be used to initialize |
|
|
130 |
the head. |
|
|
131 |
""" |
|
|
132 |
|
|
|
133 |
def __init__(self, |
|
|
134 |
num_classes, |
|
|
135 |
in_channels, |
|
|
136 |
num_segments=8, |
|
|
137 |
loss_cls=dict(type='CrossEntropyLoss'), |
|
|
138 |
spatial_type='avg', |
|
|
139 |
relation_type='TRNMultiScale', |
|
|
140 |
hidden_dim=256, |
|
|
141 |
dropout_ratio=0.8, |
|
|
142 |
init_std=0.001, |
|
|
143 |
**kwargs): |
|
|
144 |
super().__init__(num_classes, in_channels, loss_cls, **kwargs) |
|
|
145 |
|
|
|
146 |
self.num_classes = num_classes |
|
|
147 |
self.in_channels = in_channels |
|
|
148 |
self.num_segments = num_segments |
|
|
149 |
self.spatial_type = spatial_type |
|
|
150 |
self.relation_type = relation_type |
|
|
151 |
self.hidden_dim = hidden_dim |
|
|
152 |
self.dropout_ratio = dropout_ratio |
|
|
153 |
self.init_std = init_std |
|
|
154 |
|
|
|
155 |
if self.relation_type == 'TRN': |
|
|
156 |
self.consensus = RelationModule(self.hidden_dim, self.num_segments, |
|
|
157 |
self.num_classes) |
|
|
158 |
elif self.relation_type == 'TRNMultiScale': |
|
|
159 |
self.consensus = RelationModuleMultiScale(self.hidden_dim, |
|
|
160 |
self.num_segments, |
|
|
161 |
self.num_classes) |
|
|
162 |
else: |
|
|
163 |
raise ValueError(f'Unknown Relation Type {self.relation_type}!') |
|
|
164 |
|
|
|
165 |
if self.dropout_ratio != 0: |
|
|
166 |
self.dropout = nn.Dropout(p=self.dropout_ratio) |
|
|
167 |
else: |
|
|
168 |
self.dropout = None |
|
|
169 |
self.fc_cls = nn.Linear(self.in_channels, self.hidden_dim) |
|
|
170 |
|
|
|
171 |
if self.spatial_type == 'avg': |
|
|
172 |
# use `nn.AdaptiveAvgPool2d` to adaptively match the in_channels. |
|
|
173 |
self.avg_pool = nn.AdaptiveAvgPool2d(1) |
|
|
174 |
else: |
|
|
175 |
self.avg_pool = None |
|
|
176 |
|
|
|
177 |
def init_weights(self): |
|
|
178 |
"""Initiate the parameters from scratch.""" |
|
|
179 |
normal_init(self.fc_cls, std=self.init_std) |
|
|
180 |
self.consensus.init_weights() |
|
|
181 |
|
|
|
182 |
def forward(self, x, num_segs): |
|
|
183 |
"""Defines the computation performed at every call. |
|
|
184 |
|
|
|
185 |
Args: |
|
|
186 |
x (torch.Tensor): The input data. |
|
|
187 |
num_segs (int): Useless in TRNHead. By default, `num_segs` |
|
|
188 |
is equal to `clip_len * num_clips * num_crops`, which is |
|
|
189 |
automatically generated in Recognizer forward phase and |
|
|
190 |
useless in TRN models. The `self.num_segments` we need is a |
|
|
191 |
hyper parameter to build TRN models. |
|
|
192 |
Returns: |
|
|
193 |
torch.Tensor: The classification scores for input samples. |
|
|
194 |
""" |
|
|
195 |
# [N * num_segs, in_channels, 7, 7] |
|
|
196 |
if self.avg_pool is not None: |
|
|
197 |
x = self.avg_pool(x) |
|
|
198 |
# [N * num_segs, in_channels, 1, 1] |
|
|
199 |
x = torch.flatten(x, 1) |
|
|
200 |
# [N * num_segs, in_channels] |
|
|
201 |
if self.dropout is not None: |
|
|
202 |
x = self.dropout(x) |
|
|
203 |
|
|
|
204 |
# [N, num_segs, hidden_dim] |
|
|
205 |
cls_score = self.fc_cls(x) |
|
|
206 |
cls_score = cls_score.view((-1, self.num_segments) + |
|
|
207 |
cls_score.size()[1:]) |
|
|
208 |
|
|
|
209 |
# [N, num_classes] |
|
|
210 |
cls_score = self.consensus(cls_score) |
|
|
211 |
return cls_score |