--- a +++ b/mmaction/utils/module_hooks.py @@ -0,0 +1,88 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmcv.utils import Registry, build_from_cfg + +MODULE_HOOKS = Registry('module_hooks') + + +def register_module_hooks(Module, module_hooks_list): + handles = [] + for module_hook_cfg in module_hooks_list: + hooked_module_name = module_hook_cfg.pop('hooked_module', 'backbone') + if not hasattr(Module, hooked_module_name): + raise ValueError( + f'{Module.__class__} has no {hooked_module_name}!') + hooked_module = getattr(Module, hooked_module_name) + hook_pos = module_hook_cfg.pop('hook_pos', 'forward_pre') + + if hook_pos == 'forward_pre': + handle = hooked_module.register_forward_pre_hook( + build_from_cfg(module_hook_cfg, MODULE_HOOKS).hook_func()) + elif hook_pos == 'forward': + handle = hooked_module.register_forward_hook( + build_from_cfg(module_hook_cfg, MODULE_HOOKS).hook_func()) + elif hook_pos == 'backward': + handle = hooked_module.register_backward_hook( + build_from_cfg(module_hook_cfg, MODULE_HOOKS).hook_func()) + else: + raise ValueError( + f'hook_pos must be `forward_pre`, `forward` or `backward`, ' + f'but get {hook_pos}') + handles.append(handle) + return handles + + +@MODULE_HOOKS.register_module() +class GPUNormalize: + """Normalize images with the given mean and std value on GPUs. + + Call the member function ``hook_func`` will return the forward pre-hook + function for module registration. + + GPU normalization, rather than CPU normalization, is more recommended in + the case of a model running on GPUs with strong compute capacity such as + Tesla V100. + + Args: + mean (Sequence[float]): Mean values of different channels. + std (Sequence[float]): Std values of different channels. + """ + + def __init__(self, input_format, mean, std): + if input_format not in ['NCTHW', 'NCHW', 'NCHW_Flow', 'NPTCHW']: + raise ValueError(f'The input format {input_format} is invalid.') + self.input_format = input_format + _mean = torch.tensor(mean) + _std = torch.tensor(std) + if input_format == 'NCTHW': + self._mean = _mean[None, :, None, None, None] + self._std = _std[None, :, None, None, None] + elif input_format == 'NCHW': + self._mean = _mean[None, :, None, None] + self._std = _std[None, :, None, None] + elif input_format == 'NCHW_Flow': + self._mean = _mean[None, :, None, None] + self._std = _std[None, :, None, None] + elif input_format == 'NPTCHW': + self._mean = _mean[None, None, None, :, None, None] + self._std = _std[None, None, None, :, None, None] + else: + raise ValueError(f'The input format {input_format} is invalid.') + + def hook_func(self): + + def normalize_hook(Module, input): + x = input[0] + assert x.dtype == torch.uint8, ( + f'The previous augmentation should use uint8 data type to ' + f'speed up computation, but get {x.dtype}') + + mean = self._mean.to(x.device) + std = self._std.to(x.device) + + with torch.no_grad(): + x = x.float().sub_(mean).div_(std) + + return (x, *input[1:]) + + return normalize_hook