Diff of /test.py [000000] .. [6d4adb]

Switch to unified view

a b/test.py
1
import torch
2
from Unet.model import UNET
3
from Utils import utils
4
import torchvision
5
from Preprocessing.preprocessing import Preprocessor
6
import cv2
7
import os
8
9
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
10
11
model = UNET(in_channels=1, out_channels=1).to(DEVICE).float()
12
utils.load_checkpoint(torch.load('tmp/checkpoint.pth.tar', map_location=torch.device('cpu')), model)
13
model.eval()
14
15
prep = Preprocessor(128, 128)
16
17
for fileA in os.listdir('Dataset/zdim/valimagesbayes'):
18
    with torch.no_grad():
19
        x = cv2.imread('Dataset/zdim/valimagesbayes/'+fileA, 0)
20
        x = prep.preprocess(x)
21
        x = torch.from_numpy(x).float()
22
        x = x.unsqueeze(0).unsqueeze(0)
23
        x = x.to(DEVICE)
24
        preds = model(x)
25
        # preds[preds<0.5]=0
26
        # preds = torch.sigmoid(preds)
27
        torchvision.utils.save_image(preds, './resultsz/'+fileA)