Switch to side-by-side view

--- a
+++ b/tests/test_metrics/test_losses.py
@@ -0,0 +1,332 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+import pytest
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv import ConfigDict
+from numpy.testing import assert_almost_equal, assert_array_almost_equal
+from torch.autograd import Variable
+
+from mmaction.models import (BCELossWithLogits, BinaryLogisticRegressionLoss,
+                             BMNLoss, CrossEntropyLoss, HVULoss, NLLLoss,
+                             OHEMHingeLoss, SSNLoss)
+
+
+def test_hvu_loss():
+    pred = torch.tensor([[-1.0525, -0.7085, 0.1819, -0.8011],
+                         [0.1555, -1.5550, 0.5586, 1.9746]])
+    gt = torch.tensor([[1., 0., 0., 0.], [0., 0., 1., 1.]])
+    mask = torch.tensor([[1., 1., 0., 0.], [0., 0., 1., 1.]])
+    category_mask = torch.tensor([[1., 0.], [0., 1.]])
+    categories = ['action', 'scene']
+    category_nums = (2, 2)
+    category_loss_weights = (1, 1)
+    loss_all_nomask_sum = HVULoss(
+        categories=categories,
+        category_nums=category_nums,
+        category_loss_weights=category_loss_weights,
+        loss_type='all',
+        with_mask=False,
+        reduction='sum')
+    loss = loss_all_nomask_sum(pred, gt, mask, category_mask)
+    loss1 = F.binary_cross_entropy_with_logits(pred, gt, reduction='none')
+    loss1 = torch.sum(loss1, dim=1)
+    assert torch.eq(loss['loss_cls'], torch.mean(loss1))
+
+    loss_all_mask = HVULoss(
+        categories=categories,
+        category_nums=category_nums,
+        category_loss_weights=category_loss_weights,
+        loss_type='all',
+        with_mask=True)
+    loss = loss_all_mask(pred, gt, mask, category_mask)
+    loss1 = F.binary_cross_entropy_with_logits(pred, gt, reduction='none')
+    loss1 = torch.sum(loss1 * mask, dim=1) / torch.sum(mask, dim=1)
+    loss1 = torch.mean(loss1)
+    assert torch.eq(loss['loss_cls'], loss1)
+
+    loss_ind_mask = HVULoss(
+        categories=categories,
+        category_nums=category_nums,
+        category_loss_weights=category_loss_weights,
+        loss_type='individual',
+        with_mask=True)
+    loss = loss_ind_mask(pred, gt, mask, category_mask)
+    action_loss = F.binary_cross_entropy_with_logits(pred[:1, :2], gt[:1, :2])
+    scene_loss = F.binary_cross_entropy_with_logits(pred[1:, 2:], gt[1:, 2:])
+    loss1 = (action_loss + scene_loss) / 2
+    assert torch.eq(loss['loss_cls'], loss1)
+
+    loss_ind_nomask_sum = HVULoss(
+        categories=categories,
+        category_nums=category_nums,
+        category_loss_weights=category_loss_weights,
+        loss_type='individual',
+        with_mask=False,
+        reduction='sum')
+    loss = loss_ind_nomask_sum(pred, gt, mask, category_mask)
+    action_loss = F.binary_cross_entropy_with_logits(
+        pred[:, :2], gt[:, :2], reduction='none')
+    action_loss = torch.sum(action_loss, dim=1)
+    action_loss = torch.mean(action_loss)
+
+    scene_loss = F.binary_cross_entropy_with_logits(
+        pred[:, 2:], gt[:, 2:], reduction='none')
+    scene_loss = torch.sum(scene_loss, dim=1)
+    scene_loss = torch.mean(scene_loss)
+
+    loss1 = (action_loss + scene_loss) / 2
+    assert torch.eq(loss['loss_cls'], loss1)
+
+
+def test_cross_entropy_loss():
+    cls_scores = torch.rand((3, 4))
+    hard_gt_labels = torch.LongTensor([0, 1, 2]).squeeze()
+    soft_gt_labels = torch.FloatTensor([[1, 0, 0, 0], [0, 1, 0, 0],
+                                        [0, 0, 1, 0]]).squeeze()
+
+    # hard label without weight
+    cross_entropy_loss = CrossEntropyLoss()
+    output_loss = cross_entropy_loss(cls_scores, hard_gt_labels)
+    assert torch.equal(output_loss, F.cross_entropy(cls_scores,
+                                                    hard_gt_labels))
+
+    # hard label with class weight
+    weight = torch.rand(4)
+    class_weight = weight.numpy().tolist()
+    cross_entropy_loss = CrossEntropyLoss(class_weight=class_weight)
+    output_loss = cross_entropy_loss(cls_scores, hard_gt_labels)
+    assert torch.equal(
+        output_loss,
+        F.cross_entropy(cls_scores, hard_gt_labels, weight=weight))
+
+    # soft label without class weight
+    cross_entropy_loss = CrossEntropyLoss()
+    output_loss = cross_entropy_loss(cls_scores, soft_gt_labels)
+    assert_almost_equal(
+        output_loss.numpy(),
+        F.cross_entropy(cls_scores, hard_gt_labels).numpy(),
+        decimal=4)
+
+    # soft label with class weight
+    cross_entropy_loss = CrossEntropyLoss(class_weight=class_weight)
+    output_loss = cross_entropy_loss(cls_scores, soft_gt_labels)
+    assert_almost_equal(
+        output_loss.numpy(),
+        F.cross_entropy(cls_scores, hard_gt_labels, weight=weight).numpy(),
+        decimal=4)
+
+
+def test_bce_loss_with_logits():
+    cls_scores = torch.rand((3, 4))
+    gt_labels = torch.rand((3, 4))
+
+    bce_loss_with_logits = BCELossWithLogits()
+    output_loss = bce_loss_with_logits(cls_scores, gt_labels)
+    assert torch.equal(
+        output_loss, F.binary_cross_entropy_with_logits(cls_scores, gt_labels))
+
+    weight = torch.rand(4)
+    class_weight = weight.numpy().tolist()
+    bce_loss_with_logits = BCELossWithLogits(class_weight=class_weight)
+    output_loss = bce_loss_with_logits(cls_scores, gt_labels)
+    assert torch.equal(
+        output_loss,
+        F.binary_cross_entropy_with_logits(
+            cls_scores, gt_labels, weight=weight))
+
+
+def test_nll_loss():
+    cls_scores = torch.randn(3, 3)
+    gt_labels = torch.tensor([0, 2, 1]).squeeze()
+
+    sm = nn.Softmax(dim=0)
+    nll_loss = NLLLoss()
+    cls_score_log = torch.log(sm(cls_scores))
+    output_loss = nll_loss(cls_score_log, gt_labels)
+    assert torch.equal(output_loss, F.nll_loss(cls_score_log, gt_labels))
+
+
+def test_binary_logistic_loss():
+    binary_logistic_regression_loss = BinaryLogisticRegressionLoss()
+    reg_score = torch.tensor([0., 1.])
+    label = torch.tensor([0., 1.])
+    output_loss = binary_logistic_regression_loss(reg_score, label, 0.5)
+    assert_array_almost_equal(output_loss.numpy(), np.array([0.]), decimal=4)
+
+    reg_score = torch.tensor([0.3, 0.9])
+    label = torch.tensor([0., 1.])
+    output_loss = binary_logistic_regression_loss(reg_score, label, 0.5)
+    assert_array_almost_equal(
+        output_loss.numpy(), np.array([0.231]), decimal=4)
+
+
+def test_bmn_loss():
+    bmn_loss = BMNLoss()
+
+    # test tem_loss
+    pred_start = torch.tensor([0.9, 0.1])
+    pred_end = torch.tensor([0.1, 0.9])
+    gt_start = torch.tensor([1., 0.])
+    gt_end = torch.tensor([0., 1.])
+    output_tem_loss = bmn_loss.tem_loss(pred_start, pred_end, gt_start, gt_end)
+    binary_logistic_regression_loss = BinaryLogisticRegressionLoss()
+    assert_loss = (
+        binary_logistic_regression_loss(pred_start, gt_start) +
+        binary_logistic_regression_loss(pred_end, gt_end))
+    assert_array_almost_equal(
+        output_tem_loss.numpy(), assert_loss.numpy(), decimal=4)
+
+    # test pem_reg_loss
+    seed = 1
+    torch.manual_seed(seed)
+    torch.cuda.manual_seed(seed)
+    torch.cuda.manual_seed_all(seed)
+
+    pred_bm_reg = torch.tensor([[0.1, 0.99], [0.5, 0.4]])
+    gt_iou_map = torch.tensor([[0, 1.], [0, 1.]])
+    mask = torch.tensor([[0.1, 0.4], [0.4, 0.1]])
+    output_pem_reg_loss = bmn_loss.pem_reg_loss(pred_bm_reg, gt_iou_map, mask)
+    assert_array_almost_equal(
+        output_pem_reg_loss.numpy(), np.array([0.2140]), decimal=4)
+
+    # test pem_cls_loss
+    pred_bm_cls = torch.tensor([[0.1, 0.99], [0.95, 0.2]])
+    gt_iou_map = torch.tensor([[0., 1.], [0., 1.]])
+    mask = torch.tensor([[0.1, 0.4], [0.4, 0.1]])
+    output_pem_cls_loss = bmn_loss.pem_cls_loss(pred_bm_cls, gt_iou_map, mask)
+    assert_array_almost_equal(
+        output_pem_cls_loss.numpy(), np.array([1.6137]), decimal=4)
+
+    # test bmn_loss
+    pred_bm = torch.tensor([[[[0.1, 0.99], [0.5, 0.4]],
+                             [[0.1, 0.99], [0.95, 0.2]]]])
+    pred_start = torch.tensor([[0.9, 0.1]])
+    pred_end = torch.tensor([[0.1, 0.9]])
+    gt_iou_map = torch.tensor([[[0., 2.5], [0., 10.]]])
+    gt_start = torch.tensor([[1., 0.]])
+    gt_end = torch.tensor([[0., 1.]])
+    mask = torch.tensor([[0.1, 0.4], [0.4, 0.1]])
+    output_loss = bmn_loss(pred_bm, pred_start, pred_end, gt_iou_map, gt_start,
+                           gt_end, mask)
+    assert_array_almost_equal(
+        output_loss[0].numpy(),
+        output_tem_loss + 10 * output_pem_reg_loss + output_pem_cls_loss)
+    assert_array_almost_equal(output_loss[1].numpy(), output_tem_loss)
+    assert_array_almost_equal(output_loss[2].numpy(), output_pem_reg_loss)
+    assert_array_almost_equal(output_loss[3].numpy(), output_pem_cls_loss)
+
+
+def test_ohem_hinge_loss():
+    # test normal case
+    pred = torch.tensor([[
+        0.5161, 0.5228, 0.7748, 0.0573, 0.1113, 0.8862, 0.1752, 0.9448, 0.0253,
+        0.1009, 0.4371, 0.2232, 0.0412, 0.3487, 0.3350, 0.9294, 0.7122, 0.3072,
+        0.2942, 0.7679
+    ]],
+                        requires_grad=True)
+    gt = torch.tensor([8])
+    num_video = 1
+    loss = OHEMHingeLoss.apply(pred, gt, 1, 1.0, num_video)
+    assert_array_almost_equal(
+        loss.detach().numpy(), np.array([0.0552]), decimal=4)
+    loss.backward(Variable(torch.ones([1])))
+    assert_array_almost_equal(
+        np.array(pred.grad),
+        np.array([[
+            0., 0., 0., 0., 0., 0., 0., -1., 0., 0., 0., 0., 0., 0., 0., 0.,
+            0., 0., 0., 0.
+        ]]),
+        decimal=4)
+
+    # test error case
+    with pytest.raises(ValueError):
+        gt = torch.tensor([8, 10])
+        loss = OHEMHingeLoss.apply(pred, gt, 1, 1.0, num_video)
+
+
+def test_ssn_loss():
+    ssn_loss = SSNLoss()
+
+    # test activity_loss
+    activity_score = torch.rand((8, 21))
+    labels = torch.LongTensor([8] * 8).squeeze()
+    activity_indexer = torch.tensor([0, 7])
+    output_activity_loss = ssn_loss.activity_loss(activity_score, labels,
+                                                  activity_indexer)
+    assert torch.equal(
+        output_activity_loss,
+        F.cross_entropy(activity_score[activity_indexer, :],
+                        labels[activity_indexer]))
+
+    # test completeness_loss
+    completeness_score = torch.rand((8, 20), requires_grad=True)
+    labels = torch.LongTensor([8] * 8).squeeze()
+    completeness_indexer = torch.tensor([0, 1, 2, 3, 4, 5, 6])
+    positive_per_video = 1
+    incomplete_per_video = 6
+    output_completeness_loss = ssn_loss.completeness_loss(
+        completeness_score, labels, completeness_indexer, positive_per_video,
+        incomplete_per_video)
+
+    pred = completeness_score[completeness_indexer, :]
+    gt = labels[completeness_indexer]
+    pred_dim = pred.size(1)
+    pred = pred.view(-1, positive_per_video + incomplete_per_video, pred_dim)
+    gt = gt.view(-1, positive_per_video + incomplete_per_video)
+    # yapf:disable
+    positive_pred = pred[:, :positive_per_video, :].contiguous().view(-1, pred_dim)  # noqa:E501
+    incomplete_pred = pred[:, positive_per_video:, :].contiguous().view(-1, pred_dim)  # noqa:E501
+    # yapf:enable
+    ohem_ratio = 0.17
+    positive_loss = OHEMHingeLoss.apply(
+        positive_pred, gt[:, :positive_per_video].contiguous().view(-1), 1,
+        1.0, positive_per_video)
+    incomplete_loss = OHEMHingeLoss.apply(
+        incomplete_pred, gt[:, positive_per_video:].contiguous().view(-1), -1,
+        ohem_ratio, incomplete_per_video)
+    num_positives = positive_pred.size(0)
+    num_incompletes = int(incomplete_pred.size(0) * ohem_ratio)
+    assert_loss = ((positive_loss + incomplete_loss) /
+                   float(num_positives + num_incompletes))
+    assert torch.equal(output_completeness_loss, assert_loss)
+
+    # test reg_loss
+    bbox_pred = torch.rand((8, 20, 2))
+    labels = torch.LongTensor([8] * 8).squeeze()
+    bbox_targets = torch.rand((8, 2))
+    regression_indexer = torch.tensor([0])
+    output_reg_loss = ssn_loss.classwise_regression_loss(
+        bbox_pred, labels, bbox_targets, regression_indexer)
+
+    pred = bbox_pred[regression_indexer, :, :]
+    gt = labels[regression_indexer]
+    reg_target = bbox_targets[regression_indexer, :]
+    class_idx = gt.data - 1
+    classwise_pred = pred[:, class_idx, :]
+    classwise_reg_pred = torch.cat((torch.diag(classwise_pred[:, :, 0]).view(
+        -1, 1), torch.diag(classwise_pred[:, :, 1]).view(-1, 1)),
+                                   dim=1)
+    assert torch.equal(
+        output_reg_loss,
+        F.smooth_l1_loss(classwise_reg_pred.view(-1), reg_target.view(-1)) * 2)
+
+    # test ssn_loss
+    proposal_type = torch.tensor([[0, 1, 1, 1, 1, 1, 1, 2]])
+    train_cfg = ConfigDict(
+        dict(
+            ssn=dict(
+                sampler=dict(
+                    num_per_video=8,
+                    positive_ratio=1,
+                    background_ratio=1,
+                    incomplete_ratio=6,
+                    add_gt_as_proposals=True),
+                loss_weight=dict(comp_loss_weight=0.1, reg_loss_weight=0.1))))
+    output_loss = ssn_loss(activity_score, completeness_score, bbox_pred,
+                           proposal_type, labels, bbox_targets, train_cfg)
+    assert torch.equal(output_loss['loss_activity'], output_activity_loss)
+    assert torch.equal(output_loss['loss_completeness'],
+                       output_completeness_loss * 0.1)
+    assert torch.equal(output_loss['loss_reg'], output_reg_loss * 0.1)