[bad60c]: / common / pytorch.py

Download this file

24 lines (18 with data), 797 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
def save_model(path, model):
# Save a trained model
print("** ** * Saving fine - tuned model ** ** * ")
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
output_model_file = path
torch.save(model_to_save.state_dict(), output_model_file)
def load_model(path, model):
# load pretrained model and update weights
pretrained_dict = torch.load(path, map_location='cpu')
model_dict = model.state_dict()
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)
return model