|
a |
|
b/tests/test_loss.py |
|
|
1 |
from __future__ import division, print_function |
|
|
2 |
|
|
|
3 |
import unittest |
|
|
4 |
import numpy as np |
|
|
5 |
|
|
|
6 |
from keras import backend as K |
|
|
7 |
|
|
|
8 |
from rvseg import loss |
|
|
9 |
|
|
|
10 |
class TestModel(unittest.TestCase): |
|
|
11 |
def sample_tensors(self): |
|
|
12 |
# have shapes (height, width, classes=2) |
|
|
13 |
y_true = K.constant([[[0, 1], [1, 0]], |
|
|
14 |
[[1, 0], [1, 0]]]) |
|
|
15 |
y_pred = K.constant([[[.4, .6], [.1, .9]], |
|
|
16 |
[[.5, .5], [.0, 1.]]]) |
|
|
17 |
return y_true, y_pred |
|
|
18 |
|
|
|
19 |
def test_sorensen_dice(self): |
|
|
20 |
y_true, y_pred = self.sample_tensors() |
|
|
21 |
dice_coefs = loss.soft_sorensen_dice(y_true, y_pred, axis=[0, 1]) |
|
|
22 |
dice_coefs = K.eval(dice_coefs) |
|
|
23 |
# class 1: (2 * 0.6 + 1) / (1 + 3 + 1) = 0.44 |
|
|
24 |
# class 2: (2 * 0.6 + 1) / (3 + 1 + 1) = 0.44 |
|
|
25 |
expected_dice_coefs = [0.44, 0.44] |
|
|
26 |
for x,y in zip(dice_coefs, expected_dice_coefs): |
|
|
27 |
self.assertAlmostEqual(x, y) |
|
|
28 |
|
|
|
29 |
dice_coefs = loss.hard_sorensen_dice(y_true, y_pred, axis=[0, 1]) |
|
|
30 |
dice_coefs = K.eval(dice_coefs) |
|
|
31 |
# class 1: (2 * 0 + 1) / (0 + 3 + 1) = 0.25 |
|
|
32 |
# class 2: (2 * 1 + 1) / (3 + 1 + 1) = 0.6 |
|
|
33 |
expected_dice_coefs = [0.25, 0.6] |
|
|
34 |
for x,y in zip(dice_coefs, expected_dice_coefs): |
|
|
35 |
self.assertAlmostEqual(x, y) |
|
|
36 |
|
|
|
37 |
def test_weighted_categorical_crossentropy(self): |
|
|
38 |
weights = [1, 9] |
|
|
39 |
y_true, y_pred = self.sample_tensors() |
|
|
40 |
|
|
|
41 |
lossfunc = loss.weighted_categorical_crossentropy( |
|
|
42 |
y_true, y_pred, weights) |
|
|
43 |
loss_val = K.eval(lossfunc) |
|
|
44 |
|
|
|
45 |
w = 2 * np.array(weights) / sum(weights) |
|
|
46 |
logs = -np.log(np.array([.1, .4, .5, .6, .9, 1e-8])) |
|
|
47 |
expected_loss_val = w[1]*logs[3]/4 + w[0]*(logs[0] + logs[2] + logs[5])/4 |
|
|
48 |
|
|
|
49 |
self.assertAlmostEqual(np.mean(expected_loss_val), loss_val, places=5) |
|
|
50 |
|