Diff of /utilities/runUtils.py [000000] .. [a18f15]

Switch to side-by-side view

--- a
+++ b/utilities/runUtils.py
@@ -0,0 +1,96 @@
+import sys, os
+import random
+import numpy as np
+import torch
+
+
+def START_SEED(seed=73):
+    np.random.seed(seed)
+    random.seed(seed)
+    torch.manual_seed(seed)
+    torch.cuda.manual_seed(seed)
+    torch.cuda.manual_seed_all(seed)
+    torch.backends.cudnn.benchmark = False
+    torch.backends.cudnn.deterministic = True
+
+
+##========================== WEIGTHS ===========================================
+
+def load_pretrained(model, weight_path, flexible = False):
+    if not weight_path:
+        print("No weight file to be loaded returning Model with Random weights")
+        return model
+
+    def _purge(key): # hardcoded logic
+        return key.replace("backbone.", "")
+
+    model_dict = model.state_dict()
+    weight_dict = torch.load(weight_path)
+
+    if 'model' in weight_dict.keys():
+        pretrain_dict = weight_dict['model']
+    else:
+        pretrain_dict = weight_dict
+
+    pretrain_dict = { _purge(k) : v for k, v in pretrain_dict.items()}
+
+    if flexible:
+        pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict}
+    if not len(pretrain_dict):
+        raise Exception(f"No weight names match to be loaded; though file exits ! {weight_path}, Dict: {weight_dict.keys()}")
+
+    print(f"Pretrained layers:{pretrain_dict.keys()}")
+
+    model_dict.update(pretrain_dict)
+    model.load_state_dict(model_dict)
+
+    return model
+
+
+
+def count_train_param(model):
+    train_params_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
+    print('The model has {} trainable parameters'.format(train_params_count))
+    return train_params_count
+
+
+def freeze_params(model, exclusion_list = []):
+    ## TODO: Exclusion lists
+    for param in model.parameters():
+        if param not in exclusion_list:
+            param.requires_grad = False
+    return model
+
+##==============================================================================
+
+
+class ObjDict(dict):
+    """
+    reference: https://stackoverflow.com/a/32107024
+    """
+    def __init__(self, *args, **kwargs):
+        super(ObjDict, self).__init__(*args, **kwargs)
+        for arg in args:
+            if isinstance(arg, dict):
+                for k, v in arg.items():
+                    self[k] = v
+        if kwargs:
+            for k, v in kwargs.items():
+                self[k] = v
+
+    def __getattr__(self, attr):
+        return self.get(attr)
+
+    def __setattr__(self, key, value):
+        self.__setitem__(key, value)
+
+    def __setitem__(self, key, value):
+        super(ObjDict, self).__setitem__(key, value)
+        self.__dict__.update({key: value})
+
+    def __delattr__(self, item):
+        self.__delitem__(item)
+
+    def __delitem__(self, key):
+        super(ObjDict, self).__delitem__(key)
+        del self.__dict__[key]
\ No newline at end of file