[6d389a]: / mmaction / apis / inference.py

Download this file

183 lines (160 with data), 7.0 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp
import re
import warnings
from operator import itemgetter
import mmcv
import numpy as np
import torch
from mmcv.parallel import collate, scatter
from mmcv.runner import load_checkpoint
from mmaction.core import OutputHook
from mmaction.datasets.pipelines import Compose
from mmaction.models import build_recognizer
def init_recognizer(config, checkpoint=None, device='cuda:0', **kwargs):
"""Initialize a recognizer from config file.
Args:
config (str | :obj:`mmcv.Config`): Config file path or the config
object.
checkpoint (str | None, optional): Checkpoint path/url. If set to None,
the model will not load any weights. Default: None.
device (str | :obj:`torch.device`): The desired device of returned
tensor. Default: 'cuda:0'.
Returns:
nn.Module: The constructed recognizer.
"""
if 'use_frames' in kwargs:
warnings.warn('The argument `use_frames` is deprecated PR #1191. '
'Now you can use models trained with frames or videos '
'arbitrarily. ')
if isinstance(config, str):
config = mmcv.Config.fromfile(config)
elif not isinstance(config, mmcv.Config):
raise TypeError('config must be a filename or Config object, '
f'but got {type(config)}')
# pretrained model is unnecessary since we directly load checkpoint later
config.model.backbone.pretrained = None
model = build_recognizer(config.model, test_cfg=config.get('test_cfg'))
if checkpoint is not None:
load_checkpoint(model, checkpoint, map_location='cpu')
model.cfg = config
model.to(device)
model.eval()
return model
def inference_recognizer(model, video, outputs=None, as_tensor=True, **kwargs):
"""Inference a video with the recognizer.
Args:
model (nn.Module): The loaded recognizer.
video (str | dict | ndarray): The video file path / url or the
rawframes directory path / results dictionary (the input of
pipeline) / a 4D array T x H x W x 3 (The input video).
outputs (list(str) | tuple(str) | str | None) : Names of layers whose
outputs need to be returned, default: None.
as_tensor (bool): Same as that in ``OutputHook``. Default: True.
Returns:
dict[tuple(str, float)]: Top-5 recognition result dict.
dict[torch.tensor | np.ndarray]:
Output feature maps from layers specified in `outputs`.
"""
if 'use_frames' in kwargs:
warnings.warn('The argument `use_frames` is deprecated PR #1191. '
'Now you can use models trained with frames or videos '
'arbitrarily. ')
if 'label_path' in kwargs:
warnings.warn('The argument `use_frames` is deprecated PR #1191. '
'Now the label file is not needed in '
'inference_recognizer. ')
input_flag = None
if isinstance(video, dict):
input_flag = 'dict'
elif isinstance(video, np.ndarray):
assert len(video.shape) == 4, 'The shape should be T x H x W x C'
input_flag = 'array'
elif isinstance(video, str) and video.startswith('http'):
input_flag = 'video'
elif isinstance(video, str) and osp.exists(video):
if osp.isfile(video):
input_flag = 'video'
if osp.isdir(video):
input_flag = 'rawframes'
else:
raise RuntimeError('The type of argument video is not supported: '
f'{type(video)}')
if isinstance(outputs, str):
outputs = (outputs, )
assert outputs is None or isinstance(outputs, (tuple, list))
cfg = model.cfg
device = next(model.parameters()).device # model device
# build the data pipeline
test_pipeline = cfg.data.test.pipeline
# Alter data pipelines & prepare inputs
if input_flag == 'dict':
data = video
if input_flag == 'array':
modality_map = {2: 'Flow', 3: 'RGB'}
modality = modality_map.get(video.shape[-1])
data = dict(
total_frames=video.shape[0],
label=-1,
start_index=0,
array=video,
modality=modality)
for i in range(len(test_pipeline)):
if 'Decode' in test_pipeline[i]['type']:
test_pipeline[i] = dict(type='ArrayDecode')
if input_flag == 'video':
data = dict(filename=video, label=-1, start_index=0, modality='RGB')
if 'Init' not in test_pipeline[0]['type']:
test_pipeline = [dict(type='OpenCVInit')] + test_pipeline
else:
test_pipeline[0] = dict(type='OpenCVInit')
for i in range(len(test_pipeline)):
if 'Decode' in test_pipeline[i]['type']:
test_pipeline[i] = dict(type='OpenCVDecode')
if input_flag == 'rawframes':
filename_tmpl = cfg.data.test.get('filename_tmpl', 'img_{:05}.jpg')
modality = cfg.data.test.get('modality', 'RGB')
start_index = cfg.data.test.get('start_index', 1)
# count the number of frames that match the format of `filename_tmpl`
# RGB pattern example: img_{:05}.jpg -> ^img_\d+.jpg$
# Flow patteren example: {}_{:05d}.jpg -> ^x_\d+.jpg$
pattern = f'^{filename_tmpl}$'
if modality == 'Flow':
pattern = pattern.replace('{}', 'x')
pattern = pattern.replace(
pattern[pattern.find('{'):pattern.find('}') + 1], '\\d+')
total_frames = len(
list(
filter(lambda x: re.match(pattern, x) is not None,
os.listdir(video))))
data = dict(
frame_dir=video,
total_frames=total_frames,
label=-1,
start_index=start_index,
filename_tmpl=filename_tmpl,
modality=modality)
if 'Init' in test_pipeline[0]['type']:
test_pipeline = test_pipeline[1:]
for i in range(len(test_pipeline)):
if 'Decode' in test_pipeline[i]['type']:
test_pipeline[i] = dict(type='RawFrameDecode')
test_pipeline = Compose(test_pipeline)
data = test_pipeline(data)
data = collate([data], samples_per_gpu=1)
if next(model.parameters()).is_cuda:
# scatter to specified GPU
data = scatter(data, [device])[0]
# forward the model
with OutputHook(model, outputs=outputs, as_tensor=as_tensor) as h:
with torch.no_grad():
scores = model(return_loss=False, **data)[0]
returned_features = h.layer_outputs if outputs else None
num_classes = scores.shape[-1]
score_tuples = tuple(zip(range(num_classes), scores))
score_sorted = sorted(score_tuples, key=itemgetter(1), reverse=True)
top5_label = score_sorted[:5]
if outputs:
return top5_label, returned_features
return top5_label