--- a
+++ b/demo/demo_spatiotemporal_det.py
@@ -0,0 +1,428 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import argparse
+import copy as cp
+import os
+import os.path as osp
+import shutil
+
+import cv2
+import mmcv
+import numpy as np
+import torch
+from mmcv import DictAction
+from mmcv.runner import load_checkpoint
+
+from mmaction.models import build_detector
+from mmaction.utils import import_module_error_func
+
+try:
+    from mmdet.apis import inference_detector, init_detector
+except (ImportError, ModuleNotFoundError):
+
+    @import_module_error_func('mmdet')
+    def inference_detector(*args, **kwargs):
+        pass
+
+    @import_module_error_func('mmdet')
+    def init_detector(*args, **kwargs):
+        pass
+
+
+try:
+    import moviepy.editor as mpy
+except ImportError:
+    raise ImportError('Please install moviepy to enable output file')
+
+FONTFACE = cv2.FONT_HERSHEY_DUPLEX
+FONTSCALE = 0.5
+FONTCOLOR = (255, 255, 255)  # BGR, white
+MSGCOLOR = (128, 128, 128)  # BGR, gray
+THICKNESS = 1
+LINETYPE = 1
+
+
+def hex2color(h):
+    """Convert the 6-digit hex string to tuple of 3 int value (RGB)"""
+    return (int(h[:2], 16), int(h[2:4], 16), int(h[4:], 16))
+
+
+plate_blue = '03045e-023e8a-0077b6-0096c7-00b4d8-48cae4'
+plate_blue = plate_blue.split('-')
+plate_blue = [hex2color(h) for h in plate_blue]
+plate_green = '004b23-006400-007200-008000-38b000-70e000'
+plate_green = plate_green.split('-')
+plate_green = [hex2color(h) for h in plate_green]
+
+
+def visualize(frames, annotations, plate=plate_blue, max_num=5):
+    """Visualize frames with predicted annotations.
+
+    Args:
+        frames (list[np.ndarray]): Frames for visualization, note that
+            len(frames) % len(annotations) should be 0.
+        annotations (list[list[tuple]]): The predicted results.
+        plate (str): The plate used for visualization. Default: plate_blue.
+        max_num (int): Max number of labels to visualize for a person box.
+            Default: 5.
+
+    Returns:
+        list[np.ndarray]: Visualized frames.
+    """
+
+    assert max_num + 1 <= len(plate)
+    plate = [x[::-1] for x in plate]
+    frames_ = cp.deepcopy(frames)
+    nf, na = len(frames), len(annotations)
+    assert nf % na == 0
+    nfpa = len(frames) // len(annotations)
+    anno = None
+    h, w, _ = frames[0].shape
+    scale_ratio = np.array([w, h, w, h])
+    for i in range(na):
+        anno = annotations[i]
+        if anno is None:
+            continue
+        for j in range(nfpa):
+            ind = i * nfpa + j
+            frame = frames_[ind]
+            for ann in anno:
+                box = ann[0]
+                label = ann[1]
+                if not len(label):
+                    continue
+                score = ann[2]
+                box = (box * scale_ratio).astype(np.int64)
+                st, ed = tuple(box[:2]), tuple(box[2:])
+                cv2.rectangle(frame, st, ed, plate[0], 2)
+                for k, lb in enumerate(label):
+                    if k >= max_num:
+                        break
+                    text = abbrev(lb)
+                    text = ': '.join([text, str(score[k])])
+                    location = (0 + st[0], 18 + k * 18 + st[1])
+                    textsize = cv2.getTextSize(text, FONTFACE, FONTSCALE,
+                                               THICKNESS)[0]
+                    textwidth = textsize[0]
+                    diag0 = (location[0] + textwidth, location[1] - 14)
+                    diag1 = (location[0], location[1] + 2)
+                    cv2.rectangle(frame, diag0, diag1, plate[k + 1], -1)
+                    cv2.putText(frame, text, location, FONTFACE, FONTSCALE,
+                                FONTCOLOR, THICKNESS, LINETYPE)
+
+    return frames_
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='MMAction2 demo')
+    parser.add_argument(
+        '--config',
+        default=('configs/detection/ava/'
+                 'slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb.py'),
+        help='spatio temporal detection config file path')
+    parser.add_argument(
+        '--checkpoint',
+        default=('https://download.openmmlab.com/mmaction/detection/ava/'
+                 'slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb/'
+                 'slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb'
+                 '_20201217-16378594.pth'),
+        help='spatio temporal detection checkpoint file/url')
+    parser.add_argument(
+        '--det-config',
+        default='demo/faster_rcnn_r50_fpn_2x_coco.py',
+        help='human detection config file path (from mmdet)')
+    parser.add_argument(
+        '--det-checkpoint',
+        default=('http://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/'
+                 'faster_rcnn_r50_fpn_2x_coco/'
+                 'faster_rcnn_r50_fpn_2x_coco_'
+                 'bbox_mAP-0.384_20200504_210434-a5d8aa15.pth'),
+        help='human detection checkpoint file/url')
+    parser.add_argument(
+        '--det-score-thr',
+        type=float,
+        default=0.9,
+        help='the threshold of human detection score')
+    parser.add_argument(
+        '--action-score-thr',
+        type=float,
+        default=0.5,
+        help='the threshold of human action score')
+    parser.add_argument('--video', help='video file/url')
+    parser.add_argument(
+        '--label-map',
+        default='tools/data/ava/label_map.txt',
+        help='label map file')
+    parser.add_argument(
+        '--device', type=str, default='cuda:0', help='CPU/CUDA device option')
+    parser.add_argument(
+        '--out-filename',
+        default='demo/stdet_demo.mp4',
+        help='output filename')
+    parser.add_argument(
+        '--predict-stepsize',
+        default=8,
+        type=int,
+        help='give out a prediction per n frames')
+    parser.add_argument(
+        '--output-stepsize',
+        default=4,
+        type=int,
+        help=('show one frame per n frames in the demo, we should have: '
+              'predict_stepsize % output_stepsize == 0'))
+    parser.add_argument(
+        '--output-fps',
+        default=6,
+        type=int,
+        help='the fps of demo video output')
+    parser.add_argument(
+        '--cfg-options',
+        nargs='+',
+        action=DictAction,
+        default={},
+        help='override some settings in the used config, the key-value pair '
+        'in xxx=yyy format will be merged into config file. For example, '
+        "'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'")
+    args = parser.parse_args()
+    return args
+
+
+def frame_extraction(video_path):
+    """Extract frames given video_path.
+
+    Args:
+        video_path (str): The video_path.
+    """
+    # Load the video, extract frames into ./tmp/video_name
+    target_dir = osp.join('./tmp', osp.basename(osp.splitext(video_path)[0]))
+    os.makedirs(target_dir, exist_ok=True)
+    # Should be able to handle videos up to several hours
+    frame_tmpl = osp.join(target_dir, 'img_{:06d}.jpg')
+    vid = cv2.VideoCapture(video_path)
+    frames = []
+    frame_paths = []
+    flag, frame = vid.read()
+    cnt = 0
+    while flag:
+        frames.append(frame)
+        frame_path = frame_tmpl.format(cnt + 1)
+        frame_paths.append(frame_path)
+        cv2.imwrite(frame_path, frame)
+        cnt += 1
+        flag, frame = vid.read()
+    return frame_paths, frames
+
+
+def detection_inference(args, frame_paths):
+    """Detect human boxes given frame paths.
+
+    Args:
+        args (argparse.Namespace): The arguments.
+        frame_paths (list[str]): The paths of frames to do detection inference.
+
+    Returns:
+        list[np.ndarray]: The human detection results.
+    """
+    model = init_detector(args.det_config, args.det_checkpoint, args.device)
+    assert model.CLASSES[0] == 'person', ('We require you to use a detector '
+                                          'trained on COCO')
+    results = []
+    print('Performing Human Detection for each frame')
+    prog_bar = mmcv.ProgressBar(len(frame_paths))
+    for frame_path in frame_paths:
+        result = inference_detector(model, frame_path)
+        # We only keep human detections with score larger than det_score_thr
+        result = result[0][result[0][:, 4] >= args.det_score_thr]
+        results.append(result)
+        prog_bar.update()
+    return results
+
+
+def load_label_map(file_path):
+    """Load Label Map.
+
+    Args:
+        file_path (str): The file path of label map.
+
+    Returns:
+        dict: The label map (int -> label name).
+    """
+    lines = open(file_path).readlines()
+    lines = [x.strip().split(': ') for x in lines]
+    return {int(x[0]): x[1] for x in lines}
+
+
+def abbrev(name):
+    """Get the abbreviation of label name:
+
+    'take (an object) from (a person)' -> 'take ... from ...'
+    """
+    while name.find('(') != -1:
+        st, ed = name.find('('), name.find(')')
+        name = name[:st] + '...' + name[ed + 1:]
+    return name
+
+
+def pack_result(human_detection, result, img_h, img_w):
+    """Short summary.
+
+    Args:
+        human_detection (np.ndarray): Human detection result.
+        result (type): The predicted label of each human proposal.
+        img_h (int): The image height.
+        img_w (int): The image width.
+
+    Returns:
+        tuple: Tuple of human proposal, label name and label score.
+    """
+    human_detection[:, 0::2] /= img_w
+    human_detection[:, 1::2] /= img_h
+    results = []
+    if result is None:
+        return None
+    for prop, res in zip(human_detection, result):
+        res.sort(key=lambda x: -x[1])
+        results.append(
+            (prop.data.cpu().numpy(), [x[0] for x in res], [x[1]
+                                                            for x in res]))
+    return results
+
+
+def main():
+    args = parse_args()
+
+    frame_paths, original_frames = frame_extraction(args.video)
+    num_frame = len(frame_paths)
+    h, w, _ = original_frames[0].shape
+
+    # resize frames to shortside 256
+    new_w, new_h = mmcv.rescale_size((w, h), (256, np.Inf))
+    frames = [mmcv.imresize(img, (new_w, new_h)) for img in original_frames]
+    w_ratio, h_ratio = new_w / w, new_h / h
+
+    # Get clip_len, frame_interval and calculate center index of each clip
+    config = mmcv.Config.fromfile(args.config)
+    config.merge_from_dict(args.cfg_options)
+    val_pipeline = config.data.val.pipeline
+
+    sampler = [x for x in val_pipeline if x['type'] == 'SampleAVAFrames'][0]
+    clip_len, frame_interval = sampler['clip_len'], sampler['frame_interval']
+    window_size = clip_len * frame_interval
+    assert clip_len % 2 == 0, 'We would like to have an even clip_len'
+    # Note that it's 1 based here
+    timestamps = np.arange(window_size // 2, num_frame + 1 - window_size // 2,
+                           args.predict_stepsize)
+
+    # Load label_map
+    label_map = load_label_map(args.label_map)
+    try:
+        if config['data']['train']['custom_classes'] is not None:
+            label_map = {
+                id + 1: label_map[cls]
+                for id, cls in enumerate(config['data']['train']
+                                         ['custom_classes'])
+            }
+    except KeyError:
+        pass
+
+    # Get Human detection results
+    center_frames = [frame_paths[ind - 1] for ind in timestamps]
+    human_detections = detection_inference(args, center_frames)
+    for i in range(len(human_detections)):
+        det = human_detections[i]
+        det[:, 0:4:2] *= w_ratio
+        det[:, 1:4:2] *= h_ratio
+        human_detections[i] = torch.from_numpy(det[:, :4]).to(args.device)
+
+    # Get img_norm_cfg
+    img_norm_cfg = config['img_norm_cfg']
+    if 'to_rgb' not in img_norm_cfg and 'to_bgr' in img_norm_cfg:
+        to_bgr = img_norm_cfg.pop('to_bgr')
+        img_norm_cfg['to_rgb'] = to_bgr
+    img_norm_cfg['mean'] = np.array(img_norm_cfg['mean'])
+    img_norm_cfg['std'] = np.array(img_norm_cfg['std'])
+
+    # Build STDET model
+    try:
+        # In our spatiotemporal detection demo, different actions should have
+        # the same number of bboxes.
+        config['model']['test_cfg']['rcnn']['action_thr'] = .0
+    except KeyError:
+        pass
+
+    config.model.backbone.pretrained = None
+    model = build_detector(config.model, test_cfg=config.get('test_cfg'))
+
+    load_checkpoint(model, args.checkpoint, map_location='cpu')
+    model.to(args.device)
+    model.eval()
+
+    predictions = []
+
+    print('Performing SpatioTemporal Action Detection for each clip')
+    assert len(timestamps) == len(human_detections)
+    prog_bar = mmcv.ProgressBar(len(timestamps))
+    for timestamp, proposal in zip(timestamps, human_detections):
+        if proposal.shape[0] == 0:
+            predictions.append(None)
+            continue
+
+        start_frame = timestamp - (clip_len // 2 - 1) * frame_interval
+        frame_inds = start_frame + np.arange(0, window_size, frame_interval)
+        frame_inds = list(frame_inds - 1)
+        imgs = [frames[ind].astype(np.float32) for ind in frame_inds]
+        _ = [mmcv.imnormalize_(img, **img_norm_cfg) for img in imgs]
+        # THWC -> CTHW -> 1CTHW
+        input_array = np.stack(imgs).transpose((3, 0, 1, 2))[np.newaxis]
+        input_tensor = torch.from_numpy(input_array).to(args.device)
+
+        with torch.no_grad():
+            result = model(
+                return_loss=False,
+                img=[input_tensor],
+                img_metas=[[dict(img_shape=(new_h, new_w))]],
+                proposals=[[proposal]])
+            result = result[0]
+            prediction = []
+            # N proposals
+            for i in range(proposal.shape[0]):
+                prediction.append([])
+            # Perform action score thr
+            for i in range(len(result)):
+                if i + 1 not in label_map:
+                    continue
+                for j in range(proposal.shape[0]):
+                    if result[i][j, 4] > args.action_score_thr:
+                        prediction[j].append((label_map[i + 1], result[i][j,
+                                                                          4]))
+            predictions.append(prediction)
+        prog_bar.update()
+
+    results = []
+    for human_detection, prediction in zip(human_detections, predictions):
+        results.append(pack_result(human_detection, prediction, new_h, new_w))
+
+    def dense_timestamps(timestamps, n):
+        """Make it nx frames."""
+        old_frame_interval = (timestamps[1] - timestamps[0])
+        start = timestamps[0] - old_frame_interval / n * (n - 1) / 2
+        new_frame_inds = np.arange(
+            len(timestamps) * n) * old_frame_interval / n + start
+        return new_frame_inds.astype(np.int)
+
+    dense_n = int(args.predict_stepsize / args.output_stepsize)
+    frames = [
+        cv2.imread(frame_paths[i - 1])
+        for i in dense_timestamps(timestamps, dense_n)
+    ]
+    print('Performing visualization')
+    vis_frames = visualize(frames, results)
+    vid = mpy.ImageSequenceClip([x[:, :, ::-1] for x in vis_frames],
+                                fps=args.output_fps)
+    vid.write_videofile(args.out_filename)
+
+    tmp_frame_dir = osp.dirname(frame_paths[0])
+    shutil.rmtree(tmp_frame_dir)
+
+
+if __name__ == '__main__':
+    main()