--- a +++ b/eval.py @@ -0,0 +1,56 @@ +import os +import torch +from torch.utils.data import DataLoader +from torchvision import transforms +from loss import Loss,cal_dice +from dataloaders.la_heart import LAHeart, CenterCrop, ToTensor +from networks.vnet import VNet + + +def eval_loop(model, criterion, valid_loader, device): + model.eval() + running_loss = 0 + dice_valid = 0 + + with torch.no_grad(): + for sampled_batch in valid_loader: + volume_batch, label_batch = sampled_batch['image'], sampled_batch['label'] + volume_batch, label_batch = volume_batch.to(device), label_batch.to(device) + + outputs = model(volume_batch) + + loss = criterion(outputs, label_batch) + dice = cal_dice(outputs, label_batch) + print('dice: {:.3f}'.format(dice)) + running_loss += loss.item() + dice_valid += dice.item() + + loss = running_loss / len(valid_loader) + dice = dice_valid / len(valid_loader) + return {'loss': loss, 'dice': dice} + + +os.environ['CUDA_VISIBLE_DEVICES'] = '0' +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +data_path = '/***、data_set/LASet/data' +patch_size = (112,112,80) +model = VNet(n_channels=1,n_classes=2, normalization='batchnorm').to(device) +# 加载训练模型 +weight_path = 'results/VNet.pth' +weight_dict = torch.load(weight_path, map_location=device) +model.load_state_dict(weight_dict) +print('Successfully loading checkpoint.') +criterion = Loss(n_classes=2).to(device) +db_test = LAHeart(base_dir=data_path, + split='test', + transform=transforms.Compose([ + CenterCrop(patch_size), + ToTensor() + ])) +testloader = DataLoader(db_test,batch_size=1, num_workers=4, pin_memory=True) + +model.eval() +valid_metrics = eval_loop(model, criterion, testloader, device) +# 这里的dice是测试集中心裁剪的dice +dice = valid_metrics['dice'] +print('Average dice: {:.5f}'.format(dice))