Diff of /demo/demo_skeleton.py [000000] .. [6d389a]

Switch to unified view

a b/demo/demo_skeleton.py
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import argparse
3
import os
4
import os.path as osp
5
import shutil
6
7
import cv2
8
import mmcv
9
import numpy as np
10
import torch
11
from mmcv import DictAction
12
13
from mmaction.apis import inference_recognizer, init_recognizer
14
from mmaction.utils import import_module_error_func
15
16
try:
17
    from mmdet.apis import inference_detector, init_detector
18
    from mmpose.apis import (init_pose_model, inference_top_down_pose_model,
19
                             vis_pose_result)
20
except (ImportError, ModuleNotFoundError):
21
22
    @import_module_error_func('mmdet')
23
    def inference_detector(*args, **kwargs):
24
        pass
25
26
    @import_module_error_func('mmdet')
27
    def init_detector(*args, **kwargs):
28
        pass
29
30
    @import_module_error_func('mmpose')
31
    def init_pose_model(*args, **kwargs):
32
        pass
33
34
    @import_module_error_func('mmpose')
35
    def inference_top_down_pose_model(*args, **kwargs):
36
        pass
37
38
    @import_module_error_func('mmpose')
39
    def vis_pose_result(*args, **kwargs):
40
        pass
41
42
43
try:
44
    import moviepy.editor as mpy
45
except ImportError:
46
    raise ImportError('Please install moviepy to enable output file')
47
48
FONTFACE = cv2.FONT_HERSHEY_DUPLEX
49
FONTSCALE = 0.75
50
FONTCOLOR = (255, 255, 255)  # BGR, white
51
THICKNESS = 1
52
LINETYPE = 1
53
54
55
def parse_args():
56
    parser = argparse.ArgumentParser(description='MMAction2 demo')
57
    parser.add_argument('video', help='video file/url')
58
    parser.add_argument('out_filename', help='output filename')
59
    parser.add_argument(
60
        '--config',
61
        default=('configs/skeleton/posec3d/'
62
                 'slowonly_r50_u48_240e_ntu120_xsub_keypoint.py'),
63
        help='skeleton model config file path')
64
    parser.add_argument(
65
        '--checkpoint',
66
        default=('https://download.openmmlab.com/mmaction/skeleton/posec3d/'
67
                 'slowonly_r50_u48_240e_ntu120_xsub_keypoint/'
68
                 'slowonly_r50_u48_240e_ntu120_xsub_keypoint-6736b03f.pth'),
69
        help='skeleton model checkpoint file/url')
70
    parser.add_argument(
71
        '--det-config',
72
        default='demo/faster_rcnn_r50_fpn_2x_coco.py',
73
        help='human detection config file path (from mmdet)')
74
    parser.add_argument(
75
        '--det-checkpoint',
76
        default=('http://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/'
77
                 'faster_rcnn_r50_fpn_2x_coco/'
78
                 'faster_rcnn_r50_fpn_2x_coco_'
79
                 'bbox_mAP-0.384_20200504_210434-a5d8aa15.pth'),
80
        help='human detection checkpoint file/url')
81
    parser.add_argument(
82
        '--pose-config',
83
        default='demo/hrnet_w32_coco_256x192.py',
84
        help='human pose estimation config file path (from mmpose)')
85
    parser.add_argument(
86
        '--pose-checkpoint',
87
        default=('https://download.openmmlab.com/mmpose/top_down/hrnet/'
88
                 'hrnet_w32_coco_256x192-c78dce93_20200708.pth'),
89
        help='human pose estimation checkpoint file/url')
90
    parser.add_argument(
91
        '--det-score-thr',
92
        type=float,
93
        default=0.9,
94
        help='the threshold of human detection score')
95
    parser.add_argument(
96
        '--label-map',
97
        default='tools/data/skeleton/label_map_ntu120.txt',
98
        help='label map file')
99
    parser.add_argument(
100
        '--device', type=str, default='cuda:0', help='CPU/CUDA device option')
101
    parser.add_argument(
102
        '--short-side',
103
        type=int,
104
        default=480,
105
        help='specify the short-side length of the image')
106
    parser.add_argument(
107
        '--cfg-options',
108
        nargs='+',
109
        action=DictAction,
110
        default={},
111
        help='override some settings in the used config, the key-value pair '
112
        'in xxx=yyy format will be merged into config file. For example, '
113
        "'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'")
114
    args = parser.parse_args()
115
    return args
116
117
118
def frame_extraction(video_path, short_side):
119
    """Extract frames given video_path.
120
121
    Args:
122
        video_path (str): The video_path.
123
    """
124
    # Load the video, extract frames into ./tmp/video_name
125
    target_dir = osp.join('./tmp', osp.basename(osp.splitext(video_path)[0]))
126
    os.makedirs(target_dir, exist_ok=True)
127
    # Should be able to handle videos up to several hours
128
    frame_tmpl = osp.join(target_dir, 'img_{:06d}.jpg')
129
    vid = cv2.VideoCapture(video_path)
130
    frames = []
131
    frame_paths = []
132
    flag, frame = vid.read()
133
    cnt = 0
134
    new_h, new_w = None, None
