[a18f15]: / utilities / runUtils.py

Download this file

96 lines (70 with data), 2.7 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import sys, os
import random
import numpy as np
import torch
def START_SEED(seed=73):
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
##========================== WEIGTHS ===========================================
def load_pretrained(model, weight_path, flexible = False):
if not weight_path:
print("No weight file to be loaded returning Model with Random weights")
return model
def _purge(key): # hardcoded logic
return key.replace("backbone.", "")
model_dict = model.state_dict()
weight_dict = torch.load(weight_path)
if 'model' in weight_dict.keys():
pretrain_dict = weight_dict['model']
else:
pretrain_dict = weight_dict
pretrain_dict = { _purge(k) : v for k, v in pretrain_dict.items()}
if flexible:
pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict}
if not len(pretrain_dict):
raise Exception(f"No weight names match to be loaded; though file exits ! {weight_path}, Dict: {weight_dict.keys()}")
print(f"Pretrained layers:{pretrain_dict.keys()}")
model_dict.update(pretrain_dict)
model.load_state_dict(model_dict)
return model
def count_train_param(model):
train_params_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('The model has {} trainable parameters'.format(train_params_count))
return train_params_count
def freeze_params(model, exclusion_list = []):
## TODO: Exclusion lists
for param in model.parameters():
if param not in exclusion_list:
param.requires_grad = False
return model
##==============================================================================
class ObjDict(dict):
"""
reference: https://stackoverflow.com/a/32107024
"""
def __init__(self, *args, **kwargs):
super(ObjDict, self).__init__(*args, **kwargs)
for arg in args:
if isinstance(arg, dict):
for k, v in arg.items():
self[k] = v
if kwargs:
for k, v in kwargs.items():
self[k] = v
def __getattr__(self, attr):
return self.get(attr)
def __setattr__(self, key, value):
self.__setitem__(key, value)
def __setitem__(self, key, value):
super(ObjDict, self).__setitem__(key, value)
self.__dict__.update({key: value})
def __delattr__(self, item):
self.__delitem__(item)
def __delitem__(self, key):
super(ObjDict, self).__delitem__(key)
del self.__dict__[key]