|
a |
|
b/U-Net/evaluate.py |
|
|
1 |
import torch |
|
|
2 |
import torch.nn.functional as F |
|
|
3 |
from tqdm import tqdm |
|
|
4 |
|
|
|
5 |
from utils.dice_score import multiclass_dice_coeff, dice_coeff |
|
|
6 |
|
|
|
7 |
|
|
|
8 |
@torch.inference_mode() |
|
|
9 |
def evaluate(net, dataloader, device, amp): |
|
|
10 |
net.eval() |
|
|
11 |
num_val_batches = len(dataloader) |
|
|
12 |
dice_score = 0 |
|
|
13 |
|
|
|
14 |
# iterate over the validation set |
|
|
15 |
with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp): |
|
|
16 |
for batch in tqdm(dataloader, total=num_val_batches, desc='Validation round', unit='batch', leave=False): |
|
|
17 |
image, mask_true = batch['image'], batch['mask'] |
|
|
18 |
|
|
|
19 |
# move images and labels to correct device and type |
|
|
20 |
image = image.to(device=device, dtype=torch.float32, memory_format=torch.channels_last) |
|
|
21 |
mask_true = mask_true.to(device=device, dtype=torch.long) |
|
|
22 |
|
|
|
23 |
# predict the mask |
|
|
24 |
mask_pred = net(image) |
|
|
25 |
|
|
|
26 |
if net.n_classes == 1: |
|
|
27 |
assert mask_true.min() >= 0 and mask_true.max() <= 1, 'True mask indices should be in [0, 1]' |
|
|
28 |
mask_pred = (F.sigmoid(mask_pred) > 0.5).float() |
|
|
29 |
# compute the Dice score |
|
|
30 |
dice_score += dice_coeff(mask_pred, mask_true, reduce_batch_first=False) |
|
|
31 |
else: |
|
|
32 |
assert mask_true.min() >= 0 and mask_true.max() < net.n_classes, 'True mask indices should be in [0, n_classes[' |
|
|
33 |
# convert to one-hot format |
|
|
34 |
mask_true = F.one_hot(mask_true, net.n_classes).permute(0, 3, 1, 2).float() |
|
|
35 |
mask_pred = F.one_hot(mask_pred.argmax(dim=1), net.n_classes).permute(0, 3, 1, 2).float() |
|
|
36 |
# compute the Dice score, ignoring background |
|
|
37 |
dice_score += multiclass_dice_coeff(mask_pred[:, 1:], mask_true[:, 1:], reduce_batch_first=False) |
|
|
38 |
|
|
|
39 |
net.train() |
|
|
40 |
return dice_score / max(num_val_batches, 1) |