[96354c]: / tests / losses / test_dice_loss.py

Download this file

40 lines (26 with data), 1.0 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import os
import pytest
import torch
from src.dataset.utils import nifi_volume
from src.losses import dice_loss, utils
from torch import nn
@pytest.fixture(scope="function")
def volume():
patient = "BraTS20_Training_001_p0_64x64x64"
gen_path = "/Users/lauramora/Documents/MASTER/TFM/Data/2020/train/random_tumor_distribution/"
volume_path = os.path.join(gen_path, patient, f"{patient}_seg.nii.gz")
return nifi_volume.load_nifi_volume(volume_path, normalize=False)
class Identity(nn.Module):
def forward(self, input):
return input
def test_dice_loss(volume):
volume[volume == 4] = 3
classes = 4
my_loss = dice_loss.DiceLoss(classes=classes, weight=None, sigmoid_normalization=True, eval_regions=False)
seg_mask = torch.from_numpy(volume.astype(int))
target = seg_mask.unsqueeze(0).to("cpu")
input = utils.expand_as_one_hot(target.long(), classes)
my_loss.normalization = Identity()
loss, score, _ = my_loss(input, target)
assert loss == 0
assert score == 1