|
a |
|
b/medseg_dl/model/metrics.py |
|
|
1 |
import tensorflow as tf |
|
|
2 |
import os |
|
|
3 |
|
|
|
4 |
|
|
|
5 |
def metrics_fn(labels, probs, channels_out, b_verbose=False): |
|
|
6 |
""" calc metrics and fetch ops """ |
|
|
7 |
|
|
|
8 |
labels_var = tf.argmax(labels, axis=-1) |
|
|
9 |
labels_hot = tf.one_hot(labels_var, channels_out) |
|
|
10 |
predictions_var = tf.argmax(probs, axis=-1) |
|
|
11 |
predictions_hot = tf.one_hot(predictions_var, channels_out) |
|
|
12 |
metrics = fetch_metrics(labels_hot, predictions_hot, labels_var, predictions_var, channels_out) |
|
|
13 |
|
|
|
14 |
if b_verbose: |
|
|
15 |
metrics = tf.Print(metrics, [tf.shape(labels_hot), tf.shape(predictions_hot)], 'fetched metrics with labels/preds: ', summarize=20) |
|
|
16 |
|
|
|
17 |
# Get the values of the metrics (used for update later) |
|
|
18 |
metrics_values = {k: v[0] for k, v in metrics.items()} |
|
|
19 |
|
|
|
20 |
# Group the update ops for the tf.metrics, so that we can run only one op to update them all |
|
|
21 |
update_metrics_op = tf.group(*[op for _, op in metrics.values()]) |
|
|
22 |
|
|
|
23 |
# Get the op to reset the local variables used in tf.metrics, for when we restart an epoch |
|
|
24 |
scope_full = os.path.join(tf.get_default_graph().get_name_scope()) |
|
|
25 |
metric_variables = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope=scope_full) |
|
|
26 |
metrics_init_op = tf.variables_initializer(metric_variables) |
|
|
27 |
|
|
|
28 |
return metrics_init_op, update_metrics_op, metrics_values |
|
|
29 |
|
|
|
30 |
|
|
|
31 |
def fetch_metrics(labels_hot, predictions_hot, labels_var, predictions_var, channels_out): |
|
|
32 |
""" metrics used for calculation """ |
|
|
33 |
|
|
|
34 |
dice_scores = calc_dice_scores(labels_hot, predictions_hot) |
|
|
35 |
single_dice_scores = tf.split(dice_scores, channels_out, axis=-1) |
|
|
36 |
|
|
|
37 |
# generate masks for proper mean scores |
|
|
38 |
mean_masks = tf.clip_by_value(tf.count_nonzero(labels_hot, axis=[1, 2, 3]), 0, 1) |
|
|
39 |
masks = tf.split(mean_masks, channels_out, axis=-1) |
|
|
40 |
mean_dice_scores = tf.reduce_mean(tf.boolean_mask(dice_scores, mean_masks), axis=-1) |
|
|
41 |
|
|
|
42 |
# average metrics |
|
|
43 |
metrics = { |
|
|
44 |
'accuracy': tf.metrics.accuracy(labels_hot, predictions_hot), |
|
|
45 |
'mean_pc_acc': tf.metrics.mean_per_class_accuracy(labels_var, predictions_var, 3), |
|
|
46 |
'mean_iou': tf.metrics.mean_iou(labels_hot, predictions_hot, 3), |
|
|
47 |
'mean_dice': tf.metrics.mean(mean_dice_scores)} |
|
|
48 |
|
|
|
49 |
# single class metrics |
|
|
50 |
for idx_ch in range(channels_out): |
|
|
51 |
metrics[f'dice_c{idx_ch}'] = tf.metrics.mean(single_dice_scores[idx_ch], weights=masks[idx_ch]) |
|
|
52 |
|
|
|
53 |
return metrics |
|
|
54 |
|
|
|
55 |
|
|
|
56 |
def calc_dice_scores(labels_hot, predictions_hot, eps=1e-12): |
|
|
57 |
|
|
|
58 |
nom = 2 * tf.reduce_sum(tf.cast(tf.logical_and(tf.cast(labels_hot, dtype=tf.bool), tf.cast(predictions_hot, dtype=tf.bool)), dtype=tf.float32), axis=(1, 2, 3)) |
|
|
59 |
denom = tf.reduce_sum(labels_hot + predictions_hot, axis=(1, 2, 3)) |
|
|
60 |
scores = tf.divide(nom, denom + eps) |
|
|
61 |
|
|
|
62 |
return scores |