--- a
+++ b/mmaction/utils/module_hooks.py
@@ -0,0 +1,88 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from mmcv.utils import Registry, build_from_cfg
+
+MODULE_HOOKS = Registry('module_hooks')
+
+
+def register_module_hooks(Module, module_hooks_list):
+    handles = []
+    for module_hook_cfg in module_hooks_list:
+        hooked_module_name = module_hook_cfg.pop('hooked_module', 'backbone')
+        if not hasattr(Module, hooked_module_name):
+            raise ValueError(
+                f'{Module.__class__} has no {hooked_module_name}!')
+        hooked_module = getattr(Module, hooked_module_name)
+        hook_pos = module_hook_cfg.pop('hook_pos', 'forward_pre')
+
+        if hook_pos == 'forward_pre':
+            handle = hooked_module.register_forward_pre_hook(
+                build_from_cfg(module_hook_cfg, MODULE_HOOKS).hook_func())
+        elif hook_pos == 'forward':
+            handle = hooked_module.register_forward_hook(
+                build_from_cfg(module_hook_cfg, MODULE_HOOKS).hook_func())
+        elif hook_pos == 'backward':
+            handle = hooked_module.register_backward_hook(
+                build_from_cfg(module_hook_cfg, MODULE_HOOKS).hook_func())
+        else:
+            raise ValueError(
+                f'hook_pos must be `forward_pre`, `forward` or `backward`, '
+                f'but get {hook_pos}')
+        handles.append(handle)
+    return handles
+
+
+@MODULE_HOOKS.register_module()
+class GPUNormalize:
+    """Normalize images with the given mean and std value on GPUs.
+
+    Call the member function ``hook_func`` will return the forward pre-hook
+    function for module registration.
+
+    GPU normalization, rather than CPU normalization, is more recommended in
+    the case of a model running on GPUs with strong compute capacity such as
+    Tesla V100.
+
+    Args:
+        mean (Sequence[float]): Mean values of different channels.
+        std (Sequence[float]): Std values of different channels.
+    """
+
+    def __init__(self, input_format, mean, std):
+        if input_format not in ['NCTHW', 'NCHW', 'NCHW_Flow', 'NPTCHW']:
+            raise ValueError(f'The input format {input_format} is invalid.')
+        self.input_format = input_format
+        _mean = torch.tensor(mean)
+        _std = torch.tensor(std)
+        if input_format == 'NCTHW':
+            self._mean = _mean[None, :, None, None, None]
+            self._std = _std[None, :, None, None, None]
+        elif input_format == 'NCHW':
+            self._mean = _mean[None, :, None, None]
+            self._std = _std[None, :, None, None]
+        elif input_format == 'NCHW_Flow':
+            self._mean = _mean[None, :, None, None]
+            self._std = _std[None, :, None, None]
+        elif input_format == 'NPTCHW':
+            self._mean = _mean[None, None, None, :, None, None]
+            self._std = _std[None, None, None, :, None, None]
+        else:
+            raise ValueError(f'The input format {input_format} is invalid.')
+
+    def hook_func(self):
+
+        def normalize_hook(Module, input):
+            x = input[0]
+            assert x.dtype == torch.uint8, (
+                f'The previous augmentation should use uint8 data type to '
+                f'speed up computation, but get {x.dtype}')
+
+            mean = self._mean.to(x.device)
+            std = self._std.to(x.device)
+
+            with torch.no_grad():
+                x = x.float().sub_(mean).div_(std)
+
+            return (x, *input[1:])
+
+        return normalize_hook