--- a +++ b/ViTPose/demo/webcam_demo.py @@ -0,0 +1,585 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import time +from collections import deque +from queue import Queue +from threading import Event, Lock, Thread + +import cv2 +import numpy as np + +from mmpose.apis import (get_track_id, inference_top_down_pose_model, + init_pose_model, vis_pose_result) +from mmpose.core import apply_bugeye_effect, apply_sunglasses_effect +from mmpose.utils import StopWatch + +try: + from mmdet.apis import inference_detector, init_detector + has_mmdet = True +except (ImportError, ModuleNotFoundError): + has_mmdet = False + +try: + import psutil + psutil_proc = psutil.Process() +except (ImportError, ModuleNotFoundError): + psutil_proc = None + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--cam-id', type=str, default='0') + parser.add_argument( + '--det-config', + type=str, + default='demo/mmdetection_cfg/' + 'ssdlite_mobilenetv2_scratch_600e_coco.py', + help='Config file for detection') + parser.add_argument( + '--det-checkpoint', + type=str, + default='https://download.openmmlab.com/mmdetection/v2.0/ssd/' + 'ssdlite_mobilenetv2_scratch_600e_coco/ssdlite_mobilenetv2_' + 'scratch_600e_coco_20210629_110627-974d9307.pth', + help='Checkpoint file for detection') + parser.add_argument( + '--enable-human-pose', + type=int, + default=1, + help='Enable human pose estimation') + parser.add_argument( + '--enable-animal-pose', + type=int, + default=0, + help='Enable animal pose estimation') + parser.add_argument( + '--human-pose-config', + type=str, + default='configs/wholebody/2d_kpt_sview_rgb_img/topdown_heatmap/' + 'coco-wholebody/vipnas_res50_coco_wholebody_256x192_dark.py', + help='Config file for human pose') + parser.add_argument( + '--human-pose-checkpoint', + type=str, + default='https://download.openmmlab.com/' + 'mmpose/top_down/vipnas/' + 'vipnas_res50_wholebody_256x192_dark-67c0ce35_20211112.pth', + help='Checkpoint file for human pose') + parser.add_argument( + '--human-det-ids', + type=int, + default=[1], + nargs='+', + help='Object category label of human in detection results.' + 'Default is [1(person)], following COCO definition.') + parser.add_argument( + '--animal-pose-config', + type=str, + default='configs/animal/2d_kpt_sview_rgb_img/topdown_heatmap/' + 'animalpose/hrnet_w32_animalpose_256x256.py', + help='Config file for animal pose') + parser.add_argument( + '--animal-pose-checkpoint', + type=str, + default='https://download.openmmlab.com/mmpose/animal/hrnet/' + 'hrnet_w32_animalpose_256x256-1aa7f075_20210426.pth', + help='Checkpoint file for animal pose') + parser.add_argument( + '--animal-det-ids', + type=int, + default=[16, 17, 18, 19, 20], + nargs='+', + help='Object category label of animals in detection results' + 'Default is [16(cat), 17(dog), 18(horse), 19(sheep), 20(cow)], ' + 'following COCO definition.') + parser.add_argument( + '--device', default='cuda:0', help='Device used for inference') + parser.add_argument( + '--det-score-thr', + type=float, + default=0.5, + help='bbox score threshold') + parser.add_argument( + '--kpt-thr', type=float, default=0.3, help='bbox score threshold') + parser.add_argument( + '--vis-mode', + type=int, + default=2, + help='0-none. 1-detection only. 2-detection and pose.') + parser.add_argument( + '--sunglasses', action='store_true', help='Apply `sunglasses` effect.') + parser.add_argument( + '--bugeye', action='store_true', help='Apply `bug-eye` effect.') + + parser.add_argument( + '--out-video-file', + type=str, + default=None, + help='Record the video into a file. This may reduce the frame rate') + + parser.add_argument( + '--out-video-fps', + type=int, + default=20, + help='Set the FPS of the output video file.') + + parser.add_argument( + '--buffer-size', + type=int, + default=-1, + help='Frame buffer size. If set -1, the buffer size will be ' + 'automatically inferred from the display delay time. Default: -1') + + parser.add_argument( + '--inference-fps', + type=int, + default=10, + help='Maximum inference FPS. This is to limit the resource consuming ' + 'especially when the detection and pose model are lightweight and ' + 'very fast. Default: 10.') + + parser.add_argument( + '--display-delay', + type=int, + default=0, + help='Delay the output video in milliseconds. This can be used to ' + 'align the output video and inference results. The delay can be ' + 'disabled by setting a non-positive delay time. Default: 0') + + parser.add_argument( + '--synchronous-mode', + action='store_true', + help='Enable synchronous mode that video I/O and inference will be ' + 'temporally aligned. Note that this will reduce the display FPS.') + + return parser.parse_args() + + +def process_mmdet_results(mmdet_results, class_names=None, cat_ids=1): + """Process mmdet results to mmpose input format. + + Args: + mmdet_results: raw output of mmdet model + class_names: class names of mmdet model + cat_ids (int or List[int]): category id list that will be preserved + Returns: + List[Dict]: detection results for mmpose input + """ + if isinstance(mmdet_results, tuple): + mmdet_results = mmdet_results[0] + + if not isinstance(cat_ids, (list, tuple)): + cat_ids = [cat_ids] + + # only keep bboxes of interested classes + bbox_results = [mmdet_results[i - 1] for i in cat_ids] + bboxes = np.vstack(bbox_results) + + # get textual labels of classes + labels = np.concatenate([ + np.full(bbox.shape[0], i - 1, dtype=np.int32) + for i, bbox in zip(cat_ids, bbox_results) + ]) + if class_names is None: + labels = [f'class: {i}' for i in labels] + else: + labels = [class_names[i] for i in labels] + + det_results = [] + for bbox, label in zip(bboxes, labels): + det_result = dict(bbox=bbox, label=label) + det_results.append(det_result) + return det_results + + +def read_camera(): + # init video reader + print('Thread "input" started') + cam_id = args.cam_id + if cam_id.isdigit(): + cam_id = int(cam_id) + vid_cap = cv2.VideoCapture(cam_id) + if not vid_cap.isOpened(): + print(f'Cannot open camera (ID={cam_id})') + exit() + + while not event_exit.is_set(): + # capture a camera frame + ret_val, frame = vid_cap.read() + if ret_val: + ts_input = time.time() + + event_inference_done.clear() + with input_queue_mutex: + input_queue.append((ts_input, frame)) + + if args.synchronous_mode: + event_inference_done.wait() + + frame_buffer.put((ts_input, frame)) + else: + # input ending signal + frame_buffer.put((None, None)) + break + + vid_cap.release() + + +def inference_detection(): + print('Thread "det" started') + stop_watch = StopWatch(window=10) + min_interval = 1.0 / args.inference_fps + _ts_last = None # timestamp when last inference was done + + while True: + while len(input_queue) < 1: + time.sleep(0.001) + with input_queue_mutex: + ts_input, frame = input_queue.popleft() + # inference detection + with stop_watch.timeit('Det'): + mmdet_results = inference_detector(det_model, frame) + + t_info = stop_watch.report_strings() + with det_result_queue_mutex: + det_result_queue.append((ts_input, frame, t_info, mmdet_results)) + + # limit the inference FPS + _ts = time.time() + if _ts_last is not None and _ts - _ts_last < min_interval: + time.sleep(min_interval - _ts + _ts_last) + _ts_last = time.time() + + +def inference_pose(): + print('Thread "pose" started') + stop_watch = StopWatch(window=10) + + while True: + while len(det_result_queue) < 1: + time.sleep(0.001) + with det_result_queue_mutex: + ts_input, frame, t_info, mmdet_results = det_result_queue.popleft() + + pose_results_list = [] + for model_info, pose_history in zip(pose_model_list, + pose_history_list): + model_name = model_info['name'] + pose_model = model_info['model'] + cat_ids = model_info['cat_ids'] + pose_results_last = pose_history['pose_results_last'] + next_id = pose_history['next_id'] + + with stop_watch.timeit(model_name): + # process mmdet results + det_results = process_mmdet_results( + mmdet_results, + class_names=det_model.CLASSES, + cat_ids=cat_ids) + + # inference pose model + dataset_name = pose_model.cfg.data['test']['type'] + pose_results, _ = inference_top_down_pose_model( + pose_model, + frame, + det_results, + bbox_thr=args.det_score_thr, + format='xyxy', + dataset=dataset_name) + + pose_results, next_id = get_track_id( + pose_results, + pose_results_last, + next_id, + use_oks=False, + tracking_thr=0.3, + use_one_euro=True, + fps=None) + + pose_results_list.append(pose_results) + + # update pose history + pose_history['pose_results_last'] = pose_results + pose_history['next_id'] = next_id + + t_info += stop_watch.report_strings() + with pose_result_queue_mutex: + pose_result_queue.append((ts_input, t_info, pose_results_list)) + + event_inference_done.set() + + +def display(): + print('Thread "display" started') + stop_watch = StopWatch(window=10) + + # initialize result status + ts_inference = None # timestamp of the latest inference result + fps_inference = 0. # infenrece FPS + t_delay_inference = 0. # inference result time delay + pose_results_list = None # latest inference result + t_info = [] # upstream time information (list[str]) + + # initialize visualization and output + sunglasses_img = None # resource image for sunglasses effect + text_color = (228, 183, 61) # text color to show time/system information + vid_out = None # video writer + + # show instructions + print('Keyboard shortcuts: ') + print('"v": Toggle the visualization of bounding boxes and poses.') + print('"s": Toggle the sunglasses effect.') + print('"b": Toggle the bug-eye effect.') + print('"Q", "q" or Esc: Exit.') + + while True: + with stop_watch.timeit('_FPS_'): + # acquire a frame from buffer + ts_input, frame = frame_buffer.get() + # input ending signal + if ts_input is None: + break + + img = frame + + # get pose estimation results + if len(pose_result_queue) > 0: + with pose_result_queue_mutex: + _result = pose_result_queue.popleft() + _ts_input, t_info, pose_results_list = _result + + _ts = time.time() + if ts_inference is not None: + fps_inference = 1.0 / (_ts - ts_inference) + ts_inference = _ts + t_delay_inference = (_ts - _ts_input) * 1000 + + # visualize detection and pose results + if pose_results_list is not None: + for model_info, pose_results in zip(pose_model_list, + pose_results_list): + pose_model = model_info['model'] + bbox_color = model_info['bbox_color'] + + dataset_name = pose_model.cfg.data['test']['type'] + + # show pose results + if args.vis_mode == 1: + img = vis_pose_result( + pose_model, + img, + pose_results, + radius=4, + thickness=2, + dataset=dataset_name, + kpt_score_thr=1e7, + bbox_color=bbox_color) + elif args.vis_mode == 2: + img = vis_pose_result( + pose_model, + img, + pose_results, + radius=4, + thickness=2, + dataset=dataset_name, + kpt_score_thr=args.kpt_thr, + bbox_color=bbox_color) + + # sunglasses effect + if args.sunglasses: + if dataset_name in { + 'TopDownCocoDataset', + 'TopDownCocoWholeBodyDataset' + }: + left_eye_idx = 1 + right_eye_idx = 2 + elif dataset_name == 'AnimalPoseDataset': + left_eye_idx = 0 + right_eye_idx = 1 + else: + raise ValueError( + 'Sunglasses effect does not support' + f'{dataset_name}') + if sunglasses_img is None: + # The image attributes to: + # https://www.vecteezy.com/free-vector/glass + # Glass Vectors by Vecteezy + sunglasses_img = cv2.imread( + 'demo/resources/sunglasses.jpg') + img = apply_sunglasses_effect(img, pose_results, + sunglasses_img, + left_eye_idx, + right_eye_idx) + # bug-eye effect + if args.bugeye: + if dataset_name in { + 'TopDownCocoDataset', + 'TopDownCocoWholeBodyDataset' + }: + left_eye_idx = 1 + right_eye_idx = 2 + elif dataset_name == 'AnimalPoseDataset': + left_eye_idx = 0 + right_eye_idx = 1 + else: + raise ValueError('Bug-eye effect does not support' + f'{dataset_name}') + img = apply_bugeye_effect(img, pose_results, + left_eye_idx, right_eye_idx) + + # delay control + if args.display_delay > 0: + t_sleep = args.display_delay * 0.001 - (time.time() - ts_input) + if t_sleep > 0: + time.sleep(t_sleep) + t_delay = (time.time() - ts_input) * 1000 + + # show time information + t_info_display = stop_watch.report_strings() # display fps + t_info_display.append(f'Inference FPS: {fps_inference:>5.1f}') + t_info_display.append(f'Delay: {t_delay:>3.0f}') + t_info_display.append( + f'Inference Delay: {t_delay_inference:>3.0f}') + t_info_str = ' | '.join(t_info_display + t_info) + cv2.putText(img, t_info_str, (20, 20), cv2.FONT_HERSHEY_DUPLEX, + 0.3, text_color, 1) + # collect system information + sys_info = [ + f'RES: {img.shape[1]}x{img.shape[0]}', + f'Buffer: {frame_buffer.qsize()}/{frame_buffer.maxsize}' + ] + if psutil_proc is not None: + sys_info += [ + f'CPU: {psutil_proc.cpu_percent():.1f}%', + f'MEM: {psutil_proc.memory_percent():.1f}%' + ] + sys_info_str = ' | '.join(sys_info) + cv2.putText(img, sys_info_str, (20, 40), cv2.FONT_HERSHEY_DUPLEX, + 0.3, text_color, 1) + + # save the output video frame + if args.out_video_file is not None: + if vid_out is None: + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + fps = args.out_video_fps + frame_size = (img.shape[1], img.shape[0]) + vid_out = cv2.VideoWriter(args.out_video_file, fourcc, fps, + frame_size) + + vid_out.write(img) + + # display + cv2.imshow('mmpose webcam demo', img) + keyboard_input = cv2.waitKey(1) + if keyboard_input in (27, ord('q'), ord('Q')): + break + elif keyboard_input == ord('s'): + args.sunglasses = not args.sunglasses + elif keyboard_input == ord('b'): + args.bugeye = not args.bugeye + elif keyboard_input == ord('v'): + args.vis_mode = (args.vis_mode + 1) % 3 + + cv2.destroyAllWindows() + if vid_out is not None: + vid_out.release() + event_exit.set() + + +def main(): + global args + global frame_buffer + global input_queue, input_queue_mutex + global det_result_queue, det_result_queue_mutex + global pose_result_queue, pose_result_queue_mutex + global det_model, pose_model_list, pose_history_list + global event_exit, event_inference_done + + args = parse_args() + + assert has_mmdet, 'Please install mmdet to run the demo.' + assert args.det_config is not None + assert args.det_checkpoint is not None + + # build detection model + det_model = init_detector( + args.det_config, args.det_checkpoint, device=args.device.lower()) + + # build pose models + pose_model_list = [] + if args.enable_human_pose: + pose_model = init_pose_model( + args.human_pose_config, + args.human_pose_checkpoint, + device=args.device.lower()) + model_info = { + 'name': 'HumanPose', + 'model': pose_model, + 'cat_ids': args.human_det_ids, + 'bbox_color': (148, 139, 255), + } + pose_model_list.append(model_info) + if args.enable_animal_pose: + pose_model = init_pose_model( + args.animal_pose_config, + args.animal_pose_checkpoint, + device=args.device.lower()) + model_info = { + 'name': 'AnimalPose', + 'model': pose_model, + 'cat_ids': args.animal_det_ids, + 'bbox_color': 'cyan', + } + pose_model_list.append(model_info) + + # store pose history for pose tracking + pose_history_list = [] + for _ in range(len(pose_model_list)): + pose_history_list.append({'pose_results_last': [], 'next_id': 0}) + + # frame buffer + if args.buffer_size > 0: + buffer_size = args.buffer_size + else: + # infer buffer size from the display delay time + # assume that the maximum video fps is 30 + buffer_size = round(30 * (1 + max(args.display_delay, 0) / 1000.)) + frame_buffer = Queue(maxsize=buffer_size) + + # queue of input frames + # element: (timestamp, frame) + input_queue = deque(maxlen=1) + input_queue_mutex = Lock() + + # queue of detection results + # element: tuple(timestamp, frame, time_info, det_results) + det_result_queue = deque(maxlen=1) + det_result_queue_mutex = Lock() + + # queue of detection/pose results + # element: (timestamp, time_info, pose_results_list) + pose_result_queue = deque(maxlen=1) + pose_result_queue_mutex = Lock() + + try: + event_exit = Event() + event_inference_done = Event() + t_input = Thread(target=read_camera, args=()) + t_det = Thread(target=inference_detection, args=(), daemon=True) + t_pose = Thread(target=inference_pose, args=(), daemon=True) + + t_input.start() + t_det.start() + t_pose.start() + + # run display in the main thread + display() + # join the input thread (non-daemon) + t_input.join() + + except KeyboardInterrupt: + pass + + +if __name__ == '__main__': + main()