--- a +++ b/demo/demo_video_structuralize.py @@ -0,0 +1,797 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import copy as cp +import os +import os.path as osp +import shutil + +import cv2 +import mmcv +import numpy as np +import torch +from mmcv import DictAction +from mmcv.runner import load_checkpoint + +from mmaction.apis import inference_recognizer +from mmaction.datasets.pipelines import Compose +from mmaction.models import build_detector, build_model, build_recognizer +from mmaction.utils import import_module_error_func + +try: + from mmdet.apis import inference_detector, init_detector + from mmpose.apis import (init_pose_model, inference_top_down_pose_model, + vis_pose_result) + +except (ImportError, ModuleNotFoundError): + + @import_module_error_func('mmdet') + def inference_detector(*args, **kwargs): + pass + + @import_module_error_func('mmdet') + def init_detector(*args, **kwargs): + pass + + @import_module_error_func('mmpose') + def init_pose_model(*args, **kwargs): + pass + + @import_module_error_func('mmpose') + def inference_top_down_pose_model(*args, **kwargs): + pass + + @import_module_error_func('mmpose') + def vis_pose_result(*args, **kwargs): + pass + + +try: + import moviepy.editor as mpy +except ImportError: + raise ImportError('Please install moviepy to enable output file') + +FONTFACE = cv2.FONT_HERSHEY_DUPLEX +FONTSCALE = 0.5 +FONTCOLOR = (255, 255, 255) # BGR, white +MSGCOLOR = (128, 128, 128) # BGR, gray +THICKNESS = 1 +LINETYPE = 1 + + +def hex2color(h): + """Convert the 6-digit hex string to tuple of 3 int value (RGB)""" + return (int(h[:2], 16), int(h[2:4], 16), int(h[4:], 16)) + + +PLATEBLUE = '03045e-023e8a-0077b6-0096c7-00b4d8-48cae4' +PLATEBLUE = PLATEBLUE.split('-') +PLATEBLUE = [hex2color(h) for h in PLATEBLUE] +PLATEGREEN = '004b23-006400-007200-008000-38b000-70e000' +PLATEGREEN = PLATEGREEN.split('-') +PLATEGREEN = [hex2color(h) for h in PLATEGREEN] + + +def visualize(frames, + annotations, + pose_results, + action_result, + pose_model, + plate=PLATEBLUE, + max_num=5): + """Visualize frames with predicted annotations. + + Args: + frames (list[np.ndarray]): Frames for visualization, note that + len(frames) % len(annotations) should be 0. + annotations (list[list[tuple]]): The predicted spatio-temporal + detection results. + pose_results (list[list[tuple]): The pose results. + action_result (str): The predicted action recognition results. + pose_model (nn.Module): The constructed pose model. + plate (str): The plate used for visualization. Default: PLATEBLUE. + max_num (int): Max number of labels to visualize for a person box. + Default: 5. + + Returns: + list[np.ndarray]: Visualized frames. + """ + + assert max_num + 1 <= len(plate) + plate = [x[::-1] for x in plate] + frames_ = cp.deepcopy(frames) + nf, na = len(frames), len(annotations) + assert nf % na == 0 + nfpa = len(frames) // len(annotations) + anno = None + h, w, _ = frames[0].shape + scale_ratio = np.array([w, h, w, h]) + + # add pose results + if pose_results: + for i in range(nf): + frames_[i] = vis_pose_result(pose_model, frames_[i], + pose_results[i]) + + for i in range(na): + anno = annotations[i] + if anno is None: + continue + for j in range(nfpa): + ind = i * nfpa + j + frame = frames_[ind] + + # add action result for whole video + cv2.putText(frame, action_result, (10, 30), FONTFACE, FONTSCALE, + FONTCOLOR, THICKNESS, LINETYPE) + + # add spatio-temporal action detection results + for ann in anno: + box = ann[0] + label = ann[1] + if not len(label): + continue + score = ann[2] + box = (box * scale_ratio).astype(np.int64) + st, ed = tuple(box[:2]), tuple(box[2:]) + if not pose_results: + cv2.rectangle(frame, st, ed, plate[0], 2) + + for k, lb in enumerate(label): + if k >= max_num: + break + text = abbrev(lb) + text = ': '.join([text, str(score[k])]) + location = (0 + st[0], 18 + k * 18 + st[1]) + textsize = cv2.getTextSize(text, FONTFACE, FONTSCALE, + THICKNESS)[0] + textwidth = textsize[0] + diag0 = (location[0] + textwidth, location[1] - 14) + diag1 = (location[0], location[1] + 2) + cv2.rectangle(frame, diag0, diag1, plate[k + 1], -1) + cv2.putText(frame, text, location, FONTFACE, FONTSCALE, + FONTCOLOR, THICKNESS, LINETYPE) + + return frames_ + + +def parse_args(): + parser = argparse.ArgumentParser(description='MMAction2 demo') + parser.add_argument( + '--rgb-stdet-config', + default=('configs/detection/ava/' + 'slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb.py'), + help='rgb-based spatio temporal detection config file path') + parser.add_argument( + '--rgb-stdet-checkpoint', + default=('https://download.openmmlab.com/mmaction/detection/ava/' + 'slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb/' + 'slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb' + '_20201217-16378594.pth'), + help='rgb-based spatio temporal detection checkpoint file/url') + parser.add_argument( + '--skeleton-stdet-checkpoint', + default=('https://download.openmmlab.com/mmaction/skeleton/posec3d/' + 'posec3d_ava.pth'), + help='skeleton-based spatio temporal detection checkpoint file/url') + parser.add_argument( + '--det-config', + default='demo/faster_rcnn_r50_fpn_2x_coco.py', + help='human detection config file path (from mmdet)') + parser.add_argument( + '--det-checkpoint', + default=('http://download.openmmlab.com/mmdetection/v2.0/' + 'faster_rcnn/faster_rcnn_r50_fpn_2x_coco/' + 'faster_rcnn_r50_fpn_2x_coco_' + 'bbox_mAP-0.384_20200504_210434-a5d8aa15.pth'), + help='human detection checkpoint file/url') + parser.add_argument( + '--pose-config', + default='demo/hrnet_w32_coco_256x192.py', + help='human pose estimation config file path (from mmpose)') + parser.add_argument( + '--pose-checkpoint', + default=('https://download.openmmlab.com/mmpose/top_down/hrnet/' + 'hrnet_w32_coco_256x192-c78dce93_20200708.pth'), + help='human pose estimation checkpoint file/url') + parser.add_argument( + '--skeleton-config', + default='configs/skeleton/posec3d/' + 'slowonly_r50_u48_240e_ntu120_xsub_keypoint.py', + help='skeleton-based action recognition config file path') + parser.add_argument( + '--skeleton-checkpoint', + default='https://download.openmmlab.com/mmaction/skeleton/posec3d/' + 'posec3d_k400.pth', + help='skeleton-based action recognition checkpoint file/url') + parser.add_argument( + '--rgb-config', + default='configs/recognition/tsn/' + 'tsn_r50_video_inference_1x1x3_100e_kinetics400_rgb.py', + help='rgb-based action recognition config file path') + parser.add_argument( + '--rgb-checkpoint', + default='https://download.openmmlab.com/mmaction/recognition/' + 'tsn/tsn_r50_1x1x3_100e_kinetics400_rgb/' + 'tsn_r50_1x1x3_100e_kinetics400_rgb_20200614-e508be42.pth', + help='rgb-based action recognition checkpoint file/url') + parser.add_argument( + '--use-skeleton-stdet', + action='store_true', + help='use skeleton-based spatio temporal detection method') + parser.add_argument( + '--use-skeleton-recog', + action='store_true', + help='use skeleton-based action recognition method') + parser.add_argument( + '--det-score-thr', + type=float, + default=0.9, + help='the threshold of human detection score') + parser.add_argument( + '--action-score-thr', + type=float, + default=0.4, + help='the threshold of action prediction score') + parser.add_argument( + '--video', + default='demo/test_video_structuralize.mp4', + help='video file/url') + parser.add_argument( + '--label-map-stdet', + default='tools/data/ava/label_map.txt', + help='label map file for spatio-temporal action detection') + parser.add_argument( + '--label-map', + default='tools/data/kinetics/label_map_k400.txt', + help='label map file for action recognition') + parser.add_argument( + '--device', type=str, default='cuda:0', help='CPU/CUDA device option') + parser.add_argument( + '--out-filename', + default='demo/test_stdet_recognition_output.mp4', + help='output filename') + parser.add_argument( + '--predict-stepsize', + default=8, + type=int, + help='give out a spatio-temporal detection prediction per n frames') + parser.add_argument( + '--output-stepsize', + default=1, + type=int, + help=('show one frame per n frames in the demo, we should have: ' + 'predict_stepsize % output_stepsize == 0')) + parser.add_argument( + '--output-fps', + default=24, + type=int, + help='the fps of demo video output') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + default={}, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. For example, ' + "'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'") + args = parser.parse_args() + return args + + +def frame_extraction(video_path): + """Extract frames given video_path. + + Args: + video_path (str): The video_path. + """ + # Load the video, extract frames into ./tmp/video_name + target_dir = osp.join('./tmp', osp.basename(osp.splitext(video_path)[0])) + # target_dir = osp.join('./tmp','spatial_skeleton_dir') + os.makedirs(target_dir, exist_ok=True) + # Should be able to handle videos up to several hours + frame_tmpl = osp.join(target_dir, 'img_{:06d}.jpg') + vid = cv2.VideoCapture(video_path) + frames = [] + frame_paths = [] + flag, frame = vid.read() + cnt = 0 + while flag: + frames.append(frame) + frame_path = frame_tmpl.format(cnt + 1) + frame_paths.append(frame_path) + cv2.imwrite(frame_path, frame) + cnt += 1 + flag, frame = vid.read() + return frame_paths, frames + + +def detection_inference(args, frame_paths): + """Detect human boxes given frame paths. + + Args: + args (argparse.Namespace): The arguments. + frame_paths (list[str]): The paths of frames to do detection inference. + + Returns: + list[np.ndarray]: The human detection results. + """ + model = init_detector(args.det_config, args.det_checkpoint, args.device) + assert model.CLASSES[0] == 'person', ('We require you to use a detector ' + 'trained on COCO') + results = [] + print('Performing Human Detection for each frame') + prog_bar = mmcv.ProgressBar(len(frame_paths)) + for frame_path in frame_paths: + result = inference_detector(model, frame_path) + # We only keep human detections with score larger than det_score_thr + result = result[0][result[0][:, 4] >= args.det_score_thr] + results.append(result) + prog_bar.update() + + return results + + +def pose_inference(args, frame_paths, det_results): + model = init_pose_model(args.pose_config, args.pose_checkpoint, + args.device) + ret = [] + print('Performing Human Pose Estimation for each frame') + prog_bar = mmcv.ProgressBar(len(frame_paths)) + for f, d in zip(frame_paths, det_results): + # Align input format + d = [dict(bbox=x) for x in list(d)] + + pose = inference_top_down_pose_model(model, f, d, format='xyxy')[0] + ret.append(pose) + prog_bar.update() + return ret + + +def load_label_map(file_path): + """Load Label Map. + + Args: + file_path (str): The file path of label map. + + Returns: + dict: The label map (int -> label name). + """ + lines = open(file_path).readlines() + lines = [x.strip().split(': ') for x in lines] + return {int(x[0]): x[1] for x in lines} + + +def abbrev(name): + """Get the abbreviation of label name: + + 'take (an object) from (a person)' -> 'take ... from ...' + """ + while name.find('(') != -1: + st, ed = name.find('('), name.find(')') + name = name[:st] + '...' + name[ed + 1:] + return name + + +def pack_result(human_detection, result, img_h, img_w): + """Short summary. + + Args: + human_detection (np.ndarray): Human detection result. + result (type): The predicted label of each human proposal. + img_h (int): The image height. + img_w (int): The image width. + + Returns: + tuple: Tuple of human proposal, label name and label score. + """ + human_detection[:, 0::2] /= img_w + human_detection[:, 1::2] /= img_h + results = [] + if result is None: + return None + for prop, res in zip(human_detection, result): + res.sort(key=lambda x: -x[1]) + results.append( + (prop.data.cpu().numpy(), [x[0] for x in res], [x[1] + for x in res])) + return results + + +def expand_bbox(bbox, h, w, ratio=1.25): + x1, y1, x2, y2 = bbox + center_x = (x1 + x2) // 2 + center_y = (y1 + y2) // 2 + width = x2 - x1 + height = y2 - y1 + + square_l = max(width, height) + new_width = new_height = square_l * ratio + + new_x1 = max(0, int(center_x - new_width / 2)) + new_x2 = min(int(center_x + new_width / 2), w) + new_y1 = max(0, int(center_y - new_height / 2)) + new_y2 = min(int(center_y + new_height / 2), h) + return (new_x1, new_y1, new_x2, new_y2) + + +def cal_iou(box1, box2): + xmin1, ymin1, xmax1, ymax1 = box1 + xmin2, ymin2, xmax2, ymax2 = box2 + + s1 = (xmax1 - xmin1) * (ymax1 - ymin1) + s2 = (xmax2 - xmin2) * (ymax2 - ymin2) + + xmin = max(xmin1, xmin2) + ymin = max(ymin1, ymin2) + xmax = min(xmax1, xmax2) + ymax = min(ymax1, ymax2) + + w = max(0, xmax - xmin) + h = max(0, ymax - ymin) + intersect = w * h + union = s1 + s2 - intersect + iou = intersect / union + + return iou + + +def skeleton_based_action_recognition(args, pose_results, num_frame, h, w): + fake_anno = dict( + frame_dict='', + label=-1, + img_shape=(h, w), + origin_shape=(h, w), + start_index=0, + modality='Pose', + total_frames=num_frame) + num_person = max([len(x) for x in pose_results]) + + num_keypoint = 17 + keypoint = np.zeros((num_person, num_frame, num_keypoint, 2), + dtype=np.float16) + keypoint_score = np.zeros((num_person, num_frame, num_keypoint), + dtype=np.float16) + for i, poses in enumerate(pose_results): + for j, pose in enumerate(poses): + pose = pose['keypoints'] + keypoint[j, i] = pose[:, :2] + keypoint_score[j, i] = pose[:, 2] + + fake_anno['keypoint'] = keypoint + fake_anno['keypoint_score'] = keypoint_score + + label_map = [x.strip() for x in open(args.label_map).readlines()] + num_class = len(label_map) + + skeleton_config = mmcv.Config.fromfile(args.skeleton_config) + skeleton_config.model.cls_head.num_classes = num_class # for K400 dataset + skeleton_pipeline = Compose(skeleton_config.test_pipeline) + skeleton_imgs = skeleton_pipeline(fake_anno)['imgs'][None] + skeleton_imgs = skeleton_imgs.to(args.device) + + # Build skeleton-based recognition model + skeleton_model = build_model(skeleton_config.model) + load_checkpoint( + skeleton_model, args.skeleton_checkpoint, map_location='cpu') + skeleton_model.to(args.device) + skeleton_model.eval() + + with torch.no_grad(): + output = skeleton_model(return_loss=False, imgs=skeleton_imgs) + + action_idx = np.argmax(output) + skeleton_action_result = label_map[ + action_idx] # skeleton-based action result for the whole video + return skeleton_action_result + + +def rgb_based_action_recognition(args): + rgb_config = mmcv.Config.fromfile(args.rgb_config) + rgb_config.model.backbone.pretrained = None + rgb_model = build_recognizer( + rgb_config.model, test_cfg=rgb_config.get('test_cfg')) + load_checkpoint(rgb_model, args.rgb_checkpoint, map_location='cpu') + rgb_model.cfg = rgb_config + rgb_model.to(args.device) + rgb_model.eval() + action_results = inference_recognizer(rgb_model, args.video, + args.label_map) + rgb_action_result = action_results[0][0] + return rgb_action_result + + +def skeleton_based_stdet(args, label_map, human_detections, pose_results, + num_frame, clip_len, frame_interval, h, w): + window_size = clip_len * frame_interval + assert clip_len % 2 == 0, 'We would like to have an even clip_len' + timestamps = np.arange(window_size // 2, num_frame + 1 - window_size // 2, + args.predict_stepsize) + + skeleton_config = mmcv.Config.fromfile(args.skeleton_config) + num_class = max(label_map.keys()) + 1 # for AVA dataset (81) + skeleton_config.model.cls_head.num_classes = num_class + skeleton_pipeline = Compose(skeleton_config.test_pipeline) + skeleton_stdet_model = build_model(skeleton_config.model) + load_checkpoint( + skeleton_stdet_model, + args.skeleton_stdet_checkpoint, + map_location='cpu') + skeleton_stdet_model.to(args.device) + skeleton_stdet_model.eval() + + skeleton_predictions = [] + + print('Performing SpatioTemporal Action Detection for each clip') + prog_bar = mmcv.ProgressBar(len(timestamps)) + for timestamp in timestamps: + proposal = human_detections[timestamp - 1] + if proposal.shape[0] == 0: # no people detected + skeleton_predictions.append(None) + continue + + start_frame = timestamp - (clip_len // 2 - 1) * frame_interval + frame_inds = start_frame + np.arange(0, window_size, frame_interval) + frame_inds = list(frame_inds - 1) + num_frame = len(frame_inds) # 30 + + pose_result = [pose_results[ind] for ind in frame_inds] + + skeleton_prediction = [] + for i in range(proposal.shape[0]): # num_person + skeleton_prediction.append([]) + + fake_anno = dict( + frame_dict='', + label=-1, + img_shape=(h, w), + origin_shape=(h, w), + start_index=0, + modality='Pose', + total_frames=num_frame) + num_person = 1 + + num_keypoint = 17 + keypoint = np.zeros( + (num_person, num_frame, num_keypoint, 2)) # M T V 2 + keypoint_score = np.zeros( + (num_person, num_frame, num_keypoint)) # M T V + + # pose matching + person_bbox = proposal[i][:4] + area = expand_bbox(person_bbox, h, w) + + for j, poses in enumerate(pose_result): # num_frame + max_iou = float('-inf') + index = -1 + if len(poses) == 0: + continue + for k, per_pose in enumerate(poses): + iou = cal_iou(per_pose['bbox'][:4], area) + if max_iou < iou: + index = k + max_iou = iou + keypoint[0, j] = poses[index]['keypoints'][:, :2] + keypoint_score[0, j] = poses[index]['keypoints'][:, 2] + + fake_anno['keypoint'] = keypoint + fake_anno['keypoint_score'] = keypoint_score + + skeleton_imgs = skeleton_pipeline(fake_anno)['imgs'][None] + skeleton_imgs = skeleton_imgs.to(args.device) + + with torch.no_grad(): + output = skeleton_stdet_model( + return_loss=False, imgs=skeleton_imgs) + output = output[0] + for k in range(len(output)): # 81 + if k not in label_map: + continue + if output[k] > args.action_score_thr: + skeleton_prediction[i].append( + (label_map[k], output[k])) + + skeleton_predictions.append(skeleton_prediction) + prog_bar.update() + + return timestamps, skeleton_predictions + + +def rgb_based_stdet(args, frames, label_map, human_detections, w, h, new_w, + new_h, w_ratio, h_ratio): + + rgb_stdet_config = mmcv.Config.fromfile(args.rgb_stdet_config) + rgb_stdet_config.merge_from_dict(args.cfg_options) + + val_pipeline = rgb_stdet_config.data.val.pipeline + sampler = [x for x in val_pipeline if x['type'] == 'SampleAVAFrames'][0] + clip_len, frame_interval = sampler['clip_len'], sampler['frame_interval'] + assert clip_len % 2 == 0, 'We would like to have an even clip_len' + + window_size = clip_len * frame_interval + num_frame = len(frames) + timestamps = np.arange(window_size // 2, num_frame + 1 - window_size // 2, + args.predict_stepsize) + + # Get img_norm_cfg + img_norm_cfg = rgb_stdet_config['img_norm_cfg'] + if 'to_rgb' not in img_norm_cfg and 'to_bgr' in img_norm_cfg: + to_bgr = img_norm_cfg.pop('to_bgr') + img_norm_cfg['to_rgb'] = to_bgr + img_norm_cfg['mean'] = np.array(img_norm_cfg['mean']) + img_norm_cfg['std'] = np.array(img_norm_cfg['std']) + + # Build STDET model + try: + # In our spatiotemporal detection demo, different actions should have + # the same number of bboxes. + rgb_stdet_config['model']['test_cfg']['rcnn']['action_thr'] = .0 + except KeyError: + pass + + rgb_stdet_config.model.backbone.pretrained = None + rgb_stdet_model = build_detector( + rgb_stdet_config.model, test_cfg=rgb_stdet_config.get('test_cfg')) + + load_checkpoint( + rgb_stdet_model, args.rgb_stdet_checkpoint, map_location='cpu') + rgb_stdet_model.to(args.device) + rgb_stdet_model.eval() + + predictions = [] + + print('Performing SpatioTemporal Action Detection for each clip') + prog_bar = mmcv.ProgressBar(len(timestamps)) + for timestamp in timestamps: + proposal = human_detections[timestamp - 1] + + if proposal.shape[0] == 0: + predictions.append(None) + continue + + start_frame = timestamp - (clip_len // 2 - 1) * frame_interval + frame_inds = start_frame + np.arange(0, window_size, frame_interval) + frame_inds = list(frame_inds - 1) + + imgs = [frames[ind].astype(np.float32) for ind in frame_inds] + _ = [mmcv.imnormalize_(img, **img_norm_cfg) for img in imgs] + # THWC -> CTHW -> 1CTHW + input_array = np.stack(imgs).transpose((3, 0, 1, 2))[np.newaxis] + input_tensor = torch.from_numpy(input_array).to(args.device) + + with torch.no_grad(): + result = rgb_stdet_model( + return_loss=False, + img=[input_tensor], + img_metas=[[dict(img_shape=(new_h, new_w))]], + proposals=[[proposal]]) + result = result[0] + prediction = [] + # N proposals + for i in range(proposal.shape[0]): + prediction.append([]) + + # Perform action score thr + for i in range(len(result)): # 80 + if i + 1 not in label_map: + continue + for j in range(proposal.shape[0]): + if result[i][j, 4] > args.action_score_thr: + prediction[j].append((label_map[i + 1], result[i][j, + 4])) + predictions.append(prediction) + prog_bar.update() + + return timestamps, predictions + + +def main(): + args = parse_args() + + frame_paths, original_frames = frame_extraction(args.video) + num_frame = len(frame_paths) + h, w, _ = original_frames[0].shape + + # Get Human detection results and pose results + human_detections = detection_inference(args, frame_paths) + pose_results = None + if args.use_skeleton_recog or args.use_skeleton_stdet: + pose_results = pose_inference(args, frame_paths, human_detections) + + # resize frames to shortside 256 + new_w, new_h = mmcv.rescale_size((w, h), (256, np.Inf)) + frames = [mmcv.imresize(img, (new_w, new_h)) for img in original_frames] + w_ratio, h_ratio = new_w / w, new_h / h + + # Load spatio-temporal detection label_map + stdet_label_map = load_label_map(args.label_map_stdet) + rgb_stdet_config = mmcv.Config.fromfile(args.rgb_stdet_config) + rgb_stdet_config.merge_from_dict(args.cfg_options) + try: + if rgb_stdet_config['data']['train']['custom_classes'] is not None: + stdet_label_map = { + id + 1: stdet_label_map[cls] + for id, cls in enumerate(rgb_stdet_config['data']['train'] + ['custom_classes']) + } + except KeyError: + pass + + action_result = None + if args.use_skeleton_recog: + print('Use skeleton-based recognition') + action_result = skeleton_based_action_recognition( + args, pose_results, num_frame, h, w) + else: + print('Use rgb-based recognition') + action_result = rgb_based_action_recognition(args) + + stdet_preds = None + if args.use_skeleton_stdet: + print('Use skeleton-based SpatioTemporal Action Detection') + clip_len, frame_interval = 30, 1 + timestamps, stdet_preds = skeleton_based_stdet(args, stdet_label_map, + human_detections, + pose_results, num_frame, + clip_len, + frame_interval, h, w) + for i in range(len(human_detections)): + det = human_detections[i] + det[:, 0:4:2] *= w_ratio + det[:, 1:4:2] *= h_ratio + human_detections[i] = torch.from_numpy(det[:, :4]).to(args.device) + + else: + print('Use rgb-based SpatioTemporal Action Detection') + for i in range(len(human_detections)): + det = human_detections[i] + det[:, 0:4:2] *= w_ratio + det[:, 1:4:2] *= h_ratio + human_detections[i] = torch.from_numpy(det[:, :4]).to(args.device) + timestamps, stdet_preds = rgb_based_stdet(args, frames, + stdet_label_map, + human_detections, w, h, + new_w, new_h, w_ratio, + h_ratio) + + stdet_results = [] + for timestamp, prediction in zip(timestamps, stdet_preds): + human_detection = human_detections[timestamp - 1] + stdet_results.append( + pack_result(human_detection, prediction, new_h, new_w)) + + def dense_timestamps(timestamps, n): + """Make it nx frames.""" + old_frame_interval = (timestamps[1] - timestamps[0]) + start = timestamps[0] - old_frame_interval / n * (n - 1) / 2 + new_frame_inds = np.arange( + len(timestamps) * n) * old_frame_interval / n + start + return new_frame_inds.astype(np.int) + + dense_n = int(args.predict_stepsize / args.output_stepsize) + output_timestamps = dense_timestamps(timestamps, dense_n) + frames = [ + cv2.imread(frame_paths[timestamp - 1]) + for timestamp in output_timestamps + ] + + print('Performing visualization') + pose_model = init_pose_model(args.pose_config, args.pose_checkpoint, + args.device) + + if args.use_skeleton_recog or args.use_skeleton_stdet: + pose_results = [ + pose_results[timestamp - 1] for timestamp in output_timestamps + ] + + vis_frames = visualize(frames, stdet_results, pose_results, action_result, + pose_model) + vid = mpy.ImageSequenceClip([x[:, :, ::-1] for x in vis_frames], + fps=args.output_fps) + vid.write_videofile(args.out_filename) + + tmp_frame_dir = osp.dirname(frame_paths[0]) + shutil.rmtree(tmp_frame_dir) + + +if __name__ == '__main__': + main()