Switch to side-by-side view

--- a
+++ b/demo/demo_video_structuralize.py
@@ -0,0 +1,797 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import argparse
+import copy as cp
+import os
+import os.path as osp
+import shutil
+
+import cv2
+import mmcv
+import numpy as np
+import torch
+from mmcv import DictAction
+from mmcv.runner import load_checkpoint
+
+from mmaction.apis import inference_recognizer
+from mmaction.datasets.pipelines import Compose
+from mmaction.models import build_detector, build_model, build_recognizer
+from mmaction.utils import import_module_error_func
+
+try:
+    from mmdet.apis import inference_detector, init_detector
+    from mmpose.apis import (init_pose_model, inference_top_down_pose_model,
+                             vis_pose_result)
+
+except (ImportError, ModuleNotFoundError):
+
+    @import_module_error_func('mmdet')
+    def inference_detector(*args, **kwargs):
+        pass
+
+    @import_module_error_func('mmdet')
+    def init_detector(*args, **kwargs):
+        pass
+
+    @import_module_error_func('mmpose')
+    def init_pose_model(*args, **kwargs):
+        pass
+
+    @import_module_error_func('mmpose')
+    def inference_top_down_pose_model(*args, **kwargs):
+        pass
+
+    @import_module_error_func('mmpose')
+    def vis_pose_result(*args, **kwargs):
+        pass
+
+
+try:
+    import moviepy.editor as mpy
+except ImportError:
+    raise ImportError('Please install moviepy to enable output file')
+
+FONTFACE = cv2.FONT_HERSHEY_DUPLEX
+FONTSCALE = 0.5
+FONTCOLOR = (255, 255, 255)  # BGR, white
+MSGCOLOR = (128, 128, 128)  # BGR, gray
+THICKNESS = 1
+LINETYPE = 1
+
+
+def hex2color(h):
+    """Convert the 6-digit hex string to tuple of 3 int value (RGB)"""
+    return (int(h[:2], 16), int(h[2:4], 16), int(h[4:], 16))
+
+
+PLATEBLUE = '03045e-023e8a-0077b6-0096c7-00b4d8-48cae4'
+PLATEBLUE = PLATEBLUE.split('-')
+PLATEBLUE = [hex2color(h) for h in PLATEBLUE]
+PLATEGREEN = '004b23-006400-007200-008000-38b000-70e000'
+PLATEGREEN = PLATEGREEN.split('-')
+PLATEGREEN = [hex2color(h) for h in PLATEGREEN]
+
+
+def visualize(frames,
+              annotations,
+              pose_results,
+              action_result,
+              pose_model,
+              plate=PLATEBLUE,
+              max_num=5):
+    """Visualize frames with predicted annotations.
+
+    Args:
+        frames (list[np.ndarray]): Frames for visualization, note that
+            len(frames) % len(annotations) should be 0.
+        annotations (list[list[tuple]]): The predicted spatio-temporal
+            detection results.
+        pose_results (list[list[tuple]): The pose results.
+        action_result (str): The predicted action recognition results.
+        pose_model (nn.Module): The constructed pose model.
+        plate (str): The plate used for visualization. Default: PLATEBLUE.
+        max_num (int): Max number of labels to visualize for a person box.
+            Default: 5.
+
+    Returns:
+        list[np.ndarray]: Visualized frames.
+    """
+
+    assert max_num + 1 <= len(plate)
+    plate = [x[::-1] for x in plate]
+    frames_ = cp.deepcopy(frames)
+    nf, na = len(frames), len(annotations)
+    assert nf % na == 0
+    nfpa = len(frames) // len(annotations)
+    anno = None
+    h, w, _ = frames[0].shape
+    scale_ratio = np.array([w, h, w, h])
+
+    # add pose results
+    if pose_results:
+        for i in range(nf):
+            frames_[i] = vis_pose_result(pose_model, frames_[i],
+                                         pose_results[i])
+
+    for i in range(na):
+        anno = annotations[i]
+        if anno is None:
+            continue
+        for j in range(nfpa):
+            ind = i * nfpa + j
+            frame = frames_[ind]
+
+            # add action result for whole video
+            cv2.putText(frame, action_result, (10, 30), FONTFACE, FONTSCALE,
+                        FONTCOLOR, THICKNESS, LINETYPE)
+
+            # add spatio-temporal action detection results
+            for ann in anno:
+                box = ann[0]
+                label = ann[1]
+                if not len(label):
+                    continue
+                score = ann[2]
+                box = (box * scale_ratio).astype(np.int64)
+                st, ed = tuple(box[:2]), tuple(box[2:])
+                if not pose_results:
+                    cv2.rectangle(frame, st, ed, plate[0], 2)
+
+                for k, lb in enumerate(label):
+                    if k >= max_num:
+                        break
+                    text = abbrev(lb)
+                    text = ': '.join([text, str(score[k])])
+                    location = (0 + st[0], 18 + k * 18 + st[1])
+                    textsize = cv2.getTextSize(text, FONTFACE, FONTSCALE,
+                                               THICKNESS)[0]
+                    textwidth = textsize[0]
+                    diag0 = (location[0] + textwidth, location[1] - 14)
+                    diag1 = (location[0], location[1] + 2)
+                    cv2.rectangle(frame, diag0, diag1, plate[k + 1], -1)
+                    cv2.putText(frame, text, location, FONTFACE, FONTSCALE,
+                                FONTCOLOR, THICKNESS, LINETYPE)
+
+    return frames_
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='MMAction2 demo')
+    parser.add_argument(
+        '--rgb-stdet-config',
+        default=('configs/detection/ava/'
+                 'slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb.py'),
+        help='rgb-based spatio temporal detection config file path')
+    parser.add_argument(
+        '--rgb-stdet-checkpoint',
+        default=('https://download.openmmlab.com/mmaction/detection/ava/'
+                 'slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb/'
+                 'slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb'
+                 '_20201217-16378594.pth'),
+        help='rgb-based spatio temporal detection checkpoint file/url')
+    parser.add_argument(
+        '--skeleton-stdet-checkpoint',
+        default=('https://download.openmmlab.com/mmaction/skeleton/posec3d/'
+                 'posec3d_ava.pth'),
+        help='skeleton-based spatio temporal detection checkpoint file/url')
+    parser.add_argument(
+        '--det-config',
+        default='demo/faster_rcnn_r50_fpn_2x_coco.py',
+        help='human detection config file path (from mmdet)')
+    parser.add_argument(
+        '--det-checkpoint',
+        default=('http://download.openmmlab.com/mmdetection/v2.0/'
+                 'faster_rcnn/faster_rcnn_r50_fpn_2x_coco/'
+                 'faster_rcnn_r50_fpn_2x_coco_'
+                 'bbox_mAP-0.384_20200504_210434-a5d8aa15.pth'),
+        help='human detection checkpoint file/url')
+    parser.add_argument(
+        '--pose-config',
+        default='demo/hrnet_w32_coco_256x192.py',
+        help='human pose estimation config file path (from mmpose)')
+    parser.add_argument(
+        '--pose-checkpoint',
+        default=('https://download.openmmlab.com/mmpose/top_down/hrnet/'
+                 'hrnet_w32_coco_256x192-c78dce93_20200708.pth'),
+        help='human pose estimation checkpoint file/url')
+    parser.add_argument(
+        '--skeleton-config',
+        default='configs/skeleton/posec3d/'
+        'slowonly_r50_u48_240e_ntu120_xsub_keypoint.py',
+        help='skeleton-based action recognition config file path')
+    parser.add_argument(
+        '--skeleton-checkpoint',
+        default='https://download.openmmlab.com/mmaction/skeleton/posec3d/'
+        'posec3d_k400.pth',
+        help='skeleton-based action recognition checkpoint file/url')
+    parser.add_argument(
+        '--rgb-config',
+        default='configs/recognition/tsn/'
+        'tsn_r50_video_inference_1x1x3_100e_kinetics400_rgb.py',
+        help='rgb-based action recognition config file path')
+    parser.add_argument(
+        '--rgb-checkpoint',
+        default='https://download.openmmlab.com/mmaction/recognition/'
+        'tsn/tsn_r50_1x1x3_100e_kinetics400_rgb/'
+        'tsn_r50_1x1x3_100e_kinetics400_rgb_20200614-e508be42.pth',
+        help='rgb-based action recognition checkpoint file/url')
+    parser.add_argument(
+        '--use-skeleton-stdet',
+        action='store_true',
+        help='use skeleton-based spatio temporal detection method')
+    parser.add_argument(
+        '--use-skeleton-recog',
+        action='store_true',
+        help='use skeleton-based action recognition method')
+    parser.add_argument(
+        '--det-score-thr',
+        type=float,
+        default=0.9,
+        help='the threshold of human detection score')
+    parser.add_argument(
+        '--action-score-thr',
+        type=float,
+        default=0.4,
+        help='the threshold of action prediction score')
+    parser.add_argument(
+        '--video',
+        default='demo/test_video_structuralize.mp4',
+        help='video file/url')
+    parser.add_argument(
+        '--label-map-stdet',
+        default='tools/data/ava/label_map.txt',
+        help='label map file for spatio-temporal action detection')
+    parser.add_argument(
+        '--label-map',
+        default='tools/data/kinetics/label_map_k400.txt',
+        help='label map file for action recognition')
+    parser.add_argument(
+        '--device', type=str, default='cuda:0', help='CPU/CUDA device option')
+    parser.add_argument(
+        '--out-filename',
+        default='demo/test_stdet_recognition_output.mp4',
+        help='output filename')
+    parser.add_argument(
+        '--predict-stepsize',
+        default=8,
+        type=int,
+        help='give out a spatio-temporal detection prediction per n frames')
+    parser.add_argument(
+        '--output-stepsize',
+        default=1,
+        type=int,
+        help=('show one frame per n frames in the demo, we should have: '
+              'predict_stepsize % output_stepsize == 0'))
+    parser.add_argument(
+        '--output-fps',
+        default=24,
+        type=int,
+        help='the fps of demo video output')
+    parser.add_argument(
+        '--cfg-options',
+        nargs='+',
+        action=DictAction,
+        default={},
+        help='override some settings in the used config, the key-value pair '
+        'in xxx=yyy format will be merged into config file. For example, '
+        "'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'")
+    args = parser.parse_args()
+    return args
+
+
+def frame_extraction(video_path):
+    """Extract frames given video_path.
+
+    Args:
+        video_path (str): The video_path.
+    """
+    # Load the video, extract frames into ./tmp/video_name
+    target_dir = osp.join('./tmp', osp.basename(osp.splitext(video_path)[0]))
+    # target_dir = osp.join('./tmp','spatial_skeleton_dir')
+    os.makedirs(target_dir, exist_ok=True)
+    # Should be able to handle videos up to several hours
+    frame_tmpl = osp.join(target_dir, 'img_{:06d}.jpg')
+    vid = cv2.VideoCapture(video_path)
+    frames = []
+    frame_paths = []
+    flag, frame = vid.read()
+    cnt = 0
+    while flag:
+        frames.append(frame)
+        frame_path = frame_tmpl.format(cnt + 1)
+        frame_paths.append(frame_path)
+        cv2.imwrite(frame_path, frame)
+        cnt += 1
+        flag, frame = vid.read()
+    return frame_paths, frames
+
+
+def detection_inference(args, frame_paths):
+    """Detect human boxes given frame paths.
+
+    Args:
+        args (argparse.Namespace): The arguments.
+        frame_paths (list[str]): The paths of frames to do detection inference.
+
+    Returns:
+        list[np.ndarray]: The human detection results.
+    """
+    model = init_detector(args.det_config, args.det_checkpoint, args.device)
+    assert model.CLASSES[0] == 'person', ('We require you to use a detector '
+                                          'trained on COCO')
+    results = []
+    print('Performing Human Detection for each frame')
+    prog_bar = mmcv.ProgressBar(len(frame_paths))
+    for frame_path in frame_paths:
+        result = inference_detector(model, frame_path)
+        # We only keep human detections with score larger than det_score_thr
+        result = result[0][result[0][:, 4] >= args.det_score_thr]
+        results.append(result)
+        prog_bar.update()
+
+    return results
+
+
+def pose_inference(args, frame_paths, det_results):
+    model = init_pose_model(args.pose_config, args.pose_checkpoint,
+                            args.device)
+    ret = []
+    print('Performing Human Pose Estimation for each frame')
+    prog_bar = mmcv.ProgressBar(len(frame_paths))
+    for f, d in zip(frame_paths, det_results):
+        # Align input format
+        d = [dict(bbox=x) for x in list(d)]
+
+        pose = inference_top_down_pose_model(model, f, d, format='xyxy')[0]
+        ret.append(pose)
+        prog_bar.update()
+    return ret
+
+
+def load_label_map(file_path):
+    """Load Label Map.
+
+    Args:
+        file_path (str): The file path of label map.
+
+    Returns:
+        dict: The label map (int -> label name).
+    """
+    lines = open(file_path).readlines()
+    lines = [x.strip().split(': ') for x in lines]
+    return {int(x[0]): x[1] for x in lines}
+
+
+def abbrev(name):
+    """Get the abbreviation of label name:
+
+    'take (an object) from (a person)' -> 'take ... from ...'
+    """
+    while name.find('(') != -1:
+        st, ed = name.find('('), name.find(')')
+        name = name[:st] + '...' + name[ed + 1:]
+    return name
+
+
+def pack_result(human_detection, result, img_h, img_w):
+    """Short summary.
+
+    Args:
+        human_detection (np.ndarray): Human detection result.
+        result (type): The predicted label of each human proposal.
+        img_h (int): The image height.
+        img_w (int): The image width.
+
+    Returns:
+        tuple: Tuple of human proposal, label name and label score.
+    """
+    human_detection[:, 0::2] /= img_w
+    human_detection[:, 1::2] /= img_h
+    results = []
+    if result is None:
+        return None
+    for prop, res in zip(human_detection, result):
+        res.sort(key=lambda x: -x[1])
+        results.append(
+            (prop.data.cpu().numpy(), [x[0] for x in res], [x[1]
+                                                            for x in res]))
+    return results
+
+
+def expand_bbox(bbox, h, w, ratio=1.25):
+    x1, y1, x2, y2 = bbox
+    center_x = (x1 + x2) // 2
+    center_y = (y1 + y2) // 2
+    width = x2 - x1
+    height = y2 - y1
+
+    square_l = max(width, height)
+    new_width = new_height = square_l * ratio
+
+    new_x1 = max(0, int(center_x - new_width / 2))
+    new_x2 = min(int(center_x + new_width / 2), w)
+    new_y1 = max(0, int(center_y - new_height / 2))
+    new_y2 = min(int(center_y + new_height / 2), h)
+    return (new_x1, new_y1, new_x2, new_y2)
+
+
+def cal_iou(box1, box2):
+    xmin1, ymin1, xmax1, ymax1 = box1
+    xmin2, ymin2, xmax2, ymax2 = box2
+
+    s1 = (xmax1 - xmin1) * (ymax1 - ymin1)
+    s2 = (xmax2 - xmin2) * (ymax2 - ymin2)
+
+    xmin = max(xmin1, xmin2)
+    ymin = max(ymin1, ymin2)
+    xmax = min(xmax1, xmax2)
+    ymax = min(ymax1, ymax2)
+
+    w = max(0, xmax - xmin)
+    h = max(0, ymax - ymin)
+    intersect = w * h
+    union = s1 + s2 - intersect
+    iou = intersect / union
+
+    return iou
+
+
+def skeleton_based_action_recognition(args, pose_results, num_frame, h, w):
+    fake_anno = dict(
+        frame_dict='',
+        label=-1,
+        img_shape=(h, w),
+        origin_shape=(h, w),
+        start_index=0,
+        modality='Pose',
+        total_frames=num_frame)
+    num_person = max([len(x) for x in pose_results])
+
+    num_keypoint = 17
+    keypoint = np.zeros((num_person, num_frame, num_keypoint, 2),
+                        dtype=np.float16)
+    keypoint_score = np.zeros((num_person, num_frame, num_keypoint),
+                              dtype=np.float16)
+    for i, poses in enumerate(pose_results):
+        for j, pose in enumerate(poses):
+            pose = pose['keypoints']
+            keypoint[j, i] = pose[:, :2]
+            keypoint_score[j, i] = pose[:, 2]
+
+    fake_anno['keypoint'] = keypoint
+    fake_anno['keypoint_score'] = keypoint_score
+
+    label_map = [x.strip() for x in open(args.label_map).readlines()]
+    num_class = len(label_map)
+
+    skeleton_config = mmcv.Config.fromfile(args.skeleton_config)
+    skeleton_config.model.cls_head.num_classes = num_class  # for K400 dataset
+    skeleton_pipeline = Compose(skeleton_config.test_pipeline)
+    skeleton_imgs = skeleton_pipeline(fake_anno)['imgs'][None]
+    skeleton_imgs = skeleton_imgs.to(args.device)
+
+    # Build skeleton-based recognition model
+    skeleton_model = build_model(skeleton_config.model)
+    load_checkpoint(
+        skeleton_model, args.skeleton_checkpoint, map_location='cpu')
+    skeleton_model.to(args.device)
+    skeleton_model.eval()
+
+    with torch.no_grad():
+        output = skeleton_model(return_loss=False, imgs=skeleton_imgs)
+
+    action_idx = np.argmax(output)
+    skeleton_action_result = label_map[
+        action_idx]  # skeleton-based action result for the whole video
+    return skeleton_action_result
+
+
+def rgb_based_action_recognition(args):
+    rgb_config = mmcv.Config.fromfile(args.rgb_config)
+    rgb_config.model.backbone.pretrained = None
+    rgb_model = build_recognizer(
+        rgb_config.model, test_cfg=rgb_config.get('test_cfg'))
+    load_checkpoint(rgb_model, args.rgb_checkpoint, map_location='cpu')
+    rgb_model.cfg = rgb_config
+    rgb_model.to(args.device)
+    rgb_model.eval()
+    action_results = inference_recognizer(rgb_model, args.video,
+                                          args.label_map)
+    rgb_action_result = action_results[0][0]
+    return rgb_action_result
+
+
+def skeleton_based_stdet(args, label_map, human_detections, pose_results,
+                         num_frame, clip_len, frame_interval, h, w):
+    window_size = clip_len * frame_interval
+    assert clip_len % 2 == 0, 'We would like to have an even clip_len'
+    timestamps = np.arange(window_size // 2, num_frame + 1 - window_size // 2,
+                           args.predict_stepsize)
+
+    skeleton_config = mmcv.Config.fromfile(args.skeleton_config)
+    num_class = max(label_map.keys()) + 1  # for AVA dataset (81)
+    skeleton_config.model.cls_head.num_classes = num_class
+    skeleton_pipeline = Compose(skeleton_config.test_pipeline)
+    skeleton_stdet_model = build_model(skeleton_config.model)
+    load_checkpoint(
+        skeleton_stdet_model,
+        args.skeleton_stdet_checkpoint,
+        map_location='cpu')
+    skeleton_stdet_model.to(args.device)
+    skeleton_stdet_model.eval()
+
+    skeleton_predictions = []
+
+    print('Performing SpatioTemporal Action Detection for each clip')
+    prog_bar = mmcv.ProgressBar(len(timestamps))
+    for timestamp in timestamps:
+        proposal = human_detections[timestamp - 1]
+        if proposal.shape[0] == 0:  # no people detected
+            skeleton_predictions.append(None)
+            continue
+
+        start_frame = timestamp - (clip_len // 2 - 1) * frame_interval
+        frame_inds = start_frame + np.arange(0, window_size, frame_interval)
+        frame_inds = list(frame_inds - 1)
+        num_frame = len(frame_inds)  # 30
+
+        pose_result = [pose_results[ind] for ind in frame_inds]
+
+        skeleton_prediction = []
+        for i in range(proposal.shape[0]):  # num_person
+            skeleton_prediction.append([])
+
+            fake_anno = dict(
+                frame_dict='',
+                label=-1,
+                img_shape=(h, w),
+                origin_shape=(h, w),
+                start_index=0,
+                modality='Pose',
+                total_frames=num_frame)
+            num_person = 1
+
+            num_keypoint = 17
+            keypoint = np.zeros(
+                (num_person, num_frame, num_keypoint, 2))  # M T V 2
+            keypoint_score = np.zeros(
+                (num_person, num_frame, num_keypoint))  # M T V
+
+            # pose matching
+            person_bbox = proposal[i][:4]
+            area = expand_bbox(person_bbox, h, w)
+
+            for j, poses in enumerate(pose_result):  # num_frame
+                max_iou = float('-inf')
+                index = -1
+                if len(poses) == 0:
+                    continue
+                for k, per_pose in enumerate(poses):
+                    iou = cal_iou(per_pose['bbox'][:4], area)
+                    if max_iou < iou:
+                        index = k
+                        max_iou = iou
+                keypoint[0, j] = poses[index]['keypoints'][:, :2]
+                keypoint_score[0, j] = poses[index]['keypoints'][:, 2]
+
+            fake_anno['keypoint'] = keypoint
+            fake_anno['keypoint_score'] = keypoint_score
+
+            skeleton_imgs = skeleton_pipeline(fake_anno)['imgs'][None]
+            skeleton_imgs = skeleton_imgs.to(args.device)
+
+            with torch.no_grad():
+                output = skeleton_stdet_model(
+                    return_loss=False, imgs=skeleton_imgs)
+                output = output[0]
+                for k in range(len(output)):  # 81
+                    if k not in label_map:
+                        continue
+                    if output[k] > args.action_score_thr:
+                        skeleton_prediction[i].append(
+                            (label_map[k], output[k]))
+
+        skeleton_predictions.append(skeleton_prediction)
+        prog_bar.update()
+
+    return timestamps, skeleton_predictions
+
+
+def rgb_based_stdet(args, frames, label_map, human_detections, w, h, new_w,
+                    new_h, w_ratio, h_ratio):
+
+    rgb_stdet_config = mmcv.Config.fromfile(args.rgb_stdet_config)
+    rgb_stdet_config.merge_from_dict(args.cfg_options)
+
+    val_pipeline = rgb_stdet_config.data.val.pipeline
+    sampler = [x for x in val_pipeline if x['type'] == 'SampleAVAFrames'][0]
+    clip_len, frame_interval = sampler['clip_len'], sampler['frame_interval']
+    assert clip_len % 2 == 0, 'We would like to have an even clip_len'
+
+    window_size = clip_len * frame_interval
+    num_frame = len(frames)
+    timestamps = np.arange(window_size // 2, num_frame + 1 - window_size // 2,
+                           args.predict_stepsize)
+
+    # Get img_norm_cfg
+    img_norm_cfg = rgb_stdet_config['img_norm_cfg']
+    if 'to_rgb' not in img_norm_cfg and 'to_bgr' in img_norm_cfg:
+        to_bgr = img_norm_cfg.pop('to_bgr')
+        img_norm_cfg['to_rgb'] = to_bgr
+    img_norm_cfg['mean'] = np.array(img_norm_cfg['mean'])
+    img_norm_cfg['std'] = np.array(img_norm_cfg['std'])
+
+    # Build STDET model
+    try:
+        # In our spatiotemporal detection demo, different actions should have
+        # the same number of bboxes.
+        rgb_stdet_config['model']['test_cfg']['rcnn']['action_thr'] = .0
+    except KeyError:
+        pass
+
+    rgb_stdet_config.model.backbone.pretrained = None
+    rgb_stdet_model = build_detector(
+        rgb_stdet_config.model, test_cfg=rgb_stdet_config.get('test_cfg'))
+
+    load_checkpoint(
+        rgb_stdet_model, args.rgb_stdet_checkpoint, map_location='cpu')
+    rgb_stdet_model.to(args.device)
+    rgb_stdet_model.eval()
+
+    predictions = []
+
+    print('Performing SpatioTemporal Action Detection for each clip')
+    prog_bar = mmcv.ProgressBar(len(timestamps))
+    for timestamp in timestamps:
+        proposal = human_detections[timestamp - 1]
+
+        if proposal.shape[0] == 0:
+            predictions.append(None)
+            continue
+
+        start_frame = timestamp - (clip_len // 2 - 1) * frame_interval
+        frame_inds = start_frame + np.arange(0, window_size, frame_interval)
+        frame_inds = list(frame_inds - 1)
+
+        imgs = [frames[ind].astype(np.float32) for ind in frame_inds]
+        _ = [mmcv.imnormalize_(img, **img_norm_cfg) for img in imgs]
+        # THWC -> CTHW -> 1CTHW
+        input_array = np.stack(imgs).transpose((3, 0, 1, 2))[np.newaxis]
+        input_tensor = torch.from_numpy(input_array).to(args.device)
+
+        with torch.no_grad():
+            result = rgb_stdet_model(
+                return_loss=False,
+                img=[input_tensor],
+                img_metas=[[dict(img_shape=(new_h, new_w))]],
+                proposals=[[proposal]])
+            result = result[0]
+            prediction = []
+            # N proposals
+            for i in range(proposal.shape[0]):
+                prediction.append([])
+
+            # Perform action score thr
+            for i in range(len(result)):  # 80
+                if i + 1 not in label_map:
+                    continue
+                for j in range(proposal.shape[0]):
+                    if result[i][j, 4] > args.action_score_thr:
+                        prediction[j].append((label_map[i + 1], result[i][j,
+                                                                          4]))
+            predictions.append(prediction)
+        prog_bar.update()
+
+    return timestamps, predictions
+
+
+def main():
+    args = parse_args()
+
+    frame_paths, original_frames = frame_extraction(args.video)
+    num_frame = len(frame_paths)
+    h, w, _ = original_frames[0].shape
+
+    # Get Human detection results and pose results
+    human_detections = detection_inference(args, frame_paths)
+    pose_results = None
+    if args.use_skeleton_recog or args.use_skeleton_stdet:
+        pose_results = pose_inference(args, frame_paths, human_detections)
+
+    # resize frames to shortside 256
+    new_w, new_h = mmcv.rescale_size((w, h), (256, np.Inf))
+    frames = [mmcv.imresize(img, (new_w, new_h)) for img in original_frames]
+    w_ratio, h_ratio = new_w / w, new_h / h
+
+    # Load spatio-temporal detection label_map
+    stdet_label_map = load_label_map(args.label_map_stdet)
+    rgb_stdet_config = mmcv.Config.fromfile(args.rgb_stdet_config)
+    rgb_stdet_config.merge_from_dict(args.cfg_options)
+    try:
+        if rgb_stdet_config['data']['train']['custom_classes'] is not None:
+            stdet_label_map = {
+                id + 1: stdet_label_map[cls]
+                for id, cls in enumerate(rgb_stdet_config['data']['train']
+                                         ['custom_classes'])
+            }
+    except KeyError:
+        pass
+
+    action_result = None
+    if args.use_skeleton_recog:
+        print('Use skeleton-based recognition')
+        action_result = skeleton_based_action_recognition(
+            args, pose_results, num_frame, h, w)
+    else:
+        print('Use rgb-based recognition')
+        action_result = rgb_based_action_recognition(args)
+
+    stdet_preds = None
+    if args.use_skeleton_stdet:
+        print('Use skeleton-based SpatioTemporal Action Detection')
+        clip_len, frame_interval = 30, 1
+        timestamps, stdet_preds = skeleton_based_stdet(args, stdet_label_map,
+                                                       human_detections,
+                                                       pose_results, num_frame,
+                                                       clip_len,
+                                                       frame_interval, h, w)
+        for i in range(len(human_detections)):
+            det = human_detections[i]
+            det[:, 0:4:2] *= w_ratio
+            det[:, 1:4:2] *= h_ratio
+            human_detections[i] = torch.from_numpy(det[:, :4]).to(args.device)
+
+    else:
+        print('Use rgb-based SpatioTemporal Action Detection')
+        for i in range(len(human_detections)):
+            det = human_detections[i]
+            det[:, 0:4:2] *= w_ratio
+            det[:, 1:4:2] *= h_ratio
+            human_detections[i] = torch.from_numpy(det[:, :4]).to(args.device)
+        timestamps, stdet_preds = rgb_based_stdet(args, frames,
+                                                  stdet_label_map,
+                                                  human_detections, w, h,
+                                                  new_w, new_h, w_ratio,
+                                                  h_ratio)
+
+    stdet_results = []
+    for timestamp, prediction in zip(timestamps, stdet_preds):
+        human_detection = human_detections[timestamp - 1]
+        stdet_results.append(
+            pack_result(human_detection, prediction, new_h, new_w))
+
+    def dense_timestamps(timestamps, n):
+        """Make it nx frames."""
+        old_frame_interval = (timestamps[1] - timestamps[0])
+        start = timestamps[0] - old_frame_interval / n * (n - 1) / 2
+        new_frame_inds = np.arange(
+            len(timestamps) * n) * old_frame_interval / n + start
+        return new_frame_inds.astype(np.int)
+
+    dense_n = int(args.predict_stepsize / args.output_stepsize)
+    output_timestamps = dense_timestamps(timestamps, dense_n)
+    frames = [
+        cv2.imread(frame_paths[timestamp - 1])
+        for timestamp in output_timestamps
+    ]
+
+    print('Performing visualization')
+    pose_model = init_pose_model(args.pose_config, args.pose_checkpoint,
+                                 args.device)
+
+    if args.use_skeleton_recog or args.use_skeleton_stdet:
+        pose_results = [
+            pose_results[timestamp - 1] for timestamp in output_timestamps
+        ]
+
+    vis_frames = visualize(frames, stdet_results, pose_results, action_result,
+                           pose_model)
+    vid = mpy.ImageSequenceClip([x[:, :, ::-1] for x in vis_frames],
+                                fps=args.output_fps)
+    vid.write_videofile(args.out_filename)
+
+    tmp_frame_dir = osp.dirname(frame_paths[0])
+    shutil.rmtree(tmp_frame_dir)
+
+
+if __name__ == '__main__':
+    main()