a b/test/test_metrics.py
1
from unittest import TestCase
2
3
import numpy as np
4
import keras.backend as K
5
6
7
from fetal_net.metrics import weighted_dice_coefficient
8
9
10
class TestWeightedDice(TestCase):
11
    def test_weighted_dice_coefficient(self):
12
        data = np.zeros((5**3) * 3).reshape(3, 5, 5, 5)
13
        data[0, 0:1] = 1
14
        data[1, 0:2] = 1
15
        data[2, 1:4] = 1
16
17
        max_dice = K.eval(weighted_dice_coefficient(K.variable(data), K.variable(data)))
18
        for index in range(data.shape[0]):
19
            temp_data = np.copy(data)
20
            temp_data[index] = 0
21
            dice = K.eval(weighted_dice_coefficient(K.variable(data), K.variable(temp_data)))
22
            self.assertAlmostEqual(dice, (2 * max_dice)/3, delta=0.00001)
23
24
    def test_blank_dice_coefficient(self):
25
        data = np.zeros((5**3) * 3).reshape(3, 5, 5, 5)
26
        blank = np.copy(data)
27
        data[0, 0:1] = 1
28
        data[1, 0:2] = 1
29
        data[2, 1:4] = 1
30
31
        self.assertAlmostEqual(K.eval(weighted_dice_coefficient(K.variable(data), K.variable(blank))), 0, delta=0.00001)
32
33
    def test_empty_label(self):
34
        data = np.zeros((5**3) * 3).reshape(3, 5, 5, 5)
35
        data[1, 0:2] = 1
36
        data[2, 1:4] = 1
37
38
        self.assertEqual(K.eval(weighted_dice_coefficient(K.variable(data), K.variable(data))), 1)