[6d4adb]: / Utils / utils.py

Download this file

54 lines (37 with data), 1.6 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
import torchvision
import os
def save_checkpoint(state, filename='tmp/checkpoint.pth.tar'):
print('[INFO] Saving checkpoint')
torch.save(state, filename)
def load_checkpoint(checkpoint, model):
print('[INFO] Loading checkpoint')
model.load_state_dict(checkpoint['state_dict'])
def check_accuracy(loader, model, device):
num_correct = 0
num_pixels = 0
dice_score = 0
with torch.no_grad():
for x, y in loader:
x = x.to(device)
y = y.to(device)
preds = model(x)
preds = (preds > 0.5).float()
num_correct += (preds == y).sum()
num_pixels += torch.numel(preds)
dice_score += (2 * (preds * y).sum()) / ((preds + y).sum() + 1e-9)
print('Got {}/{} with acc {:2f}'.format(num_correct, num_pixels, num_correct / num_pixels * 100))
print('Dice score {}'.format(dice_score / len(loader)))
# wandb.log({"dice": dice_score})
# wandb.log({"acc": (num_correct, num_pixels, num_correct / num_pixels * 100)})
model.train()
def save_predictions_as_imgs(loader, model, device, folder='tmp/'):
model.eval()
for idx, (x, y) in enumerate(loader):
x = x.to(device)
with torch.no_grad():
preds = model(x)
preds = (preds > 0.5).float()
torchvision.utils.save_image(preds, os.path.join(folder, 'pred_{}.png'.format(idx)))
# torchvision.utils.save_image(y, os.path.join(folder, '{}.png'.format(idx)))
model.train()