--- a +++ b/demo/long_video_demo.py @@ -0,0 +1,265 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import json +import random +from collections import deque +from operator import itemgetter + +import cv2 +import mmcv +import numpy as np +import torch +from mmcv import Config, DictAction +from mmcv.parallel import collate, scatter + +from mmaction.apis import init_recognizer +from mmaction.datasets.pipelines import Compose + +FONTFACE = cv2.FONT_HERSHEY_COMPLEX_SMALL +FONTSCALE = 1 +THICKNESS = 1 +LINETYPE = 1 + +EXCLUED_STEPS = [ + 'OpenCVInit', 'OpenCVDecode', 'DecordInit', 'DecordDecode', 'PyAVInit', + 'PyAVDecode', 'RawFrameDecode' +] + + +def parse_args(): + parser = argparse.ArgumentParser( + description='MMAction2 predict different labels in a long video demo') + parser.add_argument('config', help='test config file path') + parser.add_argument('checkpoint', help='checkpoint file/url') + parser.add_argument('video_path', help='video file/url') + parser.add_argument('label', help='label file') + parser.add_argument('out_file', help='output result file in video/json') + parser.add_argument( + '--input-step', + type=int, + default=1, + help='input step for sampling frames') + parser.add_argument( + '--device', type=str, default='cuda:0', help='CPU/CUDA device option') + parser.add_argument( + '--threshold', + type=float, + default=0.01, + help='recognition score threshold') + parser.add_argument( + '--stride', + type=float, + default=0, + help=('the prediction stride equals to stride * sample_length ' + '(sample_length indicates the size of temporal window from ' + 'which you sample frames, which equals to ' + 'clip_len x frame_interval), if set as 0, the ' + 'prediction stride is 1')) + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + default={}, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. For example, ' + "'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'") + parser.add_argument( + '--label-color', + nargs='+', + type=int, + default=(255, 255, 255), + help='font color (B, G, R) of the labels in output video') + parser.add_argument( + '--msg-color', + nargs='+', + type=int, + default=(128, 128, 128), + help='font color (B, G, R) of the messages in output video') + args = parser.parse_args() + return args + + +def show_results_video(result_queue, + text_info, + thr, + msg, + frame, + video_writer, + label_color=(255, 255, 255), + msg_color=(128, 128, 128)): + if len(result_queue) != 0: + text_info = {} + results = result_queue.popleft() + for i, result in enumerate(results): + selected_label, score = result + if score < thr: + break + location = (0, 40 + i * 20) + text = selected_label + ': ' + str(round(score, 2)) + text_info[location] = text + cv2.putText(frame, text, location, FONTFACE, FONTSCALE, + label_color, THICKNESS, LINETYPE) + elif len(text_info): + for location, text in text_info.items(): + cv2.putText(frame, text, location, FONTFACE, FONTSCALE, + label_color, THICKNESS, LINETYPE) + else: + cv2.putText(frame, msg, (0, 40), FONTFACE, FONTSCALE, msg_color, + THICKNESS, LINETYPE) + video_writer.write(frame) + return text_info + + +def get_results_json(result_queue, text_info, thr, msg, ind, out_json): + if len(result_queue) != 0: + text_info = {} + results = result_queue.popleft() + for i, result in enumerate(results): + selected_label, score = result + if score < thr: + break + text_info[i + 1] = selected_label + ': ' + str(round(score, 2)) + out_json[ind] = text_info + elif len(text_info): + out_json[ind] = text_info + else: + out_json[ind] = msg + return text_info, out_json + + +def show_results(model, data, label, args): + frame_queue = deque(maxlen=args.sample_length) + result_queue = deque(maxlen=1) + + cap = cv2.VideoCapture(args.video_path) + num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + fps = cap.get(cv2.CAP_PROP_FPS) + + msg = 'Preparing action recognition ...' + text_info = {} + out_json = {} + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + frame_size = (frame_width, frame_height) + + ind = 0 + video_writer = None if args.out_file.endswith('.json') \ + else cv2.VideoWriter(args.out_file, fourcc, fps, frame_size) + prog_bar = mmcv.ProgressBar(num_frames) + backup_frames = [] + + while ind < num_frames: + ind += 1 + prog_bar.update() + ret, frame = cap.read() + if frame is None: + # drop it when encounting None + continue + backup_frames.append(np.array(frame)[:, :, ::-1]) + if ind == args.sample_length: + # provide a quick show at the beginning + frame_queue.extend(backup_frames) + backup_frames = [] + elif ((len(backup_frames) == args.input_step + and ind > args.sample_length) or ind == num_frames): + # pick a frame from the backup + # when the backup is full or reach the last frame + chosen_frame = random.choice(backup_frames) + backup_frames = [] + frame_queue.append(chosen_frame) + + ret, scores = inference(model, data, args, frame_queue) + + if ret: + num_selected_labels = min(len(label), 5) + scores_tuples = tuple(zip(label, scores)) + scores_sorted = sorted( + scores_tuples, key=itemgetter(1), reverse=True) + results = scores_sorted[:num_selected_labels] + result_queue.append(results) + + if args.out_file.endswith('.json'): + text_info, out_json = get_results_json(result_queue, text_info, + args.threshold, msg, ind, + out_json) + else: + text_info = show_results_video(result_queue, text_info, + args.threshold, msg, frame, + video_writer, args.label_color, + args.msg_color) + + cap.release() + cv2.destroyAllWindows() + if args.out_file.endswith('.json'): + with open(args.out_file, 'w') as js: + json.dump(out_json, js) + + +def inference(model, data, args, frame_queue): + if len(frame_queue) != args.sample_length: + # Do no inference when there is no enough frames + return False, None + + cur_windows = list(np.array(frame_queue)) + if data['img_shape'] is None: + data['img_shape'] = frame_queue[0].shape[:2] + + cur_data = data.copy() + cur_data['imgs'] = cur_windows + cur_data = args.test_pipeline(cur_data) + cur_data = collate([cur_data], samples_per_gpu=1) + if next(model.parameters()).is_cuda: + cur_data = scatter(cur_data, [args.device])[0] + with torch.no_grad(): + scores = model(return_loss=False, **cur_data)[0] + + if args.stride > 0: + pred_stride = int(args.sample_length * args.stride) + for _ in range(pred_stride): + frame_queue.popleft() + + # for case ``args.stride=0`` + # deque will automatically popleft one element + + return True, scores + + +def main(): + args = parse_args() + + args.device = torch.device(args.device) + + cfg = Config.fromfile(args.config) + cfg.merge_from_dict(args.cfg_options) + + model = init_recognizer(cfg, args.checkpoint, device=args.device) + data = dict(img_shape=None, modality='RGB', label=-1) + with open(args.label, 'r') as f: + label = [line.strip() for line in f] + + # prepare test pipeline from non-camera pipeline + cfg = model.cfg + sample_length = 0 + pipeline = cfg.data.test.pipeline + pipeline_ = pipeline.copy() + for step in pipeline: + if 'SampleFrames' in step['type']: + sample_length = step['clip_len'] * step['num_clips'] + data['num_clips'] = step['num_clips'] + data['clip_len'] = step['clip_len'] + pipeline_.remove(step) + if step['type'] in EXCLUED_STEPS: + # remove step to decode frames + pipeline_.remove(step) + test_pipeline = Compose(pipeline_) + + assert sample_length > 0 + args.sample_length = sample_length + args.test_pipeline = test_pipeline + + show_results(model, data, label, args) + + +if __name__ == '__main__': + main()