[903821]: / eval.py

Download this file

57 lines (47 with data), 2.0 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import os
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from loss import Loss,cal_dice
from dataloaders.la_heart import LAHeart, CenterCrop, ToTensor
from networks.vnet import VNet
def eval_loop(model, criterion, valid_loader, device):
model.eval()
running_loss = 0
dice_valid = 0
with torch.no_grad():
for sampled_batch in valid_loader:
volume_batch, label_batch = sampled_batch['image'], sampled_batch['label']
volume_batch, label_batch = volume_batch.to(device), label_batch.to(device)
outputs = model(volume_batch)
loss = criterion(outputs, label_batch)
dice = cal_dice(outputs, label_batch)
print('dice: {:.3f}'.format(dice))
running_loss += loss.item()
dice_valid += dice.item()
loss = running_loss / len(valid_loader)
dice = dice_valid / len(valid_loader)
return {'loss': loss, 'dice': dice}
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data_path = '/***、data_set/LASet/data'
patch_size = (112,112,80)
model = VNet(n_channels=1,n_classes=2, normalization='batchnorm').to(device)
# 加载训练模型
weight_path = 'results/VNet.pth'
weight_dict = torch.load(weight_path, map_location=device)
model.load_state_dict(weight_dict)
print('Successfully loading checkpoint.')
criterion = Loss(n_classes=2).to(device)
db_test = LAHeart(base_dir=data_path,
split='test',
transform=transforms.Compose([
CenterCrop(patch_size),
ToTensor()
]))
testloader = DataLoader(db_test,batch_size=1, num_workers=4, pin_memory=True)
model.eval()
valid_metrics = eval_loop(model, criterion, testloader, device)
# 这里的dice是测试集中心裁剪的dice
dice = valid_metrics['dice']
print('Average dice: {:.5f}'.format(dice))