a b/opengait/utils/common.py
1
import copy
2
import os
3
import inspect
4
import logging
5
import torch
6
import numpy as np
7
import torch.nn as nn
8
import torch.autograd as autograd
9
import yaml
10
import random
11
from torch.nn.parallel import DistributedDataParallel as DDP
12
from collections import OrderedDict, namedtuple
13
14
15
class NoOp:
16
    def __getattr__(self, *args):
17
        def no_op(*args, **kwargs): pass
18
        return no_op
19
20
21
class Odict(OrderedDict):
22
    def append(self, odict):
23
        dst_keys = self.keys()
24
        for k, v in odict.items():
25
            if not is_list(v):
26
                v = [v]
27
            if k in dst_keys:
28
                if is_list(self[k]):
29
                    self[k] += v
30
                else:
31
                    self[k] = [self[k]] + v
32
            else:
33
                self[k] = v
34
35
36
def Ntuple(description, keys, values):
37
    if not is_list_or_tuple(keys):
38
        keys = [keys]
39
        values = [values]
40
    Tuple = namedtuple(description, keys)
41
    return Tuple._make(values)
42
43
44
def get_valid_args(obj, input_args, free_keys=[]):
45
    if inspect.isfunction(obj):
46
        expected_keys = inspect.getfullargspec(obj)[0]
47
    elif inspect.isclass(obj):
48
        expected_keys = inspect.getfullargspec(obj.__init__)[0]
49
    else:
50
        raise ValueError('Just support function and class object!')
51
    unexpect_keys = list()
52
    expected_args = {}
53
    for k, v in input_args.items():
54
        if k in expected_keys:
55
            expected_args[k] = v
56
        elif k in free_keys:
57
            pass
58
        else:
59
            unexpect_keys.append(k)
60
    if unexpect_keys != []:
61
        logging.info("Find Unexpected Args(%s) in the Configuration of - %s -" %
62
                     (', '.join(unexpect_keys), obj.__name__))
63
    return expected_args
64
65
66
def get_attr_from(sources, name):
67
    try:
68
        return getattr(sources[0], name)
69
    except:
70
        return get_attr_from(sources[1:], name) if len(sources) > 1 else getattr(sources[0], name)
71
72
73
def is_list_or_tuple(x):
74
    return isinstance(x, (list, tuple))
75
76
77
def is_bool(x):
78
    return isinstance(x, bool)
79
80
81
def is_str(x):
82
    return isinstance(x, str)
83
84
85
def is_list(x):
86
    return isinstance(x, list) or isinstance(x, nn.ModuleList)
87
88
89
def is_dict(x):
90
    return isinstance(x, dict) or isinstance(x, OrderedDict) or isinstance(x, Odict)
91
92
93
def is_tensor(x):
94
    return isinstance(x, torch.Tensor)
95
96
97
def is_array(x):
98
    return isinstance(x, np.ndarray)
99
100
101
def ts2np(x):
102
    return x.cpu().data.numpy()
103
104
105
def ts2var(x, **kwargs):
106
    return autograd.Variable(x, **kwargs).cuda()
107
108
109
def np2var(x, **kwargs):
110
    return ts2var(torch.from_numpy(x), **kwargs)
111
112
113
def list2var(x, **kwargs):
114
    return np2var(np.array(x), **kwargs)
115
116
117
def mkdir(path):
118
    if not os.path.exists(path):
119
        os.makedirs(path)
120
121
122
def MergeCfgsDict(src, dst):
123
    for k, v in src.items():
124
        if (k not in dst.keys()) or (type(v) != type(dict())):
125
            dst[k] = v
126
        else:
127
            if is_dict(src[k]) and is_dict(dst[k]):
128
                MergeCfgsDict(src[k], dst[k])
129
            else:
130
                dst[k] = v
131
132
133
def clones(module, N):
134
    "Produce N identical layers."
135
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
136
137
138
def config_loader(path):
139
    with open(path, 'r') as stream:
140
        src_cfgs = yaml.safe_load(stream)
141
    with open("./configs/default.yaml", 'r') as stream:
142
        dst_cfgs = yaml.safe_load(stream)
143
    MergeCfgsDict(src_cfgs, dst_cfgs)
144
    return dst_cfgs
145
146
147
def init_seeds(seed=0, cuda_deterministic=True):
148
    random.seed(seed)
149
    np.random.seed(seed)
150
    torch.manual_seed(seed)
151
    torch.cuda.manual_seed_all(seed)
152
    # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
153
    if cuda_deterministic:  # slower, more reproducible
154
        torch.backends.cudnn.deterministic = True
155
        torch.backends.cudnn.benchmark = False
156
    else:  # faster, less reproducible
157
        torch.backends.cudnn.deterministic = False
158
        torch.backends.cudnn.benchmark = True
159
160
161
def handler(signum, frame):
162
    logging.info('Ctrl+c/z pressed')
163
    os.system(
164
        "kill $(ps aux | grep main.py | grep -v grep | awk '{print $2}') ")
165
    logging.info('process group flush!')
166
167
168
def ddp_all_gather(features, dim=0, requires_grad=True):
169
    '''
170
        inputs: [n, ...]
171
    '''
172
173
    world_size = torch.distributed.get_world_size()
174
    rank = torch.distributed.get_rank()
175
    feature_list = [torch.ones_like(features) for _ in range(world_size)]
176
    torch.distributed.all_gather(feature_list, features.contiguous())
177
178
    if requires_grad:
179
        feature_list[rank] = features
180
    feature = torch.cat(feature_list, dim=dim)
181
    return feature
182
183
184
# https://github.com/pytorch/pytorch/issues/16885
185
class DDPPassthrough(DDP):
186
    def __getattr__(self, name):
187
        try:
188
            return super().__getattr__(name)
189
        except AttributeError:
190
            return getattr(self.module, name)
191
192
193
def get_ddp_module(module, find_unused_parameters=False, **kwargs):
194
    if len(list(module.parameters())) == 0:
195
        # for the case that loss module has not parameters.
196
        return module
197
    device = torch.cuda.current_device()
198
    module = DDPPassthrough(module, device_ids=[device], output_device=device,
199
                            find_unused_parameters=find_unused_parameters, **kwargs)
200
    return module
201
202
203
def params_count(net):
204
    n_parameters = sum(p.numel() for p in net.parameters())
205
    return 'Parameters Count: {:.5f}M'.format(n_parameters / 1e6)