a b/mmaction/models/heads/trn_head.py
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import itertools
3
4
import numpy as np
5
import torch
6
import torch.nn as nn
7
from mmcv.cnn import normal_init
8
9
from ..builder import HEADS
10
from .base import BaseHead
11
12
13
class RelationModule(nn.Module):
14
    """Relation Module of TRN.
15
16
    Args:
17
        hidden_dim (int): The dimension of hidden layer of MLP in relation
18
            module.
19
        num_segments (int): Number of frame segments.
20
        num_classes (int): Number of classes to be classified.
21
    """
22
23
    def __init__(self, hidden_dim, num_segments, num_classes):
24
        super().__init__()
25
        self.hidden_dim = hidden_dim
26
        self.num_segments = num_segments
27
        self.num_classes = num_classes
28
        bottleneck_dim = 512
29
        self.classifier = nn.Sequential(
30
            nn.ReLU(),
31
            nn.Linear(self.num_segments * self.hidden_dim, bottleneck_dim),
32
            nn.ReLU(), nn.Linear(bottleneck_dim, self.num_classes))
33
34
    def init_weights(self):
35
        # Use the default kaiming_uniform for all nn.linear layers.
36
        pass
37
38
    def forward(self, x):
39
        # [N, num_segs * hidden_dim]
40
        x = x.view(x.size(0), -1)
41
        x = self.classifier(x)
42
        return x
43
44
45
class RelationModuleMultiScale(nn.Module):
46
    """Relation Module with Multi Scale of TRN.
47
48
    Args:
49
        hidden_dim (int): The dimension of hidden layer of MLP in relation
50
            module.
51
        num_segments (int): Number of frame segments.
52
        num_classes (int): Number of classes to be classified.
53
    """
54
55
    def __init__(self, hidden_dim, num_segments, num_classes):
56
        super().__init__()
57
        self.hidden_dim = hidden_dim
58
        self.num_segments = num_segments
59
        self.num_classes = num_classes
60
61
        # generate the multiple frame relations
62
        self.scales = range(num_segments, 1, -1)
63
64
        self.relations_scales = []
65
        self.subsample_scales = []
66
        max_subsample = 3
67
        for scale in self.scales:
68
            # select the different frame features for different scales
69
            relations_scale = list(
70
                itertools.combinations(range(self.num_segments), scale))
71
            self.relations_scales.append(relations_scale)
72
            # sample `max_subsample` relation_scale at most
73
            self.subsample_scales.append(
74
                min(max_subsample, len(relations_scale)))
75
        assert len(self.relations_scales[0]) == 1
76
77
        bottleneck_dim = 256
78
        self.fc_fusion_scales = nn.ModuleList()
79
        for scale in self.scales:
80
            fc_fusion = nn.Sequential(
81
                nn.ReLU(), nn.Linear(scale * self.hidden_dim, bottleneck_dim),
82
                nn.ReLU(), nn.Linear(bottleneck_dim, self.num_classes))
83
            self.fc_fusion_scales.append(fc_fusion)
84
85
    def init_weights(self):
86
        # Use the default kaiming_uniform for all nn.linear layers.
87
        pass
88
89
    def forward(self, x):
90
        # the first one is the largest scale
91
        act_all = x[:, self.relations_scales[0][0], :]
92
        act_all = act_all.view(
93
            act_all.size(0), self.scales[0] * self.hidden_dim)
94
        act_all = self.fc_fusion_scales[0](act_all)
95
96
        for scaleID in range(1, len(self.scales)):
97
            # iterate over the scales
98
            idx_relations_randomsample = np.random.choice(
99
                len(self.relations_scales[scaleID]),
100
                self.subsample_scales[scaleID],
101
                replace=False)
102
            for idx in idx_relations_randomsample:
103
                act_relation = x[:, self.relations_scales[scaleID][idx], :]
104
                act_relation = act_relation.view(
105
                    act_relation.size(0),
106
                    self.scales[scaleID] * self.hidden_dim)
107
                act_relation = self.fc_fusion_scales[scaleID](act_relation)
108
                act_all += act_relation
109
        return act_all
