--- a +++ b/demo/demo_skeleton.py @@ -0,0 +1,266 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os +import os.path as osp +import shutil + +import cv2 +import mmcv +import numpy as np +import torch +from mmcv import DictAction + +from mmaction.apis import inference_recognizer, init_recognizer +from mmaction.utils import import_module_error_func + +try: + from mmdet.apis import inference_detector, init_detector + from mmpose.apis import (init_pose_model, inference_top_down_pose_model, + vis_pose_result) +except (ImportError, ModuleNotFoundError): + + @import_module_error_func('mmdet') + def inference_detector(*args, **kwargs): + pass + + @import_module_error_func('mmdet') + def init_detector(*args, **kwargs): + pass + + @import_module_error_func('mmpose') + def init_pose_model(*args, **kwargs): + pass + + @import_module_error_func('mmpose') + def inference_top_down_pose_model(*args, **kwargs): + pass + + @import_module_error_func('mmpose') + def vis_pose_result(*args, **kwargs): + pass + + +try: + import moviepy.editor as mpy +except ImportError: + raise ImportError('Please install moviepy to enable output file') + +FONTFACE = cv2.FONT_HERSHEY_DUPLEX +FONTSCALE = 0.75 +FONTCOLOR = (255, 255, 255) # BGR, white +THICKNESS = 1 +LINETYPE = 1 + + +def parse_args(): + parser = argparse.ArgumentParser(description='MMAction2 demo') + parser.add_argument('video', help='video file/url') + parser.add_argument('out_filename', help='output filename') + parser.add_argument( + '--config', + default=('configs/skeleton/posec3d/' + 'slowonly_r50_u48_240e_ntu120_xsub_keypoint.py'), + help='skeleton model config file path') + parser.add_argument( + '--checkpoint', + default=('https://download.openmmlab.com/mmaction/skeleton/posec3d/' + 'slowonly_r50_u48_240e_ntu120_xsub_keypoint/' + 'slowonly_r50_u48_240e_ntu120_xsub_keypoint-6736b03f.pth'), + help='skeleton model checkpoint file/url') + parser.add_argument( + '--det-config', + default='demo/faster_rcnn_r50_fpn_2x_coco.py', + help='human detection config file path (from mmdet)') + parser.add_argument( + '--det-checkpoint', + default=('http://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/' + 'faster_rcnn_r50_fpn_2x_coco/' + 'faster_rcnn_r50_fpn_2x_coco_' + 'bbox_mAP-0.384_20200504_210434-a5d8aa15.pth'), + help='human detection checkpoint file/url') + parser.add_argument( + '--pose-config', + default='demo/hrnet_w32_coco_256x192.py', + help='human pose estimation config file path (from mmpose)') + parser.add_argument( + '--pose-checkpoint', + default=('https://download.openmmlab.com/mmpose/top_down/hrnet/' + 'hrnet_w32_coco_256x192-c78dce93_20200708.pth'), + help='human pose estimation checkpoint file/url') + parser.add_argument( + '--det-score-thr', + type=float, + default=0.9, + help='the threshold of human detection score') + parser.add_argument( + '--label-map', + default='tools/data/skeleton/label_map_ntu120.txt', + help='label map file') + parser.add_argument( + '--device', type=str, default='cuda:0', help='CPU/CUDA device option') + parser.add_argument( + '--short-side', + type=int, + default=480, + help='specify the short-side length of the image') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + default={}, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. For example, ' + "'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'") + args = parser.parse_args() + return args + + +def frame_extraction(video_path, short_side): + """Extract frames given video_path. + + Args: + video_path (str): The video_path. + """ + # Load the video, extract frames into ./tmp/video_name + target_dir = osp.join('./tmp', osp.basename(osp.splitext(video_path)[0])) + os.makedirs(target_dir, exist_ok=True) + # Should be able to handle videos up to several hours + frame_tmpl = osp.join(target_dir, 'img_{:06d}.jpg') + vid = cv2.VideoCapture(video_path) + frames = [] + frame_paths = [] + flag, frame = vid.read() + cnt = 0 + new_h, new_w = None, None + while flag: + if new_h is None: + h, w, _ = frame.shape + new_w, new_h = mmcv.rescale_size((w, h), (short_side, np.Inf)) + + frame = mmcv.imresize(frame, (new_w, new_h)) + + frames.append(frame) + frame_path = frame_tmpl.format(cnt + 1) + frame_paths.append(frame_path) + + cv2.imwrite(frame_path, frame) + cnt += 1 + flag, frame = vid.read() + + return frame_paths, frames + + +def detection_inference(args, frame_paths): + """Detect human boxes given frame paths. + + Args: + args (argparse.Namespace): The arguments. + frame_paths (list[str]): The paths of frames to do detection inference. + + Returns: + list[np.ndarray]: The human detection results. + """ + model = init_detector(args.det_config, args.det_checkpoint, args.device) + assert model.CLASSES[0] == 'person', ('We require you to use a detector ' + 'trained on COCO') + results = [] + print('Performing Human Detection for each frame') + prog_bar = mmcv.ProgressBar(len(frame_paths)) + for frame_path in frame_paths: + result = inference_detector(model, frame_path) + # We only keep human detections with score larger than det_score_thr + result = result[0][result[0][:, 4] >= args.det_score_thr] + results.append(result) + prog_bar.update() + return results + + +def pose_inference(args, frame_paths, det_results): + model = init_pose_model(args.pose_config, args.pose_checkpoint, + args.device) + ret = [] + print('Performing Human Pose Estimation for each frame') + prog_bar = mmcv.ProgressBar(len(frame_paths)) + for f, d in zip(frame_paths, det_results): + # Align input format + d = [dict(bbox=x) for x in list(d)] + pose = inference_top_down_pose_model(model, f, d, format='xyxy')[0] + ret.append(pose) + prog_bar.update() + return ret + + +def main(): + args = parse_args() + + frame_paths, original_frames = frame_extraction(args.video, + args.short_side) + num_frame = len(frame_paths) + h, w, _ = original_frames[0].shape + + # Get clip_len, frame_interval and calculate center index of each clip + config = mmcv.Config.fromfile(args.config) + config.merge_from_dict(args.cfg_options) + for component in config.data.test.pipeline: + if component['type'] == 'PoseNormalize': + component['mean'] = (w // 2, h // 2, .5) + component['max_value'] = (w, h, 1.) + + model = init_recognizer(config, args.checkpoint, args.device) + + # Load label_map + label_map = [x.strip() for x in open(args.label_map).readlines()] + + # Get Human detection results + det_results = detection_inference(args, frame_paths) + torch.cuda.empty_cache() + + pose_results = pose_inference(args, frame_paths, det_results) + torch.cuda.empty_cache() + + fake_anno = dict( + frame_dir='', + label=-1, + img_shape=(h, w), + original_shape=(h, w), + start_index=0, + modality='Pose', + total_frames=num_frame) + num_person = max([len(x) for x in pose_results]) + + num_keypoint = 17 + keypoint = np.zeros((num_person, num_frame, num_keypoint, 2), + dtype=np.float16) + keypoint_score = np.zeros((num_person, num_frame, num_keypoint), + dtype=np.float16) + for i, poses in enumerate(pose_results): + for j, pose in enumerate(poses): + pose = pose['keypoints'] + keypoint[j, i] = pose[:, :2] + keypoint_score[j, i] = pose[:, 2] + fake_anno['keypoint'] = keypoint + fake_anno['keypoint_score'] = keypoint_score + + results = inference_recognizer(model, fake_anno) + + action_label = label_map[results[0][0]] + + pose_model = init_pose_model(args.pose_config, args.pose_checkpoint, + args.device) + vis_frames = [ + vis_pose_result(pose_model, frame_paths[i], pose_results[i]) + for i in range(num_frame) + ] + for frame in vis_frames: + cv2.putText(frame, action_label, (10, 30), FONTFACE, FONTSCALE, + FONTCOLOR, THICKNESS, LINETYPE) + + vid = mpy.ImageSequenceClip([x[:, :, ::-1] for x in vis_frames], fps=24) + vid.write_videofile(args.out_filename, remove_temp=True) + + tmp_frame_dir = osp.dirname(frame_paths[0]) + shutil.rmtree(tmp_frame_dir) + + +if __name__ == '__main__': + main()