|
a |
|
b/demo/webcam_demo.py |
|
|
1 |
# Copyright (c) OpenMMLab. All rights reserved. |
|
|
2 |
import argparse |
|
|
3 |
import time |
|
|
4 |
from collections import deque |
|
|
5 |
from operator import itemgetter |
|
|
6 |
from threading import Thread |
|
|
7 |
|
|
|
8 |
import cv2 |
|
|
9 |
import numpy as np |
|
|
10 |
import torch |
|
|
11 |
from mmcv import Config, DictAction |
|
|
12 |
from mmcv.parallel import collate, scatter |
|
|
13 |
|
|
|
14 |
from mmaction.apis import init_recognizer |
|
|
15 |
from mmaction.datasets.pipelines import Compose |
|
|
16 |
|
|
|
17 |
FONTFACE = cv2.FONT_HERSHEY_COMPLEX_SMALL |
|
|
18 |
FONTSCALE = 1 |
|
|
19 |
FONTCOLOR = (255, 255, 255) # BGR, white |
|
|
20 |
MSGCOLOR = (128, 128, 128) # BGR, gray |
|
|
21 |
THICKNESS = 1 |
|
|
22 |
LINETYPE = 1 |
|
|
23 |
|
|
|
24 |
EXCLUED_STEPS = [ |
|
|
25 |
'OpenCVInit', 'OpenCVDecode', 'DecordInit', 'DecordDecode', 'PyAVInit', |
|
|
26 |
'PyAVDecode', 'RawFrameDecode' |
|
|
27 |
] |
|
|
28 |
|
|
|
29 |
|
|
|
30 |
def parse_args(): |
|
|
31 |
parser = argparse.ArgumentParser(description='MMAction2 webcam demo') |
|
|
32 |
parser.add_argument('config', help='test config file path') |
|
|
33 |
parser.add_argument('checkpoint', help='checkpoint file') |
|
|
34 |
parser.add_argument('label', help='label file') |
|
|
35 |
parser.add_argument( |
|
|
36 |
'--device', type=str, default='cuda:0', help='CPU/CUDA device option') |
|
|
37 |
parser.add_argument( |
|
|
38 |
'--camera-id', type=int, default=0, help='camera device id') |
|
|
39 |
parser.add_argument( |
|
|
40 |
'--threshold', |
|
|
41 |
type=float, |
|
|
42 |
default=0.01, |
|
|
43 |
help='recognition score threshold') |
|
|
44 |
parser.add_argument( |
|
|
45 |
'--average-size', |
|
|
46 |
type=int, |
|
|
47 |
default=1, |
|
|
48 |
help='number of latest clips to be averaged for prediction') |
|
|
49 |
parser.add_argument( |
|
|
50 |
'--drawing-fps', |
|
|
51 |
type=int, |
|
|
52 |
default=20, |
|
|
53 |
help='Set upper bound FPS value of the output drawing') |
|
|
54 |
parser.add_argument( |
|
|
55 |
'--inference-fps', |
|
|
56 |
type=int, |
|
|
57 |
default=4, |
|
|
58 |
help='Set upper bound FPS value of model inference') |
|
|
59 |
parser.add_argument( |
|
|
60 |
'--cfg-options', |
|
|
61 |
nargs='+', |
|
|
62 |
action=DictAction, |
|
|
63 |
default={}, |
|
|
64 |
help='override some settings in the used config, the key-value pair ' |
|
|
65 |
'in xxx=yyy format will be merged into config file. For example, ' |
|
|
66 |
"'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'") |
|
|
67 |
args = parser.parse_args() |
|
|
68 |
assert args.drawing_fps >= 0 and args.inference_fps >= 0, \ |
|
|
69 |
'upper bound FPS value of drawing and inference should be set as ' \ |
|
|
70 |
'positive number, or zero for no limit' |
|
|
71 |
return args |
|
|
72 |
|
|
|
73 |
|
|
|
74 |
def show_results(): |
|
|
75 |
print('Press "Esc", "q" or "Q" to exit') |
|
|
76 |
|
|
|
77 |
text_info = {} |
|
|
78 |
cur_time = time.time() |
|
|
79 |
while True: |
|
|
80 |
msg = 'Waiting for action ...' |
|
|
81 |
_, frame = camera.read() |
|
|
82 |
frame_queue.append(np.array(frame[:, :, ::-1])) |
|
|
83 |
|
|
|
84 |
if len(result_queue) != 0: |
|
|
85 |
text_info = {} |
|
|
86 |
results = result_queue.popleft() |
|
|
87 |
for i, result in enumerate(results): |
|
|
88 |
selected_label, score = result |
|
|
89 |
if score < threshold: |
|
|
90 |
break |
|
|
91 |
location = (0, 40 + i * 20) |
|
|
92 |
text = selected_label + ': ' + str(round(score, 2)) |
|
|
93 |
text_info[location] = text |
|
|
94 |
cv2.putText(frame, text, location, FONTFACE, FONTSCALE, |
|
|
95 |
FONTCOLOR, THICKNESS, LINETYPE) |
|
|
96 |
|
|
|
97 |
elif len(text_info) != 0: |
|
|
98 |
for location, text in text_info.items(): |
|
|
99 |
cv2.putText(frame, text, location, FONTFACE, FONTSCALE, |
|
|
100 |
FONTCOLOR, THICKNESS, LINETYPE) |
|
|
101 |
|
|
|
102 |
else: |
|
|
103 |
cv2.putText(frame, msg, (0, 40), FONTFACE, FONTSCALE, MSGCOLOR, |
|
|
104 |
THICKNESS, LINETYPE) |
|
|
105 |
|
|
|
106 |
cv2.imshow('camera', frame) |
|
|
107 |
ch = cv2.waitKey(1) |
|
|
108 |
|
|
|
109 |
if ch == 27 or ch == ord('q') or ch == ord('Q'): |
|
|
110 |
break |
|
|
111 |
|
|
|
112 |
if drawing_fps > 0: |
|
|
113 |
# add a limiter for actual drawing fps <= drawing_fps |
|
|
114 |
sleep_time = 1 / drawing_fps - (time.time() - cur_time) |
|
|
115 |
if sleep_time > 0: |
|
|
116 |
time.sleep(sleep_time) |
|
|
117 |
cur_time = time.time() |
|
|
118 |
|
|
|
119 |
|
|
|
120 |
def inference(): |
|
|
121 |
score_cache = deque() |
|
|
122 |
scores_sum = 0 |
|
|
123 |
cur_time = time.time() |
|
|
124 |
while True: |
|
|
125 |
cur_windows = [] |
|
|
126 |
|
|
|
127 |
while len(cur_windows) == 0: |
|
|
128 |
if len(frame_queue) == sample_length: |
|
|
129 |
cur_windows = list(np.array(frame_queue)) |
|
|
130 |
if data['img_shape'] is None: |
|
|
131 |
data['img_shape'] = frame_queue.popleft().shape[:2] |
|
|
132 |
|
|
|
133 |
cur_data = data.copy() |
|
|
134 |
cur_data['imgs'] = cur_windows |
|
|
135 |
cur_data = test_pipeline(cur_data) |
|
|
136 |
cur_data = collate([cur_data], samples_per_gpu=1) |
|
|
137 |
if next(model.parameters()).is_cuda: |
|
|
138 |
cur_data = scatter(cur_data, [device])[0] |
|
|
139 |
|
|
|
140 |
with torch.no_grad(): |
|
|
141 |
scores = model(return_loss=False, **cur_data)[0] |
|
|
142 |
|
|
|
143 |
score_cache.append(scores) |
|
|
144 |
scores_sum += scores |
|
|
145 |
|
|
|
146 |
if len(score_cache) == average_size: |
|
|
147 |
scores_avg = scores_sum / average_size |
|
|
148 |
num_selected_labels = min(len(label), 5) |
|
|
149 |
|
|
|
150 |
scores_tuples = tuple(zip(label, scores_avg)) |
|
|
151 |
scores_sorted = sorted( |
|
|
152 |
scores_tuples, key=itemgetter(1), reverse=True) |
|
|
153 |
results = scores_sorted[:num_selected_labels] |
|
|
154 |
|
|
|
155 |
result_queue.append(results) |
|
|
156 |
scores_sum -= score_cache.popleft() |
|
|
157 |
|
|
|
158 |
if inference_fps > 0: |
|
|
159 |
# add a limiter for actual inference fps <= inference_fps |
|
|
160 |
sleep_time = 1 / inference_fps - (time.time() - cur_time) |
|
|
161 |
if sleep_time > 0: |
|
|
162 |
time.sleep(sleep_time) |
|
|
163 |
cur_time = time.time() |
|
|
164 |
|
|
|
165 |
camera.release() |
|
|
166 |
cv2.destroyAllWindows() |
|
|
167 |
|
|
|
168 |
|
|
|
169 |
def main(): |
|
|
170 |
global frame_queue, camera, frame, results, threshold, sample_length, \ |
|
|
171 |
data, test_pipeline, model, device, average_size, label, \ |
|
|
172 |
result_queue, drawing_fps, inference_fps |
|
|
173 |
|
|
|
174 |
args = parse_args() |
|
|
175 |
average_size = args.average_size |
|
|
176 |
threshold = args.threshold |
|
|
177 |
drawing_fps = args.drawing_fps |
|
|
178 |
inference_fps = args.inference_fps |
|
|
179 |
|
|
|
180 |
device = torch.device(args.device) |
|
|
181 |
|
|
|
182 |
cfg = Config.fromfile(args.config) |
|
|
183 |
cfg.merge_from_dict(args.cfg_options) |
|
|
184 |
|
|
|
185 |
model = init_recognizer(cfg, args.checkpoint, device=device) |
|
|
186 |
camera = cv2.VideoCapture(args.camera_id) |
|
|
187 |
data = dict(img_shape=None, modality='RGB', label=-1) |
|
|
188 |
|
|
|
189 |
with open(args.label, 'r') as f: |
|
|
190 |
label = [line.strip() for line in f] |
|
|
191 |
|
|
|
192 |
# prepare test pipeline from non-camera pipeline |
|
|
193 |
cfg = model.cfg |
|
|
194 |
sample_length = 0 |
|
|
195 |
pipeline = cfg.data.test.pipeline |
|
|
196 |
pipeline_ = pipeline.copy() |
|
|
197 |
for step in pipeline: |
|
|
198 |
if 'SampleFrames' in step['type']: |
|
|
199 |
sample_length = step['clip_len'] * step['num_clips'] |
|
|
200 |
data['num_clips'] = step['num_clips'] |
|
|
201 |
data['clip_len'] = step['clip_len'] |
|
|
202 |
pipeline_.remove(step) |
|
|
203 |
if step['type'] in EXCLUED_STEPS: |
|
|
204 |
# remove step to decode frames |
|
|
205 |
pipeline_.remove(step) |
|
|
206 |
test_pipeline = Compose(pipeline_) |
|
|
207 |
|
|
|
208 |
assert sample_length > 0 |
|
|
209 |
|
|
|
210 |
try: |
|
|
211 |
frame_queue = deque(maxlen=sample_length) |
|
|
212 |
result_queue = deque(maxlen=1) |
|
|
213 |
pw = Thread(target=show_results, args=(), daemon=True) |
|
|
214 |
pr = Thread(target=inference, args=(), daemon=True) |
|
|
215 |
pw.start() |
|
|
216 |
pr.start() |
|
|
217 |
pw.join() |
|
|
218 |
except KeyboardInterrupt: |
|
|
219 |
pass |
|
|
220 |
|
|
|
221 |
|
|
|
222 |
if __name__ == '__main__': |
|
|
223 |
main() |