Diff of /opengait/utils/common.py [000000] .. [fd9ef4]

Switch to side-by-side view

--- a
+++ b/opengait/utils/common.py
@@ -0,0 +1,205 @@
+import copy
+import os
+import inspect
+import logging
+import torch
+import numpy as np
+import torch.nn as nn
+import torch.autograd as autograd
+import yaml
+import random
+from torch.nn.parallel import DistributedDataParallel as DDP
+from collections import OrderedDict, namedtuple
+
+
+class NoOp:
+    def __getattr__(self, *args):
+        def no_op(*args, **kwargs): pass
+        return no_op
+
+
+class Odict(OrderedDict):
+    def append(self, odict):
+        dst_keys = self.keys()
+        for k, v in odict.items():
+            if not is_list(v):
+                v = [v]
+            if k in dst_keys:
+                if is_list(self[k]):
+                    self[k] += v
+                else:
+                    self[k] = [self[k]] + v
+            else:
+                self[k] = v
+
+
+def Ntuple(description, keys, values):
+    if not is_list_or_tuple(keys):
+        keys = [keys]
+        values = [values]
+    Tuple = namedtuple(description, keys)
+    return Tuple._make(values)
+
+
+def get_valid_args(obj, input_args, free_keys=[]):
+    if inspect.isfunction(obj):
+        expected_keys = inspect.getfullargspec(obj)[0]
+    elif inspect.isclass(obj):
+        expected_keys = inspect.getfullargspec(obj.__init__)[0]
+    else:
+        raise ValueError('Just support function and class object!')
+    unexpect_keys = list()
+    expected_args = {}
+    for k, v in input_args.items():
+        if k in expected_keys:
+            expected_args[k] = v
+        elif k in free_keys:
+            pass
+        else:
+            unexpect_keys.append(k)
+    if unexpect_keys != []:
+        logging.info("Find Unexpected Args(%s) in the Configuration of - %s -" %
+                     (', '.join(unexpect_keys), obj.__name__))
+    return expected_args
+
+
+def get_attr_from(sources, name):
+    try:
+        return getattr(sources[0], name)
+    except:
+        return get_attr_from(sources[1:], name) if len(sources) > 1 else getattr(sources[0], name)
+
+
+def is_list_or_tuple(x):
+    return isinstance(x, (list, tuple))
+
+
+def is_bool(x):
+    return isinstance(x, bool)
+
+
+def is_str(x):
+    return isinstance(x, str)
+
+
+def is_list(x):
+    return isinstance(x, list) or isinstance(x, nn.ModuleList)
+
+
+def is_dict(x):
+    return isinstance(x, dict) or isinstance(x, OrderedDict) or isinstance(x, Odict)
+
+
+def is_tensor(x):
+    return isinstance(x, torch.Tensor)
+
+
+def is_array(x):
+    return isinstance(x, np.ndarray)
+
+
+def ts2np(x):
+    return x.cpu().data.numpy()
+
+
+def ts2var(x, **kwargs):
+    return autograd.Variable(x, **kwargs).cuda()
+
+
+def np2var(x, **kwargs):
+    return ts2var(torch.from_numpy(x), **kwargs)
+
+
+def list2var(x, **kwargs):
+    return np2var(np.array(x), **kwargs)
+
+
+def mkdir(path):
+    if not os.path.exists(path):
+        os.makedirs(path)
+
+
+def MergeCfgsDict(src, dst):
+    for k, v in src.items():
+        if (k not in dst.keys()) or (type(v) != type(dict())):
+            dst[k] = v
+        else:
+            if is_dict(src[k]) and is_dict(dst[k]):
+                MergeCfgsDict(src[k], dst[k])
+            else:
+                dst[k] = v
+
+
+def clones(module, N):
+    "Produce N identical layers."
+    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
+
+
+def config_loader(path):
+    with open(path, 'r') as stream:
+        src_cfgs = yaml.safe_load(stream)
+    with open("./configs/default.yaml", 'r') as stream:
+        dst_cfgs = yaml.safe_load(stream)
+    MergeCfgsDict(src_cfgs, dst_cfgs)
+    return dst_cfgs
+
+
+def init_seeds(seed=0, cuda_deterministic=True):
+    random.seed(seed)
+    np.random.seed(seed)
+    torch.manual_seed(seed)
+    torch.cuda.manual_seed_all(seed)
+    # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
+    if cuda_deterministic:  # slower, more reproducible
+        torch.backends.cudnn.deterministic = True
+        torch.backends.cudnn.benchmark = False
+    else:  # faster, less reproducible
+        torch.backends.cudnn.deterministic = False
+        torch.backends.cudnn.benchmark = True
+
+
+def handler(signum, frame):
+    logging.info('Ctrl+c/z pressed')
+    os.system(
+        "kill $(ps aux | grep main.py | grep -v grep | awk '{print $2}') ")
+    logging.info('process group flush!')
+
+
+def ddp_all_gather(features, dim=0, requires_grad=True):
+    '''
+        inputs: [n, ...]
+    '''
+
+    world_size = torch.distributed.get_world_size()
+    rank = torch.distributed.get_rank()
+    feature_list = [torch.ones_like(features) for _ in range(world_size)]
+    torch.distributed.all_gather(feature_list, features.contiguous())
+
+    if requires_grad:
+        feature_list[rank] = features
+    feature = torch.cat(feature_list, dim=dim)
+    return feature
+
+
+# https://github.com/pytorch/pytorch/issues/16885
+class DDPPassthrough(DDP):
+    def __getattr__(self, name):
+        try:
+            return super().__getattr__(name)
+        except AttributeError:
+            return getattr(self.module, name)
+
+
+def get_ddp_module(module, find_unused_parameters=False, **kwargs):
+    if len(list(module.parameters())) == 0:
+        # for the case that loss module has not parameters.
+        return module
+    device = torch.cuda.current_device()
+    module = DDPPassthrough(module, device_ids=[device], output_device=device,
+                            find_unused_parameters=find_unused_parameters, **kwargs)
+    return module
+
+
+def params_count(net):
+    n_parameters = sum(p.numel() for p in net.parameters())
+    return 'Parameters Count: {:.5f}M'.format(n_parameters / 1e6)