|
a |
|
b/Utils/utils.py |
|
|
1 |
import torch |
|
|
2 |
import torchvision |
|
|
3 |
import os |
|
|
4 |
|
|
|
5 |
|
|
|
6 |
def save_checkpoint(state, filename='tmp/checkpoint.pth.tar'): |
|
|
7 |
print('[INFO] Saving checkpoint') |
|
|
8 |
torch.save(state, filename) |
|
|
9 |
|
|
|
10 |
|
|
|
11 |
def load_checkpoint(checkpoint, model): |
|
|
12 |
print('[INFO] Loading checkpoint') |
|
|
13 |
model.load_state_dict(checkpoint['state_dict']) |
|
|
14 |
|
|
|
15 |
|
|
|
16 |
def check_accuracy(loader, model, device): |
|
|
17 |
num_correct = 0 |
|
|
18 |
num_pixels = 0 |
|
|
19 |
dice_score = 0 |
|
|
20 |
|
|
|
21 |
with torch.no_grad(): |
|
|
22 |
for x, y in loader: |
|
|
23 |
x = x.to(device) |
|
|
24 |
y = y.to(device) |
|
|
25 |
|
|
|
26 |
preds = model(x) |
|
|
27 |
preds = (preds > 0.5).float() |
|
|
28 |
|
|
|
29 |
num_correct += (preds == y).sum() |
|
|
30 |
num_pixels += torch.numel(preds) |
|
|
31 |
dice_score += (2 * (preds * y).sum()) / ((preds + y).sum() + 1e-9) |
|
|
32 |
print('Got {}/{} with acc {:2f}'.format(num_correct, num_pixels, num_correct / num_pixels * 100)) |
|
|
33 |
print('Dice score {}'.format(dice_score / len(loader))) |
|
|
34 |
# wandb.log({"dice": dice_score}) |
|
|
35 |
# wandb.log({"acc": (num_correct, num_pixels, num_correct / num_pixels * 100)}) |
|
|
36 |
|
|
|
37 |
model.train() |
|
|
38 |
|
|
|
39 |
|
|
|
40 |
def save_predictions_as_imgs(loader, model, device, folder='tmp/'): |
|
|
41 |
model.eval() |
|
|
42 |
|
|
|
43 |
for idx, (x, y) in enumerate(loader): |
|
|
44 |
x = x.to(device) |
|
|
45 |
|
|
|
46 |
with torch.no_grad(): |
|
|
47 |
preds = model(x) |
|
|
48 |
preds = (preds > 0.5).float() |
|
|
49 |
|
|
|
50 |
torchvision.utils.save_image(preds, os.path.join(folder, 'pred_{}.png'.format(idx))) |
|
|
51 |
# torchvision.utils.save_image(y, os.path.join(folder, '{}.png'.format(idx))) |
|
|
52 |
|
|
|
53 |
model.train() |