Switch to side-by-side view

--- a
+++ b/mmaction/models/recognizers/audio_recognizer.py
@@ -0,0 +1,102 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..builder import RECOGNIZERS
+from .base import BaseRecognizer
+
+
+@RECOGNIZERS.register_module()
+class AudioRecognizer(BaseRecognizer):
+    """Audio recognizer model framework."""
+
+    def forward(self, audios, label=None, return_loss=True):
+        """Define the computation performed at every call."""
+        if return_loss:
+            if label is None:
+                raise ValueError('Label should not be None.')
+            return self.forward_train(audios, label)
+
+        return self.forward_test(audios)
+
+    def forward_train(self, audios, labels):
+        """Defines the computation performed at every call when training."""
+        audios = audios.reshape((-1, ) + audios.shape[2:])
+        x = self.extract_feat(audios)
+        cls_score = self.cls_head(x)
+        gt_labels = labels.squeeze()
+        loss = self.cls_head.loss(cls_score, gt_labels)
+
+        return loss
+
+    def forward_test(self, audios):
+        """Defines the computation performed at every call when evaluation and
+        testing."""
+        num_segs = audios.shape[1]
+        audios = audios.reshape((-1, ) + audios.shape[2:])
+        x = self.extract_feat(audios)
+        cls_score = self.cls_head(x)
+        cls_score = self.average_clip(cls_score, num_segs)
+
+        return cls_score.cpu().numpy()
+
+    def forward_gradcam(self, audios):
+        raise NotImplementedError
+
+    def train_step(self, data_batch, optimizer, **kwargs):
+        """The iteration step during training.
+
+        This method defines an iteration step during training, except for the
+        back propagation and optimizer updating, which are done in an optimizer
+        hook. Note that in some complicated cases or models, the whole process
+        including back propagation and optimizer updating is also defined in
+        this method, such as GAN.
+
+        Args:
+            data_batch (dict): The output of dataloader.
+            optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
+                runner is passed to ``train_step()``. This argument is unused
+                and reserved.
+
+        Returns:
+            dict: It should contain at least 3 keys: ``loss``, ``log_vars``,
+                ``num_samples``.
+                ``loss`` is a tensor for back propagation, which can be a
+                weighted sum of multiple losses.
+                ``log_vars`` contains all the variables to be sent to the
+                logger.
+                ``num_samples`` indicates the batch size (when the model is
+                DDP, it means the batch size on each GPU), which is used for
+                averaging the logs.
+        """
+        audios = data_batch['audios']
+        label = data_batch['label']
+
+        losses = self(audios, label)
+
+        loss, log_vars = self._parse_losses(losses)
+
+        outputs = dict(
+            loss=loss,
+            log_vars=log_vars,
+            num_samples=len(next(iter(data_batch.values()))))
+
+        return outputs
+
+    def val_step(self, data_batch, optimizer, **kwargs):
+        """The iteration step during validation.
+
+        This method shares the same signature as :func:`train_step`, but used
+        during val epochs. Note that the evaluation after training epochs is
+        not implemented with this method, but an evaluation hook.
+        """
+        audios = data_batch['audios']
+        label = data_batch['label']
+
+        losses = self(audios, label)
+
+        loss, log_vars = self._parse_losses(losses)
+
+        outputs = dict(
+            loss=loss,
+            log_vars=log_vars,
+            num_samples=len(next(iter(data_batch.values()))))
+
+        return outputs