a | b/utils/network_utils.py | ||
---|---|---|---|
1 | import torch |
||
2 | |||
3 | |||
4 | def load_checkpoint(path, model, optimizer=None): |
||
5 | pth = torch.load(path) |
||
6 | |||
7 | model.load_state_dict(pth["state_dict"]) |
||
8 | if optimizer: |
||
9 | optimizer.load_state_dict(pth["optimizer"]) |
||
10 | |||
11 | print("Checkpoint {} successfully loaded".format(path)) |
||
12 | |||
13 | return pth["epoch"], pth["total_iter"] |
||
14 | |||
15 | |||
16 | def save_checkpoint(state, path): |
||
17 | torch.save(state, path) |