Switch to side-by-side view

--- a
+++ b/ViTPose/demo/webcam_demo.py
@@ -0,0 +1,585 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import argparse
+import time
+from collections import deque
+from queue import Queue
+from threading import Event, Lock, Thread
+
+import cv2
+import numpy as np
+
+from mmpose.apis import (get_track_id, inference_top_down_pose_model,
+                         init_pose_model, vis_pose_result)
+from mmpose.core import apply_bugeye_effect, apply_sunglasses_effect
+from mmpose.utils import StopWatch
+
+try:
+    from mmdet.apis import inference_detector, init_detector
+    has_mmdet = True
+except (ImportError, ModuleNotFoundError):
+    has_mmdet = False
+
+try:
+    import psutil
+    psutil_proc = psutil.Process()
+except (ImportError, ModuleNotFoundError):
+    psutil_proc = None
+
+
+def parse_args():
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--cam-id', type=str, default='0')
+    parser.add_argument(
+        '--det-config',
+        type=str,
+        default='demo/mmdetection_cfg/'
+        'ssdlite_mobilenetv2_scratch_600e_coco.py',
+        help='Config file for detection')
+    parser.add_argument(
+        '--det-checkpoint',
+        type=str,
+        default='https://download.openmmlab.com/mmdetection/v2.0/ssd/'
+        'ssdlite_mobilenetv2_scratch_600e_coco/ssdlite_mobilenetv2_'
+        'scratch_600e_coco_20210629_110627-974d9307.pth',
+        help='Checkpoint file for detection')
+    parser.add_argument(
+        '--enable-human-pose',
+        type=int,
+        default=1,
+        help='Enable human pose estimation')
+    parser.add_argument(
+        '--enable-animal-pose',
+        type=int,
+        default=0,
+        help='Enable animal pose estimation')
+    parser.add_argument(
+        '--human-pose-config',
+        type=str,
+        default='configs/wholebody/2d_kpt_sview_rgb_img/topdown_heatmap/'
+        'coco-wholebody/vipnas_res50_coco_wholebody_256x192_dark.py',
+        help='Config file for human pose')
+    parser.add_argument(
+        '--human-pose-checkpoint',
+        type=str,
+        default='https://download.openmmlab.com/'
+        'mmpose/top_down/vipnas/'
+        'vipnas_res50_wholebody_256x192_dark-67c0ce35_20211112.pth',
+        help='Checkpoint file for human pose')
+    parser.add_argument(
+        '--human-det-ids',
+        type=int,
+        default=[1],
+        nargs='+',
+        help='Object category label of human in detection results.'
+        'Default is [1(person)], following COCO definition.')
+    parser.add_argument(
+        '--animal-pose-config',
+        type=str,
+        default='configs/animal/2d_kpt_sview_rgb_img/topdown_heatmap/'
+        'animalpose/hrnet_w32_animalpose_256x256.py',
+        help='Config file for animal pose')
+    parser.add_argument(
+        '--animal-pose-checkpoint',
+        type=str,
+        default='https://download.openmmlab.com/mmpose/animal/hrnet/'
+        'hrnet_w32_animalpose_256x256-1aa7f075_20210426.pth',
+        help='Checkpoint file for animal pose')
+    parser.add_argument(
+        '--animal-det-ids',
+        type=int,
+        default=[16, 17, 18, 19, 20],
+        nargs='+',
+        help='Object category label of animals in detection results'
+        'Default is [16(cat), 17(dog), 18(horse), 19(sheep), 20(cow)], '
+        'following COCO definition.')
+    parser.add_argument(
+        '--device', default='cuda:0', help='Device used for inference')
+    parser.add_argument(
+        '--det-score-thr',
+        type=float,
+        default=0.5,
+        help='bbox score threshold')
+    parser.add_argument(
+        '--kpt-thr', type=float, default=0.3, help='bbox score threshold')
+    parser.add_argument(
+        '--vis-mode',
+        type=int,
+        default=2,
+        help='0-none. 1-detection only. 2-detection and pose.')
+    parser.add_argument(
+        '--sunglasses', action='store_true', help='Apply `sunglasses` effect.')
+    parser.add_argument(
+        '--bugeye', action='store_true', help='Apply `bug-eye` effect.')
+
+    parser.add_argument(
+        '--out-video-file',
+        type=str,
+        default=None,
+        help='Record the video into a file. This may reduce the frame rate')
+
+    parser.add_argument(
+        '--out-video-fps',
+        type=int,
+        default=20,
+        help='Set the FPS of the output video file.')
+
+    parser.add_argument(
+        '--buffer-size',
+        type=int,
+        default=-1,
+        help='Frame buffer size. If set -1, the buffer size will be '
+        'automatically inferred from the display delay time. Default: -1')
+
+    parser.add_argument(
+        '--inference-fps',
+        type=int,
+        default=10,
+        help='Maximum inference FPS. This is to limit the resource consuming '
+        'especially when the detection and pose model are lightweight and '
+        'very fast. Default: 10.')
+
+    parser.add_argument(
+        '--display-delay',
+        type=int,
+        default=0,
+        help='Delay the output video in milliseconds. This can be used to '
+        'align the output video and inference results. The delay can be '
+        'disabled by setting a non-positive delay time. Default: 0')
+
+    parser.add_argument(
+        '--synchronous-mode',
+        action='store_true',
+        help='Enable synchronous mode that video I/O and inference will be '
+        'temporally aligned. Note that this will reduce the display FPS.')
+
+    return parser.parse_args()
+
+
+def process_mmdet_results(mmdet_results, class_names=None, cat_ids=1):
+    """Process mmdet results to mmpose input format.
+
+    Args:
+        mmdet_results: raw output of mmdet model
+        class_names: class names of mmdet model
+        cat_ids (int or List[int]): category id list that will be preserved
+    Returns:
+        List[Dict]: detection results for mmpose input
+    """
+    if isinstance(mmdet_results, tuple):
+        mmdet_results = mmdet_results[0]
+
+    if not isinstance(cat_ids, (list, tuple)):
+        cat_ids = [cat_ids]
+
+    # only keep bboxes of interested classes
+    bbox_results = [mmdet_results[i - 1] for i in cat_ids]
+    bboxes = np.vstack(bbox_results)
+
+    # get textual labels of classes
+    labels = np.concatenate([
+        np.full(bbox.shape[0], i - 1, dtype=np.int32)
+        for i, bbox in zip(cat_ids, bbox_results)
+    ])
+    if class_names is None:
+        labels = [f'class: {i}' for i in labels]
+    else:
+        labels = [class_names[i] for i in labels]
+
+    det_results = []
+    for bbox, label in zip(bboxes, labels):
+        det_result = dict(bbox=bbox, label=label)
+        det_results.append(det_result)
+    return det_results
+
+
+def read_camera():
+    # init video reader
+    print('Thread "input" started')
+    cam_id = args.cam_id
+    if cam_id.isdigit():
+        cam_id = int(cam_id)
+    vid_cap = cv2.VideoCapture(cam_id)
+    if not vid_cap.isOpened():
+        print(f'Cannot open camera (ID={cam_id})')
+        exit()
+
+    while not event_exit.is_set():
+        # capture a camera frame
+        ret_val, frame = vid_cap.read()
+        if ret_val:
+            ts_input = time.time()
+
+            event_inference_done.clear()
+            with input_queue_mutex:
+                input_queue.append((ts_input, frame))
+
+            if args.synchronous_mode:
+                event_inference_done.wait()
+
+            frame_buffer.put((ts_input, frame))
+        else:
+            # input ending signal
+            frame_buffer.put((None, None))
+            break
+
+    vid_cap.release()
+
+
+def inference_detection():
+    print('Thread "det" started')
+    stop_watch = StopWatch(window=10)
+    min_interval = 1.0 / args.inference_fps
+    _ts_last = None  # timestamp when last inference was done
+
+    while True:
+        while len(input_queue) < 1:
+            time.sleep(0.001)
+        with input_queue_mutex:
+            ts_input, frame = input_queue.popleft()
+        # inference detection
+        with stop_watch.timeit('Det'):
+            mmdet_results = inference_detector(det_model, frame)
+
+        t_info = stop_watch.report_strings()
+        with det_result_queue_mutex:
+            det_result_queue.append((ts_input, frame, t_info, mmdet_results))
+
+        # limit the inference FPS
+        _ts = time.time()
+        if _ts_last is not None and _ts - _ts_last < min_interval:
+            time.sleep(min_interval - _ts + _ts_last)
+        _ts_last = time.time()
+
+
+def inference_pose():
+    print('Thread "pose" started')
+    stop_watch = StopWatch(window=10)
+
+    while True:
+        while len(det_result_queue) < 1:
+            time.sleep(0.001)
+        with det_result_queue_mutex:
+            ts_input, frame, t_info, mmdet_results = det_result_queue.popleft()
+
+        pose_results_list = []
+        for model_info, pose_history in zip(pose_model_list,
+                                            pose_history_list):
+            model_name = model_info['name']
+            pose_model = model_info['model']
+            cat_ids = model_info['cat_ids']
+            pose_results_last = pose_history['pose_results_last']
+            next_id = pose_history['next_id']
+
+            with stop_watch.timeit(model_name):
+                # process mmdet results
+                det_results = process_mmdet_results(
+                    mmdet_results,
+                    class_names=det_model.CLASSES,
+                    cat_ids=cat_ids)
+
+                # inference pose model
+                dataset_name = pose_model.cfg.data['test']['type']
+                pose_results, _ = inference_top_down_pose_model(
+                    pose_model,
+                    frame,
+                    det_results,
+                    bbox_thr=args.det_score_thr,
+                    format='xyxy',
+                    dataset=dataset_name)
+
+                pose_results, next_id = get_track_id(
+                    pose_results,
+                    pose_results_last,
+                    next_id,
+                    use_oks=False,
+                    tracking_thr=0.3,
+                    use_one_euro=True,
+                    fps=None)
+
+                pose_results_list.append(pose_results)
+
+                # update pose history
+                pose_history['pose_results_last'] = pose_results
+                pose_history['next_id'] = next_id
+
+        t_info += stop_watch.report_strings()
+        with pose_result_queue_mutex:
+            pose_result_queue.append((ts_input, t_info, pose_results_list))
+
+        event_inference_done.set()
+
+
+def display():
+    print('Thread "display" started')
+    stop_watch = StopWatch(window=10)
+
+    # initialize result status
+    ts_inference = None  # timestamp of the latest inference result
+    fps_inference = 0.  # infenrece FPS
+    t_delay_inference = 0.  # inference result time delay
+    pose_results_list = None  # latest inference result
+    t_info = []  # upstream time information (list[str])
+
+    # initialize visualization and output
+    sunglasses_img = None  # resource image for sunglasses effect
+    text_color = (228, 183, 61)  # text color to show time/system information
+    vid_out = None  # video writer
+
+    # show instructions
+    print('Keyboard shortcuts: ')
+    print('"v": Toggle the visualization of bounding boxes and poses.')
+    print('"s": Toggle the sunglasses effect.')
+    print('"b": Toggle the bug-eye effect.')
+    print('"Q", "q" or Esc: Exit.')
+
+    while True:
+        with stop_watch.timeit('_FPS_'):
+            # acquire a frame from buffer
+            ts_input, frame = frame_buffer.get()
+            # input ending signal
+            if ts_input is None:
+                break
+
+            img = frame
+
+            # get pose estimation results
+            if len(pose_result_queue) > 0:
+                with pose_result_queue_mutex:
+                    _result = pose_result_queue.popleft()
+                    _ts_input, t_info, pose_results_list = _result
+
+                _ts = time.time()
+                if ts_inference is not None:
+                    fps_inference = 1.0 / (_ts - ts_inference)
+                ts_inference = _ts
+                t_delay_inference = (_ts - _ts_input) * 1000
+
+            # visualize detection and pose results
+            if pose_results_list is not None:
+                for model_info, pose_results in zip(pose_model_list,
+                                                    pose_results_list):
+                    pose_model = model_info['model']
+                    bbox_color = model_info['bbox_color']
+
+                    dataset_name = pose_model.cfg.data['test']['type']
+
+                    # show pose results
+                    if args.vis_mode == 1:
+                        img = vis_pose_result(
+                            pose_model,
+                            img,
+                            pose_results,
+                            radius=4,
+                            thickness=2,
+                            dataset=dataset_name,
+                            kpt_score_thr=1e7,
+                            bbox_color=bbox_color)
+                    elif args.vis_mode == 2:
+                        img = vis_pose_result(
+                            pose_model,
+                            img,
+                            pose_results,
+                            radius=4,
+                            thickness=2,
+                            dataset=dataset_name,
+                            kpt_score_thr=args.kpt_thr,
+                            bbox_color=bbox_color)
+
+                    # sunglasses effect
+                    if args.sunglasses:
+                        if dataset_name in {
+                                'TopDownCocoDataset',
+                                'TopDownCocoWholeBodyDataset'
+                        }:
+                            left_eye_idx = 1
+                            right_eye_idx = 2
+                        elif dataset_name == 'AnimalPoseDataset':
+                            left_eye_idx = 0
+                            right_eye_idx = 1
+                        else:
+                            raise ValueError(
+                                'Sunglasses effect does not support'
+                                f'{dataset_name}')
+                        if sunglasses_img is None:
+                            # The image attributes to:
+                            # https://www.vecteezy.com/free-vector/glass
+                            # Glass Vectors by Vecteezy
+                            sunglasses_img = cv2.imread(
+                                'demo/resources/sunglasses.jpg')
+                        img = apply_sunglasses_effect(img, pose_results,
+                                                      sunglasses_img,
+                                                      left_eye_idx,
+                                                      right_eye_idx)
+                    # bug-eye effect
+                    if args.bugeye:
+                        if dataset_name in {
+                                'TopDownCocoDataset',
+                                'TopDownCocoWholeBodyDataset'
+                        }:
+                            left_eye_idx = 1
+                            right_eye_idx = 2
+                        elif dataset_name == 'AnimalPoseDataset':
+                            left_eye_idx = 0
+                            right_eye_idx = 1
+                        else:
+                            raise ValueError('Bug-eye effect does not support'
+                                             f'{dataset_name}')
+                        img = apply_bugeye_effect(img, pose_results,
+                                                  left_eye_idx, right_eye_idx)
+
+            # delay control
+            if args.display_delay > 0:
+                t_sleep = args.display_delay * 0.001 - (time.time() - ts_input)
+                if t_sleep > 0:
+                    time.sleep(t_sleep)
+            t_delay = (time.time() - ts_input) * 1000
+
+            # show time information
+            t_info_display = stop_watch.report_strings()  # display fps
+            t_info_display.append(f'Inference FPS: {fps_inference:>5.1f}')
+            t_info_display.append(f'Delay: {t_delay:>3.0f}')
+            t_info_display.append(
+                f'Inference Delay: {t_delay_inference:>3.0f}')
+            t_info_str = ' | '.join(t_info_display + t_info)
+            cv2.putText(img, t_info_str, (20, 20), cv2.FONT_HERSHEY_DUPLEX,
+                        0.3, text_color, 1)
+            # collect system information
+            sys_info = [
+                f'RES: {img.shape[1]}x{img.shape[0]}',
+                f'Buffer: {frame_buffer.qsize()}/{frame_buffer.maxsize}'
+            ]
+            if psutil_proc is not None:
+                sys_info += [
+                    f'CPU: {psutil_proc.cpu_percent():.1f}%',
+                    f'MEM: {psutil_proc.memory_percent():.1f}%'
+                ]
+            sys_info_str = ' | '.join(sys_info)
+            cv2.putText(img, sys_info_str, (20, 40), cv2.FONT_HERSHEY_DUPLEX,
+                        0.3, text_color, 1)
+
+            # save the output video frame
+            if args.out_video_file is not None:
+                if vid_out is None:
+                    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
+                    fps = args.out_video_fps
+                    frame_size = (img.shape[1], img.shape[0])
+                    vid_out = cv2.VideoWriter(args.out_video_file, fourcc, fps,
+                                              frame_size)
+
+                vid_out.write(img)
+
+            # display
+            cv2.imshow('mmpose webcam demo', img)
+            keyboard_input = cv2.waitKey(1)
+            if keyboard_input in (27, ord('q'), ord('Q')):
+                break
+            elif keyboard_input == ord('s'):
+                args.sunglasses = not args.sunglasses
+            elif keyboard_input == ord('b'):
+                args.bugeye = not args.bugeye
+            elif keyboard_input == ord('v'):
+                args.vis_mode = (args.vis_mode + 1) % 3
+
+    cv2.destroyAllWindows()
+    if vid_out is not None:
+        vid_out.release()
+    event_exit.set()
+
+
+def main():
+    global args
+    global frame_buffer
+    global input_queue, input_queue_mutex
+    global det_result_queue, det_result_queue_mutex
+    global pose_result_queue, pose_result_queue_mutex
+    global det_model, pose_model_list, pose_history_list
+    global event_exit, event_inference_done
+
+    args = parse_args()
+
+    assert has_mmdet, 'Please install mmdet to run the demo.'
+    assert args.det_config is not None
+    assert args.det_checkpoint is not None
+
+    # build detection model
+    det_model = init_detector(
+        args.det_config, args.det_checkpoint, device=args.device.lower())
+
+    # build pose models
+    pose_model_list = []
+    if args.enable_human_pose:
+        pose_model = init_pose_model(
+            args.human_pose_config,
+            args.human_pose_checkpoint,
+            device=args.device.lower())
+        model_info = {
+            'name': 'HumanPose',
+            'model': pose_model,
+            'cat_ids': args.human_det_ids,
+            'bbox_color': (148, 139, 255),
+        }
+        pose_model_list.append(model_info)
+    if args.enable_animal_pose:
+        pose_model = init_pose_model(
+            args.animal_pose_config,
+            args.animal_pose_checkpoint,
+            device=args.device.lower())
+        model_info = {
+            'name': 'AnimalPose',
+            'model': pose_model,
+            'cat_ids': args.animal_det_ids,
+            'bbox_color': 'cyan',
+        }
+        pose_model_list.append(model_info)
+
+    # store pose history for pose tracking
+    pose_history_list = []
+    for _ in range(len(pose_model_list)):
+        pose_history_list.append({'pose_results_last': [], 'next_id': 0})
+
+    # frame buffer
+    if args.buffer_size > 0:
+        buffer_size = args.buffer_size
+    else:
+        # infer buffer size from the display delay time
+        # assume that the maximum video fps is 30
+        buffer_size = round(30 * (1 + max(args.display_delay, 0) / 1000.))
+    frame_buffer = Queue(maxsize=buffer_size)
+
+    # queue of input frames
+    # element: (timestamp, frame)
+    input_queue = deque(maxlen=1)
+    input_queue_mutex = Lock()
+
+    # queue of detection results
+    # element: tuple(timestamp, frame, time_info, det_results)
+    det_result_queue = deque(maxlen=1)
+    det_result_queue_mutex = Lock()
+
+    # queue of detection/pose results
+    # element: (timestamp, time_info, pose_results_list)
+    pose_result_queue = deque(maxlen=1)
+    pose_result_queue_mutex = Lock()
+
+    try:
+        event_exit = Event()
+        event_inference_done = Event()
+        t_input = Thread(target=read_camera, args=())
+        t_det = Thread(target=inference_detection, args=(), daemon=True)
+        t_pose = Thread(target=inference_pose, args=(), daemon=True)
+
+        t_input.start()
+        t_det.start()
+        t_pose.start()
+
+        # run display in the main thread
+        display()
+        # join the input thread (non-daemon)
+        t_input.join()
+
+    except KeyboardInterrupt:
+        pass
+
+
+if __name__ == '__main__':
+    main()