a b/eval.py
1
import numpy as np
2
import pickle, nrrd, json
3
import torch
4
from torch.utils import data
5
import torch.optim as optim
6
import torch.nn as nn
7
from torch.nn import functional as F
8
9
import model
10
from data import *
11
12
with open("config.json") as f:
13
  config = json.load(f)
14
15
device = torch.device("cuda:0")
16
17
with open(config["path"]["labelled_list"], "rb") as f:
18
  list_scans = pickle.load(f)
19
20
st_scans = [s.split('/')[1] for s in list_scans]
21
st_scans = st_scans[850:]
22
print(st_scans[0])
23
24
dataset = dataset.Dataset(st_scans, config["path"]["scans"], config["path"]["masks"], mode="3d")
25
26
criterion = utils.dice_loss
27
unet = model.UNet(1,config["train3d"]["n_classes"], config["train3d"]["start_filters"]).to(device)
28
unet.load_state_dict(torch.load("./model"))
29
30
for i in range(len(dataset)):
31
    X,y = dataset.__getitem__(i)
32
    X = torch.Tensor(np.array([X.astype(np.float16)])).to(device)
33
    y = torch.Tensor(np.array([y.astype(np.float16)])).to(device)
34
    logits = unet(X)
35
    loss = criterion(logits, y)
36
    print(loss.item())
37
    mask = logits.cpu().detach().numpy()
38
    nrrd.write("mask3D.nrrd", mask[0][0])
39
    break