[6d389a]: / mmaction / core / hooks / output.py

Download this file

69 lines (51 with data), 2.1 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# Copyright (c) OpenMMLab. All rights reserved.
import functools
import warnings
import torch
class OutputHook:
"""Output feature map of some layers.
Args:
module (nn.Module): The whole module to get layers.
outputs (tuple[str] | list[str]): Layer name to output. Default: None.
as_tensor (bool): Determine to return a tensor or a numpy array.
Default: False.
"""
def __init__(self, module, outputs=None, as_tensor=False):
self.outputs = outputs
self.as_tensor = as_tensor
self.layer_outputs = {}
self.handles = []
self.register(module)
def register(self, module):
def hook_wrapper(name):
def hook(model, input, output):
if not isinstance(output, torch.Tensor):
warnings.warn(f'Directly return the output from {name}, '
f'since it is not a tensor')
self.layer_outputs[name] = output
elif self.as_tensor:
self.layer_outputs[name] = output
else:
self.layer_outputs[name] = output.detach().cpu().numpy()
return hook
if isinstance(self.outputs, (list, tuple)):
for name in self.outputs:
try:
layer = rgetattr(module, name)
h = layer.register_forward_hook(hook_wrapper(name))
except AttributeError:
raise AttributeError(f'Module {name} not found')
self.handles.append(h)
def remove(self):
for h in self.handles:
h.remove()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.remove()
# using wonder's beautiful simplification:
# https://stackoverflow.com/questions/31174295/getattr-and-setattr-on-nested-objects
def rgetattr(obj, attr, *args):
def _getattr(obj, attr):
return getattr(obj, attr, *args)
return functools.reduce(_getattr, [obj] + attr.split('.'))