[fbbdf8]: / utils / network_utils.py

Download this file

18 lines (10 with data), 377 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
import torch
def load_checkpoint(path, model, optimizer=None):
pth = torch.load(path)
model.load_state_dict(pth["state_dict"])
if optimizer:
optimizer.load_state_dict(pth["optimizer"])
print("Checkpoint {} successfully loaded".format(path))
return pth["epoch"], pth["total_iter"]
def save_checkpoint(state, path):
torch.save(state, path)