135
    while flag:
136
        if new_h is None:
137
            h, w, _ = frame.shape
138
            new_w, new_h = mmcv.rescale_size((w, h), (short_side, np.Inf))
139
140
        frame = mmcv.imresize(frame, (new_w, new_h))
141
142
        frames.append(frame)
143
        frame_path = frame_tmpl.format(cnt + 1)
144
        frame_paths.append(frame_path)
145
146
        cv2.imwrite(frame_path, frame)
147
        cnt += 1
148
        flag, frame = vid.read()
149
150
    return frame_paths, frames
151
152
153
def detection_inference(args, frame_paths):
154
    """Detect human boxes given frame paths.
155
156
    Args:
157
        args (argparse.Namespace): The arguments.
158
        frame_paths (list[str]): The paths of frames to do detection inference.
159
160
    Returns:
161
        list[np.ndarray]: The human detection results.
162
    """
163
    model = init_detector(args.det_config, args.det_checkpoint, args.device)
164
    assert model.CLASSES[0] == 'person', ('We require you to use a detector '
165
                                          'trained on COCO')
166
    results = []
167
    print('Performing Human Detection for each frame')
168
    prog_bar = mmcv.ProgressBar(len(frame_paths))
169
    for frame_path in frame_paths:
170
        result = inference_detector(model, frame_path)
171
        # We only keep human detections with score larger than det_score_thr
172
        result = result[0][result[0][:, 4] >= args.det_score_thr]
173
        results.append(result)
174
        prog_bar.update()
175
    return results
176
177
178
def pose_inference(args, frame_paths, det_results):
179
    model = init_pose_model(args.pose_config, args.pose_checkpoint,
180
                            args.device)
181
    ret = []
182
    print('Performing Human Pose Estimation for each frame')
183
    prog_bar = mmcv.ProgressBar(len(frame_paths))
184
    for f, d in zip(frame_paths, det_results):
185
        # Align input format
186
        d = [dict(bbox=x) for x in list(d)]
187
        pose = inference_top_down_pose_model(model, f, d, format='xyxy')[0]
188
        ret.append(pose)
189
        prog_bar.update()
190
    return ret
191
192
193
def main():
194
    args = parse_args()
195
196
    frame_paths, original_frames = frame_extraction(args.video,
197
                                                    args.short_side)
198
    num_frame = len(frame_paths)
199
    h, w, _ = original_frames[0].shape
200
201
    # Get clip_len, frame_interval and calculate center index of each clip
202
    config = mmcv.Config.fromfile(args.config)
203
    config.merge_from_dict(args.cfg_options)
204
    for component in config.data.test.pipeline:
205
        if component['type'] == 'PoseNormalize':
206
            component['mean'] = (w // 2, h // 2, .5)
207
            component['max_value'] = (w, h, 1.)
208
209
    model = init_recognizer(config, args.checkpoint, args.device)
210
211
    # Load label_map
212
    label_map = [x.strip() for x in open(args.label_map).readlines()]
213
214
    # Get Human detection results
215
    det_results = detection_inference(args, frame_paths)
216
    torch.cuda.empty_cache()
217
218
    pose_results = pose_inference(args, frame_paths, det_results)
219
    torch.cuda.empty_cache()
220
221
    fake_anno = dict(
222
        frame_dir='',
223
        label=-1,
224
        img_shape=(h, w),
225
        original_shape=(h, w),
226
        start_index=0,
227
        modality='Pose',
228
        total_frames=num_frame)
229
    num_person = max([len(x) for x in pose_results])
230
231
    num_keypoint = 17
232
    keypoint = np.zeros((num_person, num_frame, num_keypoint, 2),
233
                        dtype=np.float16)
234
    keypoint_score = np.zeros((num_person, num_frame, num_keypoint),
235
                              dtype=np.float16)
236
    for i, poses in enumerate(pose_results):
237
        for j, pose in enumerate(poses):
238
            pose = pose['keypoints']
239
            keypoint[j, i] = pose[:, :2]
240
            keypoint_score[j, i] = pose[:, 2]
241
    fake_anno['keypoint'] = keypoint
242
    fake_anno['keypoint_score'] = keypoint_score
243
244
    results = inference_recognizer(model, fake_anno)
245
246
    action_label = label_map[results[0][0]]
247
248
    pose_model = init_pose_model(args.pose_config, args.pose_checkpoint,
249
                                 args.device)
250
    vis_frames = [
251
        vis_pose_result(pose_model, frame_paths[i], pose_results[i])
252
        for i in range(num_frame)
253
    ]
254
    for frame in vis_frames:
255
        cv2.putText(frame, action_label, (10, 30), FONTFACE, FONTSCALE,
256
                    FONTCOLOR, THICKNESS, LINETYPE)
257
258
    vid = mpy.ImageSequenceClip([x[:, :, ::-1] for x in vis_frames], fps=24)
259
    vid.write_videofile(args.out_filename, remove_temp=True)
260
261
    tmp_frame_dir = osp.dirname(frame_paths[0])
262
    shutil.rmtree(tmp_frame_dir)
263
264
265
if __name__ == '__main__':
266
    main()