|
a |
|
b/utils.py |
|
|
1 |
import torch |
|
|
2 |
import torchvision |
|
|
3 |
from dataset import ChestDataset |
|
|
4 |
from torch.utils.data import DataLoader |
|
|
5 |
|
|
|
6 |
def save_checkpoint(state, filename="my_checkpoint.pth.tar"): |
|
|
7 |
print('=> Saving Checkpoint') |
|
|
8 |
torch.save(state, filename) |
|
|
9 |
|
|
|
10 |
def load_checkpoint(checkpoint, model): |
|
|
11 |
print("=> Loading Checkpoint") |
|
|
12 |
model.load_state_dict(checkpoint['state_dict']) |
|
|
13 |
|
|
|
14 |
def get_loaders(train_dir, train_maskdir, test_dir, test_maskdir, batch_size, train_transform, test_transform, num_workers=4, pin_memory=True): |
|
|
15 |
train_ds = ChestDataset(image_dir=train_dir, mask_dir=train_maskdir, transform=train_transform) |
|
|
16 |
train_loader = DataLoader(train_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=True) |
|
|
17 |
test_ds = ChestDataset(image_dir=test_dir, mask_dir=test_maskdir, transform=test_transform) |
|
|
18 |
test_loader = DataLoader(test_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=False) |
|
|
19 |
|
|
|
20 |
return train_loader, test_loader |
|
|
21 |
|
|
|
22 |
def check_accuracy(loader, model): |
|
|
23 |
num_correct = 0 |
|
|
24 |
num_pixels = 0 |
|
|
25 |
dice_score = 0 |
|
|
26 |
model.eval() |
|
|
27 |
|
|
|
28 |
with torch.no_grad(): |
|
|
29 |
for x, y in loader: |
|
|
30 |
y = y.unsqueeze(1) |
|
|
31 |
preds = torch.sigmoid(model(x)) |
|
|
32 |
preds = (preds> 0.5).float() |
|
|
33 |
num_correct += (preds == y).sum() |
|
|
34 |
num_pixels += torch.numel(preds) |
|
|
35 |
dice_score += (2*(preds*y).sum()) / ((preds+y).sum() + 1e-8) |
|
|
36 |
print( |
|
|
37 |
f"Got {num_correct}/ {num_pixels} with accuracy {num_correct/num_pixels*100:.2f}" |
|
|
38 |
) |
|
|
39 |
print(f'Dice Score: {dice_score/len(loader)}') |
|
|
40 |
model.train() |
|
|
41 |
|
|
|
42 |
def save_predictions_as_images(loader, model, folder='saved_images/'): |
|
|
43 |
model.eval() |
|
|
44 |
for idx, (x,y) in enumerate(loader): |
|
|
45 |
with torch.no_grad(): |
|
|
46 |
preds = torch.sigmoid(model(x)) |
|
|
47 |
preds = (preds>0.5).float() |
|
|
48 |
torchvision.utils.save_image( |
|
|
49 |
preds, f'{folder}/pred_{idx}.png' |
|
|
50 |
) |
|
|
51 |
torchvision.utils.save_image(y.unsqueeze(1), f'{folder}{idx}.png') |
|
|
52 |
|
|
|
53 |
model.train() |