|
a |
|
b/test.py |
|
|
1 |
import torch |
|
|
2 |
from Unet.model import UNET |
|
|
3 |
from Utils import utils |
|
|
4 |
import torchvision |
|
|
5 |
from Preprocessing.preprocessing import Preprocessor |
|
|
6 |
import cv2 |
|
|
7 |
import os |
|
|
8 |
|
|
|
9 |
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
10 |
|
|
|
11 |
model = UNET(in_channels=1, out_channels=1).to(DEVICE).float() |
|
|
12 |
utils.load_checkpoint(torch.load('tmp/checkpoint.pth.tar', map_location=torch.device('cpu')), model) |
|
|
13 |
model.eval() |
|
|
14 |
|
|
|
15 |
prep = Preprocessor(128, 128) |
|
|
16 |
|
|
|
17 |
for fileA in os.listdir('Dataset/zdim/valimagesbayes'): |
|
|
18 |
with torch.no_grad(): |
|
|
19 |
x = cv2.imread('Dataset/zdim/valimagesbayes/'+fileA, 0) |
|
|
20 |
x = prep.preprocess(x) |
|
|
21 |
x = torch.from_numpy(x).float() |
|
|
22 |
x = x.unsqueeze(0).unsqueeze(0) |
|
|
23 |
x = x.to(DEVICE) |
|
|
24 |
preds = model(x) |
|
|
25 |
# preds[preds<0.5]=0 |
|
|
26 |
# preds = torch.sigmoid(preds) |
|
|
27 |
torchvision.utils.save_image(preds, './resultsz/'+fileA) |