--- a +++ b/demo/demo.py @@ -0,0 +1,207 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os +import os.path as osp + +import cv2 +import decord +import numpy as np +import torch +import webcolors +from mmcv import Config, DictAction + +from mmaction.apis import inference_recognizer, init_recognizer + + +def parse_args(): + parser = argparse.ArgumentParser(description='MMAction2 demo') + parser.add_argument('config', help='test config file path') + parser.add_argument('checkpoint', help='checkpoint file/url') + parser.add_argument('video', help='video file/url or rawframes directory') + parser.add_argument('label', help='label file') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + default={}, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. For example, ' + "'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'") + parser.add_argument( + '--use-frames', + default=False, + action='store_true', + help='whether to use rawframes as input') + parser.add_argument( + '--device', type=str, default='cuda:0', help='CPU/CUDA device option') + parser.add_argument( + '--fps', + default=30, + type=int, + help='specify fps value of the output video when using rawframes to ' + 'generate file') + parser.add_argument( + '--font-scale', + default=0.5, + type=float, + help='font scale of the label in output video') + parser.add_argument( + '--font-color', + default='white', + help='font color of the label in output video') + parser.add_argument( + '--target-resolution', + nargs=2, + default=None, + type=int, + help='Target resolution (w, h) for resizing the frames when using a ' + 'video as input. If either dimension is set to -1, the frames are ' + 'resized by keeping the existing aspect ratio') + parser.add_argument( + '--resize-algorithm', + default='bicubic', + help='resize algorithm applied to generate video') + parser.add_argument('--out-filename', default=None, help='output filename') + args = parser.parse_args() + return args + + +def get_output(video_path, + out_filename, + label, + fps=30, + font_scale=0.5, + font_color='white', + target_resolution=None, + resize_algorithm='bicubic', + use_frames=False): + """Get demo output using ``moviepy``. + + This function will generate video file or gif file from raw video or + frames, by using ``moviepy``. For more information of some parameters, + you can refer to: https://github.com/Zulko/moviepy. + + Args: + video_path (str): The video file path or the rawframes directory path. + If ``use_frames`` is set to True, it should be rawframes directory + path. Otherwise, it should be video file path. + out_filename (str): Output filename for the generated file. + label (str): Predicted label of the generated file. + fps (int): Number of picture frames to read per second. Default: 30. + font_scale (float): Font scale of the label. Default: 0.5. + font_color (str): Font color of the label. Default: 'white'. + target_resolution (None | tuple[int | None]): Set to + (desired_width desired_height) to have resized frames. If either + dimension is None, the frames are resized by keeping the existing + aspect ratio. Default: None. + resize_algorithm (str): Support "bicubic", "bilinear", "neighbor", + "lanczos", etc. Default: 'bicubic'. For more information, + see https://ffmpeg.org/ffmpeg-scaler.html + use_frames: Determine Whether to use rawframes as input. Default:False. + """ + + if video_path.startswith(('http://', 'https://')): + raise NotImplementedError + + try: + from moviepy.editor import ImageSequenceClip + except ImportError: + raise ImportError('Please install moviepy to enable output file.') + + # Channel Order is BGR + if use_frames: + frame_list = sorted( + [osp.join(video_path, x) for x in os.listdir(video_path)]) + frames = [cv2.imread(x) for x in frame_list] + else: + video = decord.VideoReader(video_path) + frames = [x.asnumpy()[..., ::-1] for x in video] + + if target_resolution: + w, h = target_resolution + frame_h, frame_w, _ = frames[0].shape + if w == -1: + w = int(h / frame_h * frame_w) + if h == -1: + h = int(w / frame_w * frame_h) + frames = [cv2.resize(f, (w, h)) for f in frames] + + textsize = cv2.getTextSize(label, cv2.FONT_HERSHEY_DUPLEX, font_scale, + 1)[0] + textheight = textsize[1] + padding = 10 + location = (padding, padding + textheight) + + if isinstance(font_color, str): + font_color = webcolors.name_to_rgb(font_color)[::-1] + + frames = [np.array(frame) for frame in frames] + for frame in frames: + cv2.putText(frame, label, location, cv2.FONT_HERSHEY_DUPLEX, + font_scale, font_color, 1) + + # RGB order + frames = [x[..., ::-1] for x in frames] + video_clips = ImageSequenceClip(frames, fps=fps) + + out_type = osp.splitext(out_filename)[1][1:] + if out_type == 'gif': + video_clips.write_gif(out_filename) + else: + video_clips.write_videofile(out_filename, remove_temp=True) + + +def main(): + args = parse_args() + # assign the desired device. + device = torch.device(args.device) + + cfg = Config.fromfile(args.config) + cfg.merge_from_dict(args.cfg_options) + + # build the recognizer from a config file and checkpoint file/url + model = init_recognizer(cfg, args.checkpoint, device=device) + + # e.g. use ('backbone', ) to return backbone feature + output_layer_names = None + + # test a single video or rawframes of a single video + if output_layer_names: + results, returned_feature = inference_recognizer( + model, args.video, outputs=output_layer_names) + else: + results = inference_recognizer(model, args.video) + + labels = open(args.label).readlines() + labels = [x.strip() for x in labels] + results = [(labels[k[0]], k[1]) for k in results] + + print('The top-5 labels with corresponding scores are:') + for result in results: + print(f'{result[0]}: ', result[1]) + + if args.out_filename is not None: + + if args.target_resolution is not None: + if args.target_resolution[0] == -1: + assert isinstance(args.target_resolution[1], int) + assert args.target_resolution[1] > 0 + if args.target_resolution[1] == -1: + assert isinstance(args.target_resolution[0], int) + assert args.target_resolution[0] > 0 + args.target_resolution = tuple(args.target_resolution) + + get_output( + args.video, + args.out_filename, + results[0][0], + fps=args.fps, + font_scale=args.font_scale, + font_color=args.font_color, + target_resolution=args.target_resolution, + resize_algorithm=args.resize_algorithm, + use_frames=args.use_frames) + + +if __name__ == '__main__': + main()