|
a |
|
b/eval.py |
|
|
1 |
import os |
|
|
2 |
import torch |
|
|
3 |
from torch.utils.data import DataLoader |
|
|
4 |
from torchvision import transforms |
|
|
5 |
from loss import Loss,cal_dice |
|
|
6 |
from dataloaders.la_heart import LAHeart, CenterCrop, ToTensor |
|
|
7 |
from networks.vnet import VNet |
|
|
8 |
|
|
|
9 |
|
|
|
10 |
def eval_loop(model, criterion, valid_loader, device): |
|
|
11 |
model.eval() |
|
|
12 |
running_loss = 0 |
|
|
13 |
dice_valid = 0 |
|
|
14 |
|
|
|
15 |
with torch.no_grad(): |
|
|
16 |
for sampled_batch in valid_loader: |
|
|
17 |
volume_batch, label_batch = sampled_batch['image'], sampled_batch['label'] |
|
|
18 |
volume_batch, label_batch = volume_batch.to(device), label_batch.to(device) |
|
|
19 |
|
|
|
20 |
outputs = model(volume_batch) |
|
|
21 |
|
|
|
22 |
loss = criterion(outputs, label_batch) |
|
|
23 |
dice = cal_dice(outputs, label_batch) |
|
|
24 |
print('dice: {:.3f}'.format(dice)) |
|
|
25 |
running_loss += loss.item() |
|
|
26 |
dice_valid += dice.item() |
|
|
27 |
|
|
|
28 |
loss = running_loss / len(valid_loader) |
|
|
29 |
dice = dice_valid / len(valid_loader) |
|
|
30 |
return {'loss': loss, 'dice': dice} |
|
|
31 |
|
|
|
32 |
|
|
|
33 |
os.environ['CUDA_VISIBLE_DEVICES'] = '0' |
|
|
34 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
35 |
data_path = '/***、data_set/LASet/data' |
|
|
36 |
patch_size = (112,112,80) |
|
|
37 |
model = VNet(n_channels=1,n_classes=2, normalization='batchnorm').to(device) |
|
|
38 |
# 加载训练模型 |
|
|
39 |
weight_path = 'results/VNet.pth' |
|
|
40 |
weight_dict = torch.load(weight_path, map_location=device) |
|
|
41 |
model.load_state_dict(weight_dict) |
|
|
42 |
print('Successfully loading checkpoint.') |
|
|
43 |
criterion = Loss(n_classes=2).to(device) |
|
|
44 |
db_test = LAHeart(base_dir=data_path, |
|
|
45 |
split='test', |
|
|
46 |
transform=transforms.Compose([ |
|
|
47 |
CenterCrop(patch_size), |
|
|
48 |
ToTensor() |
|
|
49 |
])) |
|
|
50 |
testloader = DataLoader(db_test,batch_size=1, num_workers=4, pin_memory=True) |
|
|
51 |
|
|
|
52 |
model.eval() |
|
|
53 |
valid_metrics = eval_loop(model, criterion, testloader, device) |
|
|
54 |
# 这里的dice是测试集中心裁剪的dice |
|
|
55 |
dice = valid_metrics['dice'] |
|
|
56 |
print('Average dice: {:.5f}'.format(dice)) |