Diff of /utils.py [000000] .. [c621c3]

Switch to unified view

a b/utils.py
1
import torch
2
import torchvision
3
from dataset import ChestDataset
4
from torch.utils.data import DataLoader
5
6
def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
7
    print('=> Saving Checkpoint')
8
    torch.save(state, filename)
9
10
def load_checkpoint(checkpoint, model):
11
    print("=> Loading Checkpoint")
12
    model.load_state_dict(checkpoint['state_dict'])
13
14
def get_loaders(train_dir, train_maskdir, test_dir, test_maskdir, batch_size, train_transform, test_transform, num_workers=4, pin_memory=True):
15
    train_ds = ChestDataset(image_dir=train_dir, mask_dir=train_maskdir, transform=train_transform)
16
    train_loader = DataLoader(train_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=True)
17
    test_ds = ChestDataset(image_dir=test_dir, mask_dir=test_maskdir, transform=test_transform)
18
    test_loader = DataLoader(test_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=False)
19
20
    return train_loader, test_loader
21
22
def check_accuracy(loader, model):
23
    num_correct = 0
24
    num_pixels = 0
25
    dice_score = 0
26
    model.eval()
27
28
    with torch.no_grad():
29
        for x, y in loader:
30
            y = y.unsqueeze(1)
31
            preds = torch.sigmoid(model(x))
32
            preds = (preds> 0.5).float()
33
            num_correct += (preds == y).sum()
34
            num_pixels += torch.numel(preds)
35
            dice_score += (2*(preds*y).sum()) / ((preds+y).sum() + 1e-8)
36
    print(
37
        f"Got {num_correct}/ {num_pixels} with accuracy {num_correct/num_pixels*100:.2f}"
38
    )
39
    print(f'Dice Score: {dice_score/len(loader)}')
40
    model.train()
41
42
def save_predictions_as_images(loader, model, folder='saved_images/'):
43
    model.eval()
44
    for idx, (x,y) in enumerate(loader):
45
        with torch.no_grad():
46
            preds = torch.sigmoid(model(x))
47
            preds = (preds>0.5).float()
48
        torchvision.utils.save_image(
49
            preds, f'{folder}/pred_{idx}.png'
50
        )
51
        torchvision.utils.save_image(y.unsqueeze(1), f'{folder}{idx}.png')
52
53
    model.train()