Diff of /Utils/utils.py [000000] .. [6d4adb]

Switch to unified view

a b/Utils/utils.py
1
import torch
2
import torchvision
3
import os
4
5
6
def save_checkpoint(state, filename='tmp/checkpoint.pth.tar'):
7
    print('[INFO] Saving checkpoint')
8
    torch.save(state, filename)
9
10
11
def load_checkpoint(checkpoint, model):
12
    print('[INFO] Loading checkpoint')
13
    model.load_state_dict(checkpoint['state_dict'])
14
15
16
def check_accuracy(loader, model, device):
17
    num_correct = 0
18
    num_pixels = 0
19
    dice_score = 0
20
21
    with torch.no_grad():
22
        for x, y in loader:
23
            x = x.to(device)
24
            y = y.to(device)
25
26
            preds = model(x)
27
            preds = (preds > 0.5).float()
28
29
            num_correct += (preds == y).sum()
30
            num_pixels += torch.numel(preds)
31
            dice_score += (2 * (preds * y).sum()) / ((preds + y).sum() + 1e-9)
32
            print('Got {}/{} with acc {:2f}'.format(num_correct, num_pixels, num_correct / num_pixels * 100))
33
            print('Dice score {}'.format(dice_score / len(loader)))
34
            # wandb.log({"dice": dice_score})
35
            # wandb.log({"acc": (num_correct, num_pixels, num_correct / num_pixels * 100)})
36
37
            model.train()
38
39
40
def save_predictions_as_imgs(loader, model, device, folder='tmp/'):
41
    model.eval()
42
43
    for idx, (x, y) in enumerate(loader):
44
        x = x.to(device)
45
46
        with torch.no_grad():
47
            preds = model(x)
48
            preds = (preds > 0.5).float()
49
50
        torchvision.utils.save_image(preds, os.path.join(folder, 'pred_{}.png'.format(idx)))
51
        # torchvision.utils.save_image(y, os.path.join(folder, '{}.png'.format(idx)))
52
53
    model.train()