Diff of /utils.py [000000] .. [c621c3]

Switch to side-by-side view

--- a
+++ b/utils.py
@@ -0,0 +1,53 @@
+import torch
+import torchvision
+from dataset import ChestDataset
+from torch.utils.data import DataLoader
+
+def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
+    print('=> Saving Checkpoint')
+    torch.save(state, filename)
+
+def load_checkpoint(checkpoint, model):
+    print("=> Loading Checkpoint")
+    model.load_state_dict(checkpoint['state_dict'])
+
+def get_loaders(train_dir, train_maskdir, test_dir, test_maskdir, batch_size, train_transform, test_transform, num_workers=4, pin_memory=True):
+    train_ds = ChestDataset(image_dir=train_dir, mask_dir=train_maskdir, transform=train_transform)
+    train_loader = DataLoader(train_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=True)
+    test_ds = ChestDataset(image_dir=test_dir, mask_dir=test_maskdir, transform=test_transform)
+    test_loader = DataLoader(test_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=False)
+
+    return train_loader, test_loader
+
+def check_accuracy(loader, model):
+    num_correct = 0
+    num_pixels = 0
+    dice_score = 0
+    model.eval()
+
+    with torch.no_grad():
+        for x, y in loader:
+            y = y.unsqueeze(1)
+            preds = torch.sigmoid(model(x))
+            preds = (preds> 0.5).float()
+            num_correct += (preds == y).sum()
+            num_pixels += torch.numel(preds)
+            dice_score += (2*(preds*y).sum()) / ((preds+y).sum() + 1e-8)
+    print(
+        f"Got {num_correct}/ {num_pixels} with accuracy {num_correct/num_pixels*100:.2f}"
+    )
+    print(f'Dice Score: {dice_score/len(loader)}')
+    model.train()
+
+def save_predictions_as_images(loader, model, folder='saved_images/'):
+    model.eval()
+    for idx, (x,y) in enumerate(loader):
+        with torch.no_grad():
+            preds = torch.sigmoid(model(x))
+            preds = (preds>0.5).float()
+        torchvision.utils.save_image(
+            preds, f'{folder}/pred_{idx}.png'
+        )
+        torchvision.utils.save_image(y.unsqueeze(1), f'{folder}{idx}.png')
+
+    model.train()
\ No newline at end of file