Diff of /eval.py [000000] .. [903821]

Switch to unified view

a b/eval.py
1
import os
2
import torch
3
from torch.utils.data import DataLoader
4
from torchvision import transforms
5
from loss import Loss,cal_dice
6
from dataloaders.la_heart import LAHeart, CenterCrop, ToTensor
7
from networks.vnet import VNet
8
9
10
def eval_loop(model, criterion, valid_loader, device):
11
    model.eval()
12
    running_loss = 0
13
    dice_valid = 0
14
15
    with torch.no_grad():
16
        for sampled_batch in valid_loader:
17
            volume_batch, label_batch = sampled_batch['image'], sampled_batch['label']
18
            volume_batch, label_batch = volume_batch.to(device), label_batch.to(device)
19
20
            outputs = model(volume_batch)
21
22
            loss = criterion(outputs, label_batch)
23
            dice = cal_dice(outputs, label_batch)
24
            print('dice: {:.3f}'.format(dice))
25
            running_loss += loss.item()
26
            dice_valid += dice.item()
27
28
    loss = running_loss / len(valid_loader)
29
    dice = dice_valid / len(valid_loader)
30
    return {'loss': loss, 'dice': dice}
31
32
33
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
34
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
35
data_path = '/***、data_set/LASet/data'
36
patch_size = (112,112,80)
37
model = VNet(n_channels=1,n_classes=2, normalization='batchnorm').to(device)
38
# 加载训练模型
39
weight_path = 'results/VNet.pth'
40
weight_dict = torch.load(weight_path, map_location=device)
41
model.load_state_dict(weight_dict)
42
print('Successfully loading checkpoint.')
43
criterion = Loss(n_classes=2).to(device)
44
db_test = LAHeart(base_dir=data_path,
45
                        split='test',
46
                        transform=transforms.Compose([
47
                        CenterCrop(patch_size),
48
                        ToTensor()
49
                    ]))
50
testloader = DataLoader(db_test,batch_size=1, num_workers=4, pin_memory=True)
51
52
model.eval()
53
valid_metrics = eval_loop(model, criterion, testloader, device)
54
# 这里的dice是测试集中心裁剪的dice
55
dice = valid_metrics['dice']
56
print('Average dice: {:.5f}'.format(dice))