Diff of /demo/demo.py [000000] .. [6d389a]

Switch to unified view

a b/demo/demo.py
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import argparse
3
import os
4
import os.path as osp
5
6
import cv2
7
import decord
8
import numpy as np
9
import torch
10
import webcolors
11
from mmcv import Config, DictAction
12
13
from mmaction.apis import inference_recognizer, init_recognizer
14
15
16
def parse_args():
17
    parser = argparse.ArgumentParser(description='MMAction2 demo')
18
    parser.add_argument('config', help='test config file path')
19
    parser.add_argument('checkpoint', help='checkpoint file/url')
20
    parser.add_argument('video', help='video file/url or rawframes directory')
21
    parser.add_argument('label', help='label file')
22
    parser.add_argument(
23
        '--cfg-options',
24
        nargs='+',
25
        action=DictAction,
26
        default={},
27
        help='override some settings in the used config, the key-value pair '
28
        'in xxx=yyy format will be merged into config file. For example, '
29
        "'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'")
30
    parser.add_argument(
31
        '--use-frames',
32
        default=False,
33
        action='store_true',
34
        help='whether to use rawframes as input')
35
    parser.add_argument(
36
        '--device', type=str, default='cuda:0', help='CPU/CUDA device option')
37
    parser.add_argument(
38
        '--fps',
39
        default=30,
40
        type=int,
41
        help='specify fps value of the output video when using rawframes to '
42
        'generate file')
43
    parser.add_argument(
44
        '--font-scale',
45
        default=0.5,
46
        type=float,
47
        help='font scale of the label in output video')
48
    parser.add_argument(
49
        '--font-color',
50
        default='white',
51
        help='font color of the label in output video')
52
    parser.add_argument(
53
        '--target-resolution',
54
        nargs=2,
55
        default=None,
56
        type=int,
57
        help='Target resolution (w, h) for resizing the frames when using a '
58
        'video as input. If either dimension is set to -1, the frames are '
59
        'resized by keeping the existing aspect ratio')
60
    parser.add_argument(
61
        '--resize-algorithm',
62
        default='bicubic',
63
        help='resize algorithm applied to generate video')
64
    parser.add_argument('--out-filename', default=None, help='output filename')
65
    args = parser.parse_args()
66
    return args
67
68
69
def get_output(video_path,
70
               out_filename,
71
               label,
72
               fps=30,
73
               font_scale=0.5,
74
               font_color='white',
75
               target_resolution=None,
76
               resize_algorithm='bicubic',
77
               use_frames=False):
78
    """Get demo output using ``moviepy``.
79
80
    This function will generate video file or gif file from raw video or
81
    frames, by using ``moviepy``. For more information of some parameters,
82
    you can refer to: https://github.com/Zulko/moviepy.
83
84
    Args:
85
        video_path (str): The video file path or the rawframes directory path.
86
            If ``use_frames`` is set to True, it should be rawframes directory
87
            path. Otherwise, it should be video file path.
88
        out_filename (str): Output filename for the generated file.
89
        label (str): Predicted label of the generated file.
90
        fps (int): Number of picture frames to read per second. Default: 30.
91
        font_scale (float): Font scale of the label. Default: 0.5.
92
        font_color (str): Font color of the label. Default: 'white'.
93
        target_resolution (None | tuple[int | None]): Set to
94
            (desired_width desired_height) to have resized frames. If either
95
            dimension is None, the frames are resized by keeping the existing
96
            aspect ratio. Default: None.
97
        resize_algorithm (str): Support "bicubic", "bilinear", "neighbor",
98
            "lanczos", etc. Default: 'bicubic'. For more information,
99
            see https://ffmpeg.org/ffmpeg-scaler.html
100
        use_frames: Determine Whether to use rawframes as input. Default:False.
101
    """
102
103
    if video_path.startswith(('http://', 'https://')):
104
        raise NotImplementedError
105
106
    try:
107
        from moviepy.editor import ImageSequenceClip
108
    except ImportError:
109
        raise ImportError('Please install moviepy to enable output file.')
110
111
    # Channel Order is BGR
112
    if use_frames:
113
        frame_list = sorted(
114
            [osp.join(video_path, x) for x in os.listdir(video_path)])
115
        frames = [cv2.imread(x) for x in frame_list]
116
    else:
117
        video = decord.VideoReader(video_path)
118
        frames = [x.asnumpy()[..., ::-1] for x in video]
119
120
    if target_resolution:
121
        w, h = target_resolution
122
        frame_h, frame_w, _ = frames[0].shape
123
        if w == -1:
124
            w = int(h / frame_h * frame_w)
125
        if h == -1:
126
            h = int(w / frame_w * frame_h)
127
        frames = [cv2.resize(f, (w, h)) for f in frames]
128
129
    textsize = cv2.getTextSize(label, cv2.FONT_HERSHEY_DUPLEX, font_scale,
130
                               1)[0]
131
    textheight = textsize[1]
132
    padding = 10
133
    location = (padding, padding + textheight)
134
135
    if isinstance(font_color, str):
136
        font_color = webcolors.name_to_rgb(font_color)[::-1]
137
138
    frames = [np.array(frame) for frame in frames]
139
    for frame in frames:
140
        cv2.putText(frame, label, location, cv2.FONT_HERSHEY_DUPLEX,
141
                    font_scale, font_color, 1)
142
143
    # RGB order
144
    frames = [x[..., ::-1] for x in frames]
145
    video_clips = ImageSequenceClip(frames, fps=fps)
146
147
    out_type = osp.splitext(out_filename)[1][1:]
148
    if out_type == 'gif':
149
        video_clips.write_gif(out_filename)
150
    else:
151
        video_clips.write_videofile(out_filename, remove_temp=True)
152
153
154
def main():
155
    args = parse_args()
156
    # assign the desired device.
157
    device = torch.device(args.device)
158
159
    cfg = Config.fromfile(args.config)
160
    cfg.merge_from_dict(args.cfg_options)
161
162
    # build the recognizer from a config file and checkpoint file/url
163
    model = init_recognizer(cfg, args.checkpoint, device=device)
164
165
    # e.g. use ('backbone', ) to return backbone feature
166
    output_layer_names = None
167
168
    # test a single video or rawframes of a single video
169
    if output_layer_names:
170
        results, returned_feature = inference_recognizer(
171
            model, args.video, outputs=output_layer_names)
172
    else:
173
        results = inference_recognizer(model, args.video)
174
175
    labels = open(args.label).readlines()
176
    labels = [x.strip() for x in labels]
177
    results = [(labels[k[0]], k[1]) for k in results]
178
179
    print('The top-5 labels with corresponding scores are:')
180
    for result in results:
181
        print(f'{result[0]}: ', result[1])
182
183
    if args.out_filename is not None:
184
185
        if args.target_resolution is not None:
186
            if args.target_resolution[0] == -1:
187
                assert isinstance(args.target_resolution[1], int)
188
                assert args.target_resolution[1] > 0
189
            if args.target_resolution[1] == -1:
190
                assert isinstance(args.target_resolution[0], int)
191
                assert args.target_resolution[0] > 0
192
            args.target_resolution = tuple(args.target_resolution)
193
194
        get_output(
195
            args.video,
196
            args.out_filename,
197
            results[0][0],
198
            fps=args.fps,
199
            font_scale=args.font_scale,
200
            font_color=args.font_color,
201
            target_resolution=args.target_resolution,
202
            resize_algorithm=args.resize_algorithm,
203
            use_frames=args.use_frames)
204
205
206
if __name__ == '__main__':
207
    main()