a b/U-Net/evaluate.py
1
import torch
2
import torch.nn.functional as F
3
from tqdm import tqdm
4
5
from utils.dice_score import multiclass_dice_coeff, dice_coeff
6
7
8
@torch.inference_mode()
9
def evaluate(net, dataloader, device, amp):
10
    net.eval()
11
    num_val_batches = len(dataloader)
12
    dice_score = 0
13
14
    # iterate over the validation set
15
    with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
16
        for batch in tqdm(dataloader, total=num_val_batches, desc='Validation round', unit='batch', leave=False):
17
            image, mask_true = batch['image'], batch['mask']
18
19
            # move images and labels to correct device and type
20
            image = image.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
21
            mask_true = mask_true.to(device=device, dtype=torch.long)
22
23
            # predict the mask
24
            mask_pred = net(image)
25
26
            if net.n_classes == 1:
27
                assert mask_true.min() >= 0 and mask_true.max() <= 1, 'True mask indices should be in [0, 1]'
28
                mask_pred = (F.sigmoid(mask_pred) > 0.5).float()
29
                # compute the Dice score
30
                dice_score += dice_coeff(mask_pred, mask_true, reduce_batch_first=False)
31
            else:
32
                assert mask_true.min() >= 0 and mask_true.max() < net.n_classes, 'True mask indices should be in [0, n_classes['
33
                # convert to one-hot format
34
                mask_true = F.one_hot(mask_true, net.n_classes).permute(0, 3, 1, 2).float()
35
                mask_pred = F.one_hot(mask_pred.argmax(dim=1), net.n_classes).permute(0, 3, 1, 2).float()
36
                # compute the Dice score, ignoring background
37
                dice_score += multiclass_dice_coeff(mask_pred[:, 1:], mask_true[:, 1:], reduce_batch_first=False)
38
39
    net.train()
40
    return dice_score / max(num_val_batches, 1)