|
a |
|
b/eval.py |
|
|
1 |
import numpy as np |
|
|
2 |
import pickle, nrrd, json |
|
|
3 |
import torch |
|
|
4 |
from torch.utils import data |
|
|
5 |
import torch.optim as optim |
|
|
6 |
import torch.nn as nn |
|
|
7 |
from torch.nn import functional as F |
|
|
8 |
|
|
|
9 |
import model |
|
|
10 |
from data import * |
|
|
11 |
|
|
|
12 |
with open("config.json") as f: |
|
|
13 |
config = json.load(f) |
|
|
14 |
|
|
|
15 |
device = torch.device("cuda:0") |
|
|
16 |
|
|
|
17 |
with open(config["path"]["labelled_list"], "rb") as f: |
|
|
18 |
list_scans = pickle.load(f) |
|
|
19 |
|
|
|
20 |
st_scans = [s.split('/')[1] for s in list_scans] |
|
|
21 |
st_scans = st_scans[850:] |
|
|
22 |
print(st_scans[0]) |
|
|
23 |
|
|
|
24 |
dataset = dataset.Dataset(st_scans, config["path"]["scans"], config["path"]["masks"], mode="3d") |
|
|
25 |
|
|
|
26 |
criterion = utils.dice_loss |
|
|
27 |
unet = model.UNet(1,config["train3d"]["n_classes"], config["train3d"]["start_filters"]).to(device) |
|
|
28 |
unet.load_state_dict(torch.load("./model")) |
|
|
29 |
|
|
|
30 |
for i in range(len(dataset)): |
|
|
31 |
X,y = dataset.__getitem__(i) |
|
|
32 |
X = torch.Tensor(np.array([X.astype(np.float16)])).to(device) |
|
|
33 |
y = torch.Tensor(np.array([y.astype(np.float16)])).to(device) |
|
|
34 |
logits = unet(X) |
|
|
35 |
loss = criterion(logits, y) |
|
|
36 |
print(loss.item()) |
|
|
37 |
mask = logits.cpu().detach().numpy() |
|
|
38 |
nrrd.write("mask3D.nrrd", mask[0][0]) |
|
|
39 |
break |