|
a |
|
b/demo/demo_skeleton.py |
|
|
1 |
# Copyright (c) OpenMMLab. All rights reserved. |
|
|
2 |
import argparse |
|
|
3 |
import os |
|
|
4 |
import os.path as osp |
|
|
5 |
import shutil |
|
|
6 |
|
|
|
7 |
import cv2 |
|
|
8 |
import mmcv |
|
|
9 |
import numpy as np |
|
|
10 |
import torch |
|
|
11 |
from mmcv import DictAction |
|
|
12 |
|
|
|
13 |
from mmaction.apis import inference_recognizer, init_recognizer |
|
|
14 |
from mmaction.utils import import_module_error_func |
|
|
15 |
|
|
|
16 |
try: |
|
|
17 |
from mmdet.apis import inference_detector, init_detector |
|
|
18 |
from mmpose.apis import (init_pose_model, inference_top_down_pose_model, |
|
|
19 |
vis_pose_result) |
|
|
20 |
except (ImportError, ModuleNotFoundError): |
|
|
21 |
|
|
|
22 |
@import_module_error_func('mmdet') |
|
|
23 |
def inference_detector(*args, **kwargs): |
|
|
24 |
pass |
|
|
25 |
|
|
|
26 |
@import_module_error_func('mmdet') |
|
|
27 |
def init_detector(*args, **kwargs): |
|
|
28 |
pass |
|
|
29 |
|
|
|
30 |
@import_module_error_func('mmpose') |
|
|
31 |
def init_pose_model(*args, **kwargs): |
|
|
32 |
pass |
|
|
33 |
|
|
|
34 |
@import_module_error_func('mmpose') |
|
|
35 |
def inference_top_down_pose_model(*args, **kwargs): |
|
|
36 |
pass |
|
|
37 |
|
|
|
38 |
@import_module_error_func('mmpose') |
|
|
39 |
def vis_pose_result(*args, **kwargs): |
|
|
40 |
pass |
|
|
41 |
|
|
|
42 |
|
|
|
43 |
try: |
|
|
44 |
import moviepy.editor as mpy |
|
|
45 |
except ImportError: |
|
|
46 |
raise ImportError('Please install moviepy to enable output file') |
|
|
47 |
|
|
|
48 |
FONTFACE = cv2.FONT_HERSHEY_DUPLEX |
|
|
49 |
FONTSCALE = 0.75 |
|
|
50 |
FONTCOLOR = (255, 255, 255) # BGR, white |
|
|
51 |
THICKNESS = 1 |
|
|
52 |
LINETYPE = 1 |
|
|
53 |
|
|
|
54 |
|
|
|
55 |
def parse_args(): |
|
|
56 |
parser = argparse.ArgumentParser(description='MMAction2 demo') |
|
|
57 |
parser.add_argument('video', help='video file/url') |
|
|
58 |
parser.add_argument('out_filename', help='output filename') |
|
|
59 |
parser.add_argument( |
|
|
60 |
'--config', |
|
|
61 |
default=('configs/skeleton/posec3d/' |
|
|
62 |
'slowonly_r50_u48_240e_ntu120_xsub_keypoint.py'), |
|
|
63 |
help='skeleton model config file path') |
|
|
64 |
parser.add_argument( |
|
|
65 |
'--checkpoint', |
|
|
66 |
default=('https://download.openmmlab.com/mmaction/skeleton/posec3d/' |
|
|
67 |
'slowonly_r50_u48_240e_ntu120_xsub_keypoint/' |
|
|
68 |
'slowonly_r50_u48_240e_ntu120_xsub_keypoint-6736b03f.pth'), |
|
|
69 |
help='skeleton model checkpoint file/url') |
|
|
70 |
parser.add_argument( |
|
|
71 |
'--det-config', |
|
|
72 |
default='demo/faster_rcnn_r50_fpn_2x_coco.py', |
|
|
73 |
help='human detection config file path (from mmdet)') |
|
|
74 |
parser.add_argument( |
|
|
75 |
'--det-checkpoint', |
|
|
76 |
default=('http://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/' |
|
|
77 |
'faster_rcnn_r50_fpn_2x_coco/' |
|
|
78 |
'faster_rcnn_r50_fpn_2x_coco_' |
|
|
79 |
'bbox_mAP-0.384_20200504_210434-a5d8aa15.pth'), |
|
|
80 |
help='human detection checkpoint file/url') |
|
|
81 |
parser.add_argument( |
|
|
82 |
'--pose-config', |
|
|
83 |
default='demo/hrnet_w32_coco_256x192.py', |
|
|
84 |
help='human pose estimation config file path (from mmpose)') |
|
|
85 |
parser.add_argument( |
|
|
86 |
'--pose-checkpoint', |
|
|
87 |
default=('https://download.openmmlab.com/mmpose/top_down/hrnet/' |
|
|
88 |
'hrnet_w32_coco_256x192-c78dce93_20200708.pth'), |
|
|
89 |
help='human pose estimation checkpoint file/url') |
|
|
90 |
parser.add_argument( |
|
|
91 |
'--det-score-thr', |
|
|
92 |
type=float, |
|
|
93 |
default=0.9, |
|
|
94 |
help='the threshold of human detection score') |
|
|
95 |
parser.add_argument( |
|
|
96 |
'--label-map', |
|
|
97 |
default='tools/data/skeleton/label_map_ntu120.txt', |
|
|
98 |
help='label map file') |
|
|
99 |
parser.add_argument( |
|
|
100 |
'--device', type=str, default='cuda:0', help='CPU/CUDA device option') |
|
|
101 |
parser.add_argument( |
|
|
102 |
'--short-side', |
|
|
103 |
type=int, |
|
|
104 |
default=480, |
|
|
105 |
help='specify the short-side length of the image') |
|
|
106 |
parser.add_argument( |
|
|
107 |
'--cfg-options', |
|
|
108 |
nargs='+', |
|
|
109 |
action=DictAction, |
|
|
110 |
default={}, |
|
|
111 |
help='override some settings in the used config, the key-value pair ' |
|
|
112 |
'in xxx=yyy format will be merged into config file. For example, ' |
|
|
113 |
"'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'") |
|
|
114 |
args = parser.parse_args() |
|
|
115 |
return args |
|
|
116 |
|
|
|
117 |
|
|
|
118 |
def frame_extraction(video_path, short_side): |
|
|
119 |
"""Extract frames given video_path. |
|
|
120 |
|
|
|
121 |
Args: |
|
|
122 |
video_path (str): The video_path. |
|
|
123 |
""" |
|
|
124 |
# Load the video, extract frames into ./tmp/video_name |
|
|
125 |
target_dir = osp.join('./tmp', osp.basename(osp.splitext(video_path)[0])) |
|
|
126 |
os.makedirs(target_dir, exist_ok=True) |
|
|
127 |
# Should be able to handle videos up to several hours |
|
|
128 |
frame_tmpl = osp.join(target_dir, 'img_{:06d}.jpg') |
|
|
129 |
vid = cv2.VideoCapture(video_path) |
|
|
130 |
frames = [] |
|
|
131 |
frame_paths = [] |
|
|
132 |
flag, frame = vid.read() |
|
|
133 |
cnt = 0 |
|
|
134 |
new_h, new_w = None, None |
|
|
135 |
while flag: |
|
|
136 |
if new_h is None: |
|
|
137 |
h, w, _ = frame.shape |
|
|
138 |
new_w, new_h = mmcv.rescale_size((w, h), (short_side, np.Inf)) |
|
|
139 |
|
|
|
140 |
frame = mmcv.imresize(frame, (new_w, new_h)) |
|
|
141 |
|
|
|
142 |
frames.append(frame) |
|
|
143 |
frame_path = frame_tmpl.format(cnt + 1) |
|
|
144 |
frame_paths.append(frame_path) |
|
|
145 |
|
|
|
146 |
cv2.imwrite(frame_path, frame) |
|
|
147 |
cnt += 1 |
|
|
148 |
flag, frame = vid.read() |
|
|
149 |
|
|
|
150 |
return frame_paths, frames |
|
|
151 |
|
|
|
152 |
|
|
|
153 |
def detection_inference(args, frame_paths): |
|
|
154 |
"""Detect human boxes given frame paths. |
|
|
155 |
|
|
|
156 |
Args: |
|
|
157 |
args (argparse.Namespace): The arguments. |
|
|
158 |
frame_paths (list[str]): The paths of frames to do detection inference. |
|
|
159 |
|
|
|
160 |
Returns: |
|
|
161 |
list[np.ndarray]: The human detection results. |
|
|
162 |
""" |
|
|
163 |
model = init_detector(args.det_config, args.det_checkpoint, args.device) |
|
|
164 |
assert model.CLASSES[0] == 'person', ('We require you to use a detector ' |
|
|
165 |
'trained on COCO') |
|
|
166 |
results = [] |
|
|
167 |
print('Performing Human Detection for each frame') |
|
|
168 |
prog_bar = mmcv.ProgressBar(len(frame_paths)) |
|
|
169 |
for frame_path in frame_paths: |
|
|
170 |
result = inference_detector(model, frame_path) |
|
|
171 |
# We only keep human detections with score larger than det_score_thr |
|
|
172 |
result = result[0][result[0][:, 4] >= args.det_score_thr] |
|
|
173 |
results.append(result) |
|
|
174 |
prog_bar.update() |
|
|
175 |
return results |
|
|
176 |
|
|
|
177 |
|
|
|
178 |
def pose_inference(args, frame_paths, det_results): |
|
|
179 |
model = init_pose_model(args.pose_config, args.pose_checkpoint, |
|
|
180 |
args.device) |
|
|
181 |
ret = [] |
|
|
182 |
print('Performing Human Pose Estimation for each frame') |
|
|
183 |
prog_bar = mmcv.ProgressBar(len(frame_paths)) |
|
|
184 |
for f, d in zip(frame_paths, det_results): |
|
|
185 |
# Align input format |
|
|
186 |
d = [dict(bbox=x) for x in list(d)] |
|
|
187 |
pose = inference_top_down_pose_model(model, f, d, format='xyxy')[0] |
|
|
188 |
ret.append(pose) |
|
|
189 |
prog_bar.update() |
|
|
190 |
return ret |
|
|
191 |
|
|
|
192 |
|
|
|
193 |
def main(): |
|
|
194 |
args = parse_args() |
|
|
195 |
|
|
|
196 |
frame_paths, original_frames = frame_extraction(args.video, |
|
|
197 |
args.short_side) |
|
|
198 |
num_frame = len(frame_paths) |
|
|
199 |
h, w, _ = original_frames[0].shape |
|
|
200 |
|
|
|
201 |
# Get clip_len, frame_interval and calculate center index of each clip |
|
|
202 |
config = mmcv.Config.fromfile(args.config) |
|
|
203 |
config.merge_from_dict(args.cfg_options) |
|
|
204 |
for component in config.data.test.pipeline: |
|
|
205 |
if component['type'] == 'PoseNormalize': |
|
|
206 |
component['mean'] = (w // 2, h // 2, .5) |
|
|
207 |
component['max_value'] = (w, h, 1.) |
|
|
208 |
|
|
|
209 |
model = init_recognizer(config, args.checkpoint, args.device) |
|
|
210 |
|
|
|
211 |
# Load label_map |
|
|
212 |
label_map = [x.strip() for x in open(args.label_map).readlines()] |
|
|
213 |
|
|
|
214 |
# Get Human detection results |
|
|
215 |
det_results = detection_inference(args, frame_paths) |
|
|
216 |
torch.cuda.empty_cache() |
|
|
217 |
|
|
|
218 |
pose_results = pose_inference(args, frame_paths, det_results) |
|
|
219 |
torch.cuda.empty_cache() |
|
|
220 |
|
|
|
221 |
fake_anno = dict( |
|
|
222 |
frame_dir='', |
|
|
223 |
label=-1, |
|
|
224 |
img_shape=(h, w), |
|
|
225 |
original_shape=(h, w), |
|
|
226 |
start_index=0, |
|
|
227 |
modality='Pose', |
|
|
228 |
total_frames=num_frame) |
|
|
229 |
num_person = max([len(x) for x in pose_results]) |
|
|
230 |
|
|
|
231 |
num_keypoint = 17 |
|
|
232 |
keypoint = np.zeros((num_person, num_frame, num_keypoint, 2), |
|
|
233 |
dtype=np.float16) |
|
|
234 |
keypoint_score = np.zeros((num_person, num_frame, num_keypoint), |
|
|
235 |
dtype=np.float16) |
|
|
236 |
for i, poses in enumerate(pose_results): |
|
|
237 |
for j, pose in enumerate(poses): |
|
|
238 |
pose = pose['keypoints'] |
|
|
239 |
keypoint[j, i] = pose[:, :2] |
|
|
240 |
keypoint_score[j, i] = pose[:, 2] |
|
|
241 |
fake_anno['keypoint'] = keypoint |
|
|
242 |
fake_anno['keypoint_score'] = keypoint_score |
|
|
243 |
|
|
|
244 |
results = inference_recognizer(model, fake_anno) |
|
|
245 |
|
|
|
246 |
action_label = label_map[results[0][0]] |
|
|
247 |
|
|
|
248 |
pose_model = init_pose_model(args.pose_config, args.pose_checkpoint, |
|
|
249 |
args.device) |
|
|
250 |
vis_frames = [ |
|
|
251 |
vis_pose_result(pose_model, frame_paths[i], pose_results[i]) |
|
|
252 |
for i in range(num_frame) |
|
|
253 |
] |
|
|
254 |
for frame in vis_frames: |
|
|
255 |
cv2.putText(frame, action_label, (10, 30), FONTFACE, FONTSCALE, |
|
|
256 |
FONTCOLOR, THICKNESS, LINETYPE) |
|
|
257 |
|
|
|
258 |
vid = mpy.ImageSequenceClip([x[:, :, ::-1] for x in vis_frames], fps=24) |
|
|
259 |
vid.write_videofile(args.out_filename, remove_temp=True) |
|
|
260 |
|
|
|
261 |
tmp_frame_dir = osp.dirname(frame_paths[0]) |
|
|
262 |
shutil.rmtree(tmp_frame_dir) |
|
|
263 |
|
|
|
264 |
|
|
|
265 |
if __name__ == '__main__': |
|
|
266 |
main() |