--- a +++ b/utilities/runUtils.py @@ -0,0 +1,96 @@ +import sys, os +import random +import numpy as np +import torch + + +def START_SEED(seed=73): + np.random.seed(seed) + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + + +##========================== WEIGTHS =========================================== + +def load_pretrained(model, weight_path, flexible = False): + if not weight_path: + print("No weight file to be loaded returning Model with Random weights") + return model + + def _purge(key): # hardcoded logic + return key.replace("backbone.", "") + + model_dict = model.state_dict() + weight_dict = torch.load(weight_path) + + if 'model' in weight_dict.keys(): + pretrain_dict = weight_dict['model'] + else: + pretrain_dict = weight_dict + + pretrain_dict = { _purge(k) : v for k, v in pretrain_dict.items()} + + if flexible: + pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict} + if not len(pretrain_dict): + raise Exception(f"No weight names match to be loaded; though file exits ! {weight_path}, Dict: {weight_dict.keys()}") + + print(f"Pretrained layers:{pretrain_dict.keys()}") + + model_dict.update(pretrain_dict) + model.load_state_dict(model_dict) + + return model + + + +def count_train_param(model): + train_params_count = sum(p.numel() for p in model.parameters() if p.requires_grad) + print('The model has {} trainable parameters'.format(train_params_count)) + return train_params_count + + +def freeze_params(model, exclusion_list = []): + ## TODO: Exclusion lists + for param in model.parameters(): + if param not in exclusion_list: + param.requires_grad = False + return model + +##============================================================================== + + +class ObjDict(dict): + """ + reference: https://stackoverflow.com/a/32107024 + """ + def __init__(self, *args, **kwargs): + super(ObjDict, self).__init__(*args, **kwargs) + for arg in args: + if isinstance(arg, dict): + for k, v in arg.items(): + self[k] = v + if kwargs: + for k, v in kwargs.items(): + self[k] = v + + def __getattr__(self, attr): + return self.get(attr) + + def __setattr__(self, key, value): + self.__setitem__(key, value) + + def __setitem__(self, key, value): + super(ObjDict, self).__setitem__(key, value) + self.__dict__.update({key: value}) + + def __delattr__(self, item): + self.__delitem__(item) + + def __delitem__(self, key): + super(ObjDict, self).__delitem__(key) + del self.__dict__[key] \ No newline at end of file