--- a +++ b/test.py @@ -0,0 +1,27 @@ +import torch +from Unet.model import UNET +from Utils import utils +import torchvision +from Preprocessing.preprocessing import Preprocessor +import cv2 +import os + +DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +model = UNET(in_channels=1, out_channels=1).to(DEVICE).float() +utils.load_checkpoint(torch.load('tmp/checkpoint.pth.tar', map_location=torch.device('cpu')), model) +model.eval() + +prep = Preprocessor(128, 128) + +for fileA in os.listdir('Dataset/zdim/valimagesbayes'): + with torch.no_grad(): + x = cv2.imread('Dataset/zdim/valimagesbayes/'+fileA, 0) + x = prep.preprocess(x) + x = torch.from_numpy(x).float() + x = x.unsqueeze(0).unsqueeze(0) + x = x.to(DEVICE) + preds = model(x) + # preds[preds<0.5]=0 + # preds = torch.sigmoid(preds) + torchvision.utils.save_image(preds, './resultsz/'+fileA)