Diff of /utilities/runUtils.py [000000] .. [a18f15]

Switch to unified view

a b/utilities/runUtils.py
1
import sys, os
2
import random
3
import numpy as np
4
import torch
5
6
7
def START_SEED(seed=73):
8
    np.random.seed(seed)
9
    random.seed(seed)
10
    torch.manual_seed(seed)
11
    torch.cuda.manual_seed(seed)
12
    torch.cuda.manual_seed_all(seed)
13
    torch.backends.cudnn.benchmark = False
14
    torch.backends.cudnn.deterministic = True
15
16
17
##========================== WEIGTHS ===========================================
18
19
def load_pretrained(model, weight_path, flexible = False):
20
    if not weight_path:
21
        print("No weight file to be loaded returning Model with Random weights")
22
        return model
23
24
    def _purge(key): # hardcoded logic
25
        return key.replace("backbone.", "")
26
27
    model_dict = model.state_dict()
28
    weight_dict = torch.load(weight_path)
29
30
    if 'model' in weight_dict.keys():
31
        pretrain_dict = weight_dict['model']
32
    else:
33
        pretrain_dict = weight_dict
34
35
    pretrain_dict = { _purge(k) : v for k, v in pretrain_dict.items()}
36
37
    if flexible:
38
        pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict}
39
    if not len(pretrain_dict):
40
        raise Exception(f"No weight names match to be loaded; though file exits ! {weight_path}, Dict: {weight_dict.keys()}")
41
42
    print(f"Pretrained layers:{pretrain_dict.keys()}")
43
44
    model_dict.update(pretrain_dict)
45
    model.load_state_dict(model_dict)
46
47
    return model
48
49
50
51
def count_train_param(model):
52
    train_params_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
53
    print('The model has {} trainable parameters'.format(train_params_count))
54
    return train_params_count
55
56
57
def freeze_params(model, exclusion_list = []):
58
    ## TODO: Exclusion lists
59
    for param in model.parameters():
60
        if param not in exclusion_list:
61
            param.requires_grad = False
62
    return model
63
64
##==============================================================================
65
66
67
class ObjDict(dict):
68
    """
69
    reference: https://stackoverflow.com/a/32107024
70
    """
71
    def __init__(self, *args, **kwargs):
72
        super(ObjDict, self).__init__(*args, **kwargs)
73
        for arg in args:
74
            if isinstance(arg, dict):
75
                for k, v in arg.items():
76
                    self[k] = v
77
        if kwargs:
78
            for k, v in kwargs.items():
79
                self[k] = v
80
81
    def __getattr__(self, attr):
82
        return self.get(attr)
83
84
    def __setattr__(self, key, value):
85
        self.__setitem__(key, value)
86
87
    def __setitem__(self, key, value):
88
        super(ObjDict, self).__setitem__(key, value)
89
        self.__dict__.update({key: value})
90
91
    def __delattr__(self, item):
92
        self.__delitem__(item)
93
94
    def __delitem__(self, key):
95
        super(ObjDict, self).__delitem__(key)
96
        del self.__dict__[key]