--- a +++ b/tools/deployment/mmaction_handler.py @@ -0,0 +1,79 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import base64 +import os +import os.path as osp +import warnings + +import decord +import numpy as np +import torch + +from mmaction.apis import inference_recognizer, init_recognizer # noqa: F401 + +try: + from ts.torch_handler.base_handler import BaseHandler +except ImportError: + raise ImportError('`ts` is required. Try: pip install ts.') + + +class MMActionHandler(BaseHandler): + + def initialize(self, context): + properties = context.system_properties + self.map_location = 'cuda' if torch.cuda.is_available() else 'cpu' + self.device = torch.device(self.map_location + ':' + + str(properties.get('gpu_id')) if torch.cuda. + is_available() else self.map_location) + self.manifest = context.manifest + + model_dir = properties.get('model_dir') + serialized_file = self.manifest['model']['serializedFile'] + checkpoint = os.path.join(model_dir, serialized_file) + self.config_file = os.path.join(model_dir, 'config.py') + + mapping_file_path = osp.join(model_dir, 'label_map.txt') + if not os.path.isfile(mapping_file_path): + warnings.warn('Missing the label_map.txt file. ' + 'Inference output will not include class name.') + self.mapping = None + else: + lines = open(mapping_file_path).readlines() + self.mapping = [x.strip() for x in lines] + + self.model = init_recognizer(self.config_file, checkpoint, self.device) + self.initialized = True + + def preprocess(self, data): + videos = [] + + for row in data: + video = row.get('data') or row.get('body') + if isinstance(video, str): + video = base64.b64decode(video) + # First save the bytes as a tmp file + with open('/tmp/tmp.mp4', 'wb') as fout: + fout.write(video) + + video = decord.VideoReader('/tmp/tmp.mp4') + frames = [x.asnumpy() for x in video] + videos.append(np.stack(frames)) + + return videos + + def inference(self, data, *args, **kwargs): + results = [inference_recognizer(self.model, item) for item in data] + return results + + def postprocess(self, data): + # Format output following the example ObjectDetectionHandler format + output = [] + for video_idx, video_result in enumerate(data): + output.append([]) + assert isinstance(video_result, list) + + output[video_idx] = { + self.mapping[x[0]] if self.mapping else x[0]: float(x[1]) + for x in video_result + } + + return output