--- a +++ b/utils.py @@ -0,0 +1,53 @@ +import torch +import torchvision +from dataset import ChestDataset +from torch.utils.data import DataLoader + +def save_checkpoint(state, filename="my_checkpoint.pth.tar"): + print('=> Saving Checkpoint') + torch.save(state, filename) + +def load_checkpoint(checkpoint, model): + print("=> Loading Checkpoint") + model.load_state_dict(checkpoint['state_dict']) + +def get_loaders(train_dir, train_maskdir, test_dir, test_maskdir, batch_size, train_transform, test_transform, num_workers=4, pin_memory=True): + train_ds = ChestDataset(image_dir=train_dir, mask_dir=train_maskdir, transform=train_transform) + train_loader = DataLoader(train_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=True) + test_ds = ChestDataset(image_dir=test_dir, mask_dir=test_maskdir, transform=test_transform) + test_loader = DataLoader(test_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=False) + + return train_loader, test_loader + +def check_accuracy(loader, model): + num_correct = 0 + num_pixels = 0 + dice_score = 0 + model.eval() + + with torch.no_grad(): + for x, y in loader: + y = y.unsqueeze(1) + preds = torch.sigmoid(model(x)) + preds = (preds> 0.5).float() + num_correct += (preds == y).sum() + num_pixels += torch.numel(preds) + dice_score += (2*(preds*y).sum()) / ((preds+y).sum() + 1e-8) + print( + f"Got {num_correct}/ {num_pixels} with accuracy {num_correct/num_pixels*100:.2f}" + ) + print(f'Dice Score: {dice_score/len(loader)}') + model.train() + +def save_predictions_as_images(loader, model, folder='saved_images/'): + model.eval() + for idx, (x,y) in enumerate(loader): + with torch.no_grad(): + preds = torch.sigmoid(model(x)) + preds = (preds>0.5).float() + torchvision.utils.save_image( + preds, f'{folder}/pred_{idx}.png' + ) + torchvision.utils.save_image(y.unsqueeze(1), f'{folder}{idx}.png') + + model.train() \ No newline at end of file