Diff of /common/pytorch.py [000000] .. [bad60c]

Switch to unified view

a b/common/pytorch.py
1
import torch
2
3
4
def save_model(path, model):
5
    # Save a trained model
6
    print("** ** * Saving fine - tuned model ** ** * ")
7
    model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
8
    output_model_file = path
9
    torch.save(model_to_save.state_dict(), output_model_file)
10
11
12
def load_model(path, model):
13
    # load pretrained model and update weights
14
    pretrained_dict = torch.load(path, map_location='cpu')
15
    model_dict = model.state_dict()
16
    # 1. filter out unnecessary keys
17
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
18
    # 2. overwrite entries in the existing state dict
19
    model_dict.update(pretrained_dict)
20
    # 3. load the new state dict
21
    model.load_state_dict(model_dict)
22
    return model
23