Switch to unified view

a b/medseg_dl/model/metrics.py
1
import tensorflow as tf
2
import os
3
4
5
def metrics_fn(labels, probs, channels_out, b_verbose=False):
6
    """ calc metrics and fetch ops """
7
8
    labels_var = tf.argmax(labels, axis=-1)
9
    labels_hot = tf.one_hot(labels_var, channels_out)
10
    predictions_var = tf.argmax(probs, axis=-1)
11
    predictions_hot = tf.one_hot(predictions_var, channels_out)
12
    metrics = fetch_metrics(labels_hot, predictions_hot, labels_var, predictions_var, channels_out)
13
14
    if b_verbose:
15
        metrics = tf.Print(metrics, [tf.shape(labels_hot), tf.shape(predictions_hot)], 'fetched metrics with labels/preds: ', summarize=20)
16
17
    # Get the values of the metrics (used for update later)
18
    metrics_values = {k: v[0] for k, v in metrics.items()}
19
20
    # Group the update ops for the tf.metrics, so that we can run only one op to update them all
21
    update_metrics_op = tf.group(*[op for _, op in metrics.values()])
22
23
    # Get the op to reset the local variables used in tf.metrics, for when we restart an epoch
24
    scope_full = os.path.join(tf.get_default_graph().get_name_scope())
25
    metric_variables = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope=scope_full)
26
    metrics_init_op = tf.variables_initializer(metric_variables)
27
28
    return metrics_init_op, update_metrics_op, metrics_values
29
30
31
def fetch_metrics(labels_hot, predictions_hot, labels_var, predictions_var, channels_out):
32
    """ metrics used for calculation """
33
34
    dice_scores = calc_dice_scores(labels_hot, predictions_hot)
35
    single_dice_scores = tf.split(dice_scores, channels_out, axis=-1)
36
37
    # generate masks for proper mean scores
38
    mean_masks = tf.clip_by_value(tf.count_nonzero(labels_hot, axis=[1, 2, 3]), 0, 1)
39
    masks = tf.split(mean_masks, channels_out, axis=-1)
40
    mean_dice_scores = tf.reduce_mean(tf.boolean_mask(dice_scores, mean_masks), axis=-1)
41
42
    # average metrics
43
    metrics = {
44
            'accuracy': tf.metrics.accuracy(labels_hot, predictions_hot),
45
            'mean_pc_acc': tf.metrics.mean_per_class_accuracy(labels_var, predictions_var, 3),
46
            'mean_iou': tf.metrics.mean_iou(labels_hot, predictions_hot, 3),
47
            'mean_dice': tf.metrics.mean(mean_dice_scores)}
48
49
    # single class metrics
50
    for idx_ch in range(channels_out):
51
        metrics[f'dice_c{idx_ch}'] = tf.metrics.mean(single_dice_scores[idx_ch], weights=masks[idx_ch])
52
53
    return metrics
54
55
56
def calc_dice_scores(labels_hot, predictions_hot, eps=1e-12):
57
58
    nom = 2 * tf.reduce_sum(tf.cast(tf.logical_and(tf.cast(labels_hot, dtype=tf.bool), tf.cast(predictions_hot, dtype=tf.bool)), dtype=tf.float32), axis=(1, 2, 3))
59
    denom = tf.reduce_sum(labels_hot + predictions_hot, axis=(1, 2, 3))
60
    scores = tf.divide(nom, denom + eps)
61
62
    return scores