Diff of /utils/network_utils.py [000000] .. [fbbdf8]

Switch to unified view

a b/utils/network_utils.py
1
import torch
2
3
4
def load_checkpoint(path, model, optimizer=None):
5
    pth = torch.load(path)
6
7
    model.load_state_dict(pth["state_dict"])
8
    if optimizer:
9
        optimizer.load_state_dict(pth["optimizer"])
10
11
    print("Checkpoint {} successfully loaded".format(path))
12
13
    return pth["epoch"], pth["total_iter"]
14
15
16
def save_checkpoint(state, path):
17
    torch.save(state, path)