Diff of /metrics.py [000000] .. [72db80]

Switch to unified view

a b/metrics.py
1
2
import numpy as np
3
def single_dice_coef(y_pred, y_true):
4
    # shape of y_true and y_pred: (height, width)
5
    intersection = np.sum(y_true * y_pred)
6
    if (np.sum(y_true) == 0) and (np.sum(y_pred) == 0):
7
        return 1
8
    return (2*intersection) / (np.sum(y_true) + np.sum(y_pred))
9
10
11
def mean_dice_coef(y_pred, y_true):
12
    # shape of y_true and y_pred: (n_samples, height, width)
13
    batch_size = y_true.shape[0]
14
    mean_dice_channel = 0.
15
    for i in range(batch_size):
16
        channel_dice = single_dice_coef(y_pred[i, :, :], y_true[i, :, :])
17
        mean_dice_channel += channel_dice/(batch_size)
18
    return mean_dice_channel
19
20
def mean_dice_coef_remove_empty(y_pred, y_true):
21
    # shape of y_true and y_pred: (n_samples, height, width)
22
    batch_size = y_true.shape[0]
23
    mean_dice_channel = 0.
24
    num_no_empty = batch_size
25
    for i in range(batch_size):
26
        if (np.sum(y_true[i, :, :]) == 0):
27
            num_no_empty -= 1
28
            continue
29
30
        channel_dice = single_dice_coef(y_pred[i, :, :], y_true[i, :, :])
31
        mean_dice_channel += channel_dice
32
    
33
    if num_no_empty == 0:
34
        return None
35
36
    return mean_dice_channel/(num_no_empty)