110
111
112
@HEADS.register_module()
113
class TRNHead(BaseHead):
114
    """Class head for TRN.
115
116
    Args:
117
        num_classes (int): Number of classes to be classified.
118
        in_channels (int): Number of channels in input feature.
119
        num_segments (int): Number of frame segments. Default: 8.
120
        loss_cls (dict): Config for building loss. Default:
121
            dict(type='CrossEntropyLoss')
122
        spatial_type (str): Pooling type in spatial dimension. Default: 'avg'.
123
        relation_type (str): The relation module type. Choices are 'TRN' or
124
            'TRNMultiScale'. Default: 'TRNMultiScale'.
125
        hidden_dim (int): The dimension of hidden layer of MLP in relation
126
            module. Default: 256.
127
        dropout_ratio (float): Probability of dropout layer. Default: 0.8.
128
        init_std (float): Std value for Initiation. Default: 0.001.
129
        kwargs (dict, optional): Any keyword argument to be used to initialize
130
            the head.
131
    """
132
133
    def __init__(self,
134
                 num_classes,
135
                 in_channels,
136
                 num_segments=8,
137
                 loss_cls=dict(type='CrossEntropyLoss'),
138
                 spatial_type='avg',
139
                 relation_type='TRNMultiScale',
140
                 hidden_dim=256,
141
                 dropout_ratio=0.8,
142
                 init_std=0.001,
143
                 **kwargs):
144
        super().__init__(num_classes, in_channels, loss_cls, **kwargs)
145
146
        self.num_classes = num_classes
147
        self.in_channels = in_channels
148
        self.num_segments = num_segments
149
        self.spatial_type = spatial_type
150
        self.relation_type = relation_type
151
        self.hidden_dim = hidden_dim
152
        self.dropout_ratio = dropout_ratio
153
        self.init_std = init_std
154
155
        if self.relation_type == 'TRN':
156
            self.consensus = RelationModule(self.hidden_dim, self.num_segments,
157
                                            self.num_classes)
158
        elif self.relation_type == 'TRNMultiScale':
159
            self.consensus = RelationModuleMultiScale(self.hidden_dim,
160
                                                      self.num_segments,
161
                                                      self.num_classes)
162
        else:
163
            raise ValueError(f'Unknown Relation Type {self.relation_type}!')
164
165
        if self.dropout_ratio != 0:
166
            self.dropout = nn.Dropout(p=self.dropout_ratio)
167
        else:
168
            self.dropout = None
169
        self.fc_cls = nn.Linear(self.in_channels, self.hidden_dim)
170
171
        if self.spatial_type == 'avg':
172
            # use `nn.AdaptiveAvgPool2d` to adaptively match the in_channels.
173
            self.avg_pool = nn.AdaptiveAvgPool2d(1)
174
        else:
175
            self.avg_pool = None
176
177
    def init_weights(self):
178
        """Initiate the parameters from scratch."""
179
        normal_init(self.fc_cls, std=self.init_std)
180
        self.consensus.init_weights()
181
182
    def forward(self, x, num_segs):
183
        """Defines the computation performed at every call.
184
185
        Args:
186
            x (torch.Tensor): The input data.
187
            num_segs (int): Useless in TRNHead. By default, `num_segs`
188
                is equal to `clip_len * num_clips * num_crops`, which is
189
                automatically generated in Recognizer forward phase and
190
                useless in TRN models. The `self.num_segments` we need is a
191
                hyper parameter to build TRN models.
192
        Returns:
193
            torch.Tensor: The classification scores for input samples.
194
        """
195
        # [N * num_segs, in_channels, 7, 7]
196
        if self.avg_pool is not None:
197
            x = self.avg_pool(x)
198
        # [N * num_segs, in_channels, 1, 1]
199
        x = torch.flatten(x, 1)
200
        # [N * num_segs, in_channels]
201
        if self.dropout is not None:
202
            x = self.dropout(x)
203
204
        # [N, num_segs, hidden_dim]
205
        cls_score = self.fc_cls(x)
206
        cls_score = cls_score.view((-1, self.num_segments) +
207
                                   cls_score.size()[1:])
208
209
        # [N, num_classes]
210
        cls_score = self.consensus(cls_score)
211
        return cls_score