|
a |
|
b/utilities/runUtils.py |
|
|
1 |
import sys, os |
|
|
2 |
import random |
|
|
3 |
import numpy as np |
|
|
4 |
import torch |
|
|
5 |
|
|
|
6 |
|
|
|
7 |
def START_SEED(seed=73): |
|
|
8 |
np.random.seed(seed) |
|
|
9 |
random.seed(seed) |
|
|
10 |
torch.manual_seed(seed) |
|
|
11 |
torch.cuda.manual_seed(seed) |
|
|
12 |
torch.cuda.manual_seed_all(seed) |
|
|
13 |
torch.backends.cudnn.benchmark = False |
|
|
14 |
torch.backends.cudnn.deterministic = True |
|
|
15 |
|
|
|
16 |
|
|
|
17 |
##========================== WEIGTHS =========================================== |
|
|
18 |
|
|
|
19 |
def load_pretrained(model, weight_path, flexible = False): |
|
|
20 |
if not weight_path: |
|
|
21 |
print("No weight file to be loaded returning Model with Random weights") |
|
|
22 |
return model |
|
|
23 |
|
|
|
24 |
def _purge(key): # hardcoded logic |
|
|
25 |
return key.replace("backbone.", "") |
|
|
26 |
|
|
|
27 |
model_dict = model.state_dict() |
|
|
28 |
weight_dict = torch.load(weight_path) |
|
|
29 |
|
|
|
30 |
if 'model' in weight_dict.keys(): |
|
|
31 |
pretrain_dict = weight_dict['model'] |
|
|
32 |
else: |
|
|
33 |
pretrain_dict = weight_dict |
|
|
34 |
|
|
|
35 |
pretrain_dict = { _purge(k) : v for k, v in pretrain_dict.items()} |
|
|
36 |
|
|
|
37 |
if flexible: |
|
|
38 |
pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict} |
|
|
39 |
if not len(pretrain_dict): |
|
|
40 |
raise Exception(f"No weight names match to be loaded; though file exits ! {weight_path}, Dict: {weight_dict.keys()}") |
|
|
41 |
|
|
|
42 |
print(f"Pretrained layers:{pretrain_dict.keys()}") |
|
|
43 |
|
|
|
44 |
model_dict.update(pretrain_dict) |
|
|
45 |
model.load_state_dict(model_dict) |
|
|
46 |
|
|
|
47 |
return model |
|
|
48 |
|
|
|
49 |
|
|
|
50 |
|
|
|
51 |
def count_train_param(model): |
|
|
52 |
train_params_count = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
53 |
print('The model has {} trainable parameters'.format(train_params_count)) |
|
|
54 |
return train_params_count |
|
|
55 |
|
|
|
56 |
|
|
|
57 |
def freeze_params(model, exclusion_list = []): |
|
|
58 |
## TODO: Exclusion lists |
|
|
59 |
for param in model.parameters(): |
|
|
60 |
if param not in exclusion_list: |
|
|
61 |
param.requires_grad = False |
|
|
62 |
return model |
|
|
63 |
|
|
|
64 |
##============================================================================== |
|
|
65 |
|
|
|
66 |
|
|
|
67 |
class ObjDict(dict): |
|
|
68 |
""" |
|
|
69 |
reference: https://stackoverflow.com/a/32107024 |
|
|
70 |
""" |
|
|
71 |
def __init__(self, *args, **kwargs): |
|
|
72 |
super(ObjDict, self).__init__(*args, **kwargs) |
|
|
73 |
for arg in args: |
|
|
74 |
if isinstance(arg, dict): |
|
|
75 |
for k, v in arg.items(): |
|
|
76 |
self[k] = v |
|
|
77 |
if kwargs: |
|
|
78 |
for k, v in kwargs.items(): |
|
|
79 |
self[k] = v |
|
|
80 |
|
|
|
81 |
def __getattr__(self, attr): |
|
|
82 |
return self.get(attr) |
|
|
83 |
|
|
|
84 |
def __setattr__(self, key, value): |
|
|
85 |
self.__setitem__(key, value) |
|
|
86 |
|
|
|
87 |
def __setitem__(self, key, value): |
|
|
88 |
super(ObjDict, self).__setitem__(key, value) |
|
|
89 |
self.__dict__.update({key: value}) |
|
|
90 |
|
|
|
91 |
def __delattr__(self, item): |
|
|
92 |
self.__delitem__(item) |
|
|
93 |
|
|
|
94 |
def __delitem__(self, key): |
|
|
95 |
super(ObjDict, self).__delitem__(key) |
|
|
96 |
del self.__dict__[key] |