Diff of /demo/webcam_demo.py [000000] .. [6d389a]

Switch to unified view

a b/demo/webcam_demo.py
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import argparse
3
import time
4
from collections import deque
5
from operator import itemgetter
6
from threading import Thread
7
8
import cv2
9
import numpy as np
10
import torch
11
from mmcv import Config, DictAction
12
from mmcv.parallel import collate, scatter
13
14
from mmaction.apis import init_recognizer
15
from mmaction.datasets.pipelines import Compose
16
17
FONTFACE = cv2.FONT_HERSHEY_COMPLEX_SMALL
18
FONTSCALE = 1
19
FONTCOLOR = (255, 255, 255)  # BGR, white
20
MSGCOLOR = (128, 128, 128)  # BGR, gray
21
THICKNESS = 1
22
LINETYPE = 1
23
24
EXCLUED_STEPS = [
25
    'OpenCVInit', 'OpenCVDecode', 'DecordInit', 'DecordDecode', 'PyAVInit',
26
    'PyAVDecode', 'RawFrameDecode'
27
]
28
29
30
def parse_args():
31
    parser = argparse.ArgumentParser(description='MMAction2 webcam demo')
32
    parser.add_argument('config', help='test config file path')
33
    parser.add_argument('checkpoint', help='checkpoint file')
34
    parser.add_argument('label', help='label file')
35
    parser.add_argument(
36
        '--device', type=str, default='cuda:0', help='CPU/CUDA device option')
37
    parser.add_argument(
38
        '--camera-id', type=int, default=0, help='camera device id')
39
    parser.add_argument(
40
        '--threshold',
41
        type=float,
42
        default=0.01,
43
        help='recognition score threshold')
44
    parser.add_argument(
45
        '--average-size',
46
        type=int,
47
        default=1,
48
        help='number of latest clips to be averaged for prediction')
49
    parser.add_argument(
50
        '--drawing-fps',
51
        type=int,
52
        default=20,
53
        help='Set upper bound FPS value of the output drawing')
54
    parser.add_argument(
55
        '--inference-fps',
56
        type=int,
57
        default=4,
58
        help='Set upper bound FPS value of model inference')
59
    parser.add_argument(
60
        '--cfg-options',
61
        nargs='+',
62
        action=DictAction,
63
        default={},
64
        help='override some settings in the used config, the key-value pair '
65
        'in xxx=yyy format will be merged into config file. For example, '
66
        "'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'")
67
    args = parser.parse_args()
68
    assert args.drawing_fps >= 0 and args.inference_fps >= 0, \
69
        'upper bound FPS value of drawing and inference should be set as ' \
70
        'positive number, or zero for no limit'
71
    return args
72
73
74
def show_results():
75
    print('Press "Esc", "q" or "Q" to exit')
76
77
    text_info = {}
78
    cur_time = time.time()
79
    while True:
80
        msg = 'Waiting for action ...'
81
        _, frame = camera.read()
82
        frame_queue.append(np.array(frame[:, :, ::-1]))
83
84
        if len(result_queue) != 0:
85
            text_info = {}
86
            results = result_queue.popleft()
87
            for i, result in enumerate(results):
88
                selected_label, score = result
89
                if score < threshold:
90
                    break
91
                location = (0, 40 + i * 20)
92
                text = selected_label + ': ' + str(round(score, 2))
93
                text_info[location] = text
94
                cv2.putText(frame, text, location, FONTFACE, FONTSCALE,
95
                            FONTCOLOR, THICKNESS, LINETYPE)
96
97
        elif len(text_info) != 0:
98
            for location, text in text_info.items():
99
                cv2.putText(frame, text, location, FONTFACE, FONTSCALE,
100
                            FONTCOLOR, THICKNESS, LINETYPE)
101
102
        else:
103
            cv2.putText(frame, msg, (0, 40), FONTFACE, FONTSCALE, MSGCOLOR,
104
                        THICKNESS, LINETYPE)
105
106
        cv2.imshow('camera', frame)
107
        ch = cv2.waitKey(1)
108
109
        if ch == 27 or ch == ord('q') or ch == ord('Q'):
110
            break
111
112
        if drawing_fps > 0:
113
            # add a limiter for actual drawing fps <= drawing_fps
114
            sleep_time = 1 / drawing_fps - (time.time() - cur_time)
115
            if sleep_time > 0:
116
                time.sleep(sleep_time)
117
            cur_time = time.time()
118
119
120
def inference():
121
    score_cache = deque()
122
    scores_sum = 0
123
    cur_time = time.time()
124
    while True:
125
        cur_windows = []
126
127
        while len(cur_windows) == 0:
128
            if len(frame_queue) == sample_length:
129
                cur_windows = list(np.array(frame_queue))
130
                if data['img_shape'] is None:
131
                    data['img_shape'] = frame_queue.popleft().shape[:2]
132
133
        cur_data = data.copy()
134
        cur_data['imgs'] = cur_windows
135
        cur_data = test_pipeline(cur_data)
136
        cur_data = collate([cur_data], samples_per_gpu=1)
137
        if next(model.parameters()).is_cuda:
138
            cur_data = scatter(cur_data, [device])[0]
139
140
        with torch.no_grad():
141
            scores = model(return_loss=False, **cur_data)[0]
142
143
        score_cache.append(scores)
144
        scores_sum += scores
145
146
        if len(score_cache) == average_size:
147
            scores_avg = scores_sum / average_size
148
            num_selected_labels = min(len(label), 5)
149
150
            scores_tuples = tuple(zip(label, scores_avg))
151
            scores_sorted = sorted(
152
                scores_tuples, key=itemgetter(1), reverse=True)
153
            results = scores_sorted[:num_selected_labels]
154
155
            result_queue.append(results)
156
            scores_sum -= score_cache.popleft()
157
158
        if inference_fps > 0:
159
            # add a limiter for actual inference fps <= inference_fps
160
            sleep_time = 1 / inference_fps - (time.time() - cur_time)
161
            if sleep_time > 0:
162
                time.sleep(sleep_time)
163
            cur_time = time.time()
164
165
    camera.release()
166
    cv2.destroyAllWindows()
167
168
169
def main():
170
    global frame_queue, camera, frame, results, threshold, sample_length, \
171
        data, test_pipeline, model, device, average_size, label, \
172
        result_queue, drawing_fps, inference_fps
173
174
    args = parse_args()
175
    average_size = args.average_size
176
    threshold = args.threshold
177
    drawing_fps = args.drawing_fps
178
    inference_fps = args.inference_fps
179
180
    device = torch.device(args.device)
181
182
    cfg = Config.fromfile(args.config)
183
    cfg.merge_from_dict(args.cfg_options)
184
185
    model = init_recognizer(cfg, args.checkpoint, device=device)
186
    camera = cv2.VideoCapture(args.camera_id)
187
    data = dict(img_shape=None, modality='RGB', label=-1)
188
189
    with open(args.label, 'r') as f:
190
        label = [line.strip() for line in f]
191
192
    # prepare test pipeline from non-camera pipeline
193
    cfg = model.cfg
194
    sample_length = 0
195
    pipeline = cfg.data.test.pipeline
196
    pipeline_ = pipeline.copy()
197
    for step in pipeline:
198
        if 'SampleFrames' in step['type']:
199
            sample_length = step['clip_len'] * step['num_clips']
200
            data['num_clips'] = step['num_clips']
201
            data['clip_len'] = step['clip_len']
202
            pipeline_.remove(step)
203
        if step['type'] in EXCLUED_STEPS:
204
            # remove step to decode frames
205
            pipeline_.remove(step)
206
    test_pipeline = Compose(pipeline_)
207
208
    assert sample_length > 0
209
210
    try:
211
        frame_queue = deque(maxlen=sample_length)
212
        result_queue = deque(maxlen=1)
213
        pw = Thread(target=show_results, args=(), daemon=True)
214
        pr = Thread(target=inference, args=(), daemon=True)
215
        pw.start()
216
        pr.start()
217
        pw.join()
218
    except KeyboardInterrupt:
219
        pass
220
221
222
if __name__ == '__main__':
223
    main()