|
a |
|
b/common/pytorch.py |
|
|
1 |
import torch |
|
|
2 |
|
|
|
3 |
|
|
|
4 |
def save_model(path, model): |
|
|
5 |
# Save a trained model |
|
|
6 |
print("** ** * Saving fine - tuned model ** ** * ") |
|
|
7 |
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self |
|
|
8 |
output_model_file = path |
|
|
9 |
torch.save(model_to_save.state_dict(), output_model_file) |
|
|
10 |
|
|
|
11 |
|
|
|
12 |
def load_model(path, model): |
|
|
13 |
# load pretrained model and update weights |
|
|
14 |
pretrained_dict = torch.load(path, map_location='cpu') |
|
|
15 |
model_dict = model.state_dict() |
|
|
16 |
# 1. filter out unnecessary keys |
|
|
17 |
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} |
|
|
18 |
# 2. overwrite entries in the existing state dict |
|
|
19 |
model_dict.update(pretrained_dict) |
|
|
20 |
# 3. load the new state dict |
|
|
21 |
model.load_state_dict(model_dict) |
|
|
22 |
return model |
|
|
23 |
|