Diff of /demo/demo_gradcam.py [000000] .. [6d389a]

Switch to unified view

a b/demo/demo_gradcam.py
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import argparse
3
import os
4
import os.path as osp
5
6
import mmcv
7
import numpy as np
8
import torch
9
from mmcv import Config, DictAction
10
from mmcv.parallel import collate, scatter
11
12
from mmaction.apis import init_recognizer
13
from mmaction.datasets.pipelines import Compose
14
from mmaction.utils import GradCAM
15
16
17
def parse_args():
18
    parser = argparse.ArgumentParser(description='MMAction2 GradCAM demo')
19
20
    parser.add_argument('config', help='test config file path')
21
    parser.add_argument('checkpoint', help='checkpoint file/url')
22
    parser.add_argument('video', help='video file/url or rawframes directory')
23
    parser.add_argument(
24
        '--use-frames',
25
        default=False,
26
        action='store_true',
27
        help='whether to use rawframes as input')
28
    parser.add_argument(
29
        '--device', type=str, default='cuda:0', help='CPU/CUDA device option')
30
    parser.add_argument(
31
        '--target-layer-name',
32
        type=str,
33
        default='backbone/layer4/1/relu',
34
        help='GradCAM target layer name')
35
    parser.add_argument('--out-filename', default=None, help='output filename')
36
    parser.add_argument('--fps', default=5, type=int)
37
    parser.add_argument(
38
        '--cfg-options',
39
        nargs='+',
40
        action=DictAction,
41
        default={},
42
        help='override some settings in the used config, the key-value pair '
43
        'in xxx=yyy format will be merged into config file. For example, '
44
        "'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'")
45
    parser.add_argument(
46
        '--target-resolution',
47
        nargs=2,
48
        default=None,
49
        type=int,
50
        help='Target resolution (w, h) for resizing the frames when using a '
51
        'video as input. If either dimension is set to -1, the frames are '
52
        'resized by keeping the existing aspect ratio')
53
    parser.add_argument(
54
        '--resize-algorithm',
55
        default='bilinear',
56
        help='resize algorithm applied to generate video & gif')
57
58
    args = parser.parse_args()
59
    return args
60
61
62
def build_inputs(model, video_path, use_frames=False):
63
    """build inputs for GradCAM.
64
65
    Note that, building inputs for GradCAM is exactly the same as building
66
    inputs for Recognizer test stage. Codes from `inference_recognizer`.
67
68
    Args:
69
        model (nn.Module): Recognizer model.
70
        video_path (str): video file/url or rawframes directory.
71
        use_frames (bool): whether to use rawframes as input.
72
    Returns:
73
        dict: Both GradCAM inputs and Recognizer test stage inputs,
74
            including two keys, ``imgs`` and ``label``.
75
    """
76
    if not (osp.exists(video_path) or video_path.startswith('http')):
77
        raise RuntimeError(f"'{video_path}' is missing")
78
79
    if osp.isfile(video_path) and use_frames:
80
        raise RuntimeError(
81
            f"'{video_path}' is a video file, not a rawframe directory")
82
    if osp.isdir(video_path) and not use_frames:
83
        raise RuntimeError(
84
            f"'{video_path}' is a rawframe directory, not a video file")
85
86
    cfg = model.cfg
87
    device = next(model.parameters()).device  # model device
88
89
    #print(model.state_dict().keys())
90
91
    # build the data pipeline
92
    test_pipeline = cfg.data.test.pipeline
93
    test_pipeline = Compose(test_pipeline)
94
    # prepare data
95
    if use_frames:
96
        filename_tmpl = cfg.data.test.get('filename_tmpl', 'img_{:05}.jpg')
97
        modality = cfg.data.test.get('modality', 'RGB')
98
        start_index = cfg.data.test.get('start_index', 0)
99
        data = dict(
100
            frame_dir=video_path,
101
            total_frames=len(os.listdir(video_path)),
102
            label=-1,
103
            start_index=start_index,
104
            filename_tmpl=filename_tmpl,
105
            modality=modality)
106
    else:
107
        start_index = cfg.data.test.get('start_index', 0)
108
        data = dict(
109
            filename=video_path,
110
            label=-1,
111
            start_index=start_index,
112
            modality='RGB')
113
    data = test_pipeline(data)
114
    data = collate([data], samples_per_gpu=1)
115
    if next(model.parameters()).is_cuda:
116
        # scatter to specified GPU
117
        data = scatter(data, [device])[0]
118
119
    return data
120
121
122
def _resize_frames(frame_list,
123
                   scale,
124
                   keep_ratio=True,
125
                   interpolation='bilinear'):
126
    """resize frames according to given scale.
127
128
    Codes are modified from `mmaction2/datasets/pipelines/augmentation.py`,
129
    `Resize` class.
130
131
    Args:
132
        frame_list (list[np.ndarray]): frames to be resized.
133
        scale (tuple[int]): If keep_ratio is True, it serves as scaling
134
            factor or maximum size: the image will be rescaled as large
135
            as possible within the scale. Otherwise, it serves as (w, h)
136
            of output size.
137
        keep_ratio (bool): If set to True, Images will be resized without
138
            changing the aspect ratio. Otherwise, it will resize images to a
139
            given size. Default: True.
140
        interpolation (str): Algorithm used for interpolation:
141
            "nearest" | "bilinear". Default: "bilinear".
142
    Returns:
143
        list[np.ndarray]: Both GradCAM and Recognizer test stage inputs,
144
            including two keys, ``imgs`` and ``label``.
145
    """
146
    if scale is None or (scale[0] == -1 and scale[1] == -1):
147
        return frame_list
148
    scale = tuple(scale)
149
    max_long_edge = max(scale)
150
    max_short_edge = min(scale)
151
    if max_short_edge == -1:
152
        scale = (np.inf, max_long_edge)
153
154
    img_h, img_w, _ = frame_list[0].shape
155
156
    if keep_ratio:
157
        new_w, new_h = mmcv.rescale_size((img_w, img_h), scale)
158
    else:
159
        new_w, new_h = scale
160
161
    frame_list = [
162
        mmcv.imresize(img, (new_w, new_h), interpolation=interpolation)
163
        for img in frame_list
164
    ]
165
166
    return frame_list
167
168
169
def main():
170
    args = parse_args()
171
172
    # assign the desired device.
173
    device = torch.device(args.device)
174
175
    cfg = Config.fromfile(args.config)
176
    cfg.merge_from_dict(args.cfg_options)
177
178
    # build the recognizer from a config file and checkpoint file/url
179
    model = init_recognizer(cfg, args.checkpoint, device=device)
180
    #print(model.state_dict().keys())
181
182
    inputs = build_inputs(model, args.video, use_frames=args.use_frames)
183
    gradcam = GradCAM(model, args.target_layer_name)
184
    results = gradcam(inputs)  #
185
186
    if args.out_filename is not None:
187
        try:
188
            from moviepy.editor import ImageSequenceClip
189
        except ImportError:
190
            raise ImportError('Please install moviepy to enable output file.')
191
192
        # frames_batches shape [B, T, H, W, 3], in RGB order
193
        frames_batches = (results[0] * 255.).numpy().astype(np.uint8)
194
        frames = frames_batches.reshape(-1, *frames_batches.shape[-3:])
195
196
        frame_list = list(frames)
197
        frame_list = _resize_frames(
198
            frame_list,
199
            args.target_resolution,
200
            interpolation=args.resize_algorithm)
201
202
        video_clips = ImageSequenceClip(frame_list, fps=args.fps)
203
        out_type = osp.splitext(args.out_filename)[1][1:]
204
        if out_type == 'gif':
205
            video_clips.write_gif(args.out_filename)
206
        else:
207
            video_clips.write_videofile(args.out_filename, remove_temp=True)
208
209
210
if __name__ == '__main__':
211
    main()