--- a
+++ b/medseg_dl/model/metrics.py
@@ -0,0 +1,62 @@
+import tensorflow as tf
+import os
+
+
+def metrics_fn(labels, probs, channels_out, b_verbose=False):
+    """ calc metrics and fetch ops """
+
+    labels_var = tf.argmax(labels, axis=-1)
+    labels_hot = tf.one_hot(labels_var, channels_out)
+    predictions_var = tf.argmax(probs, axis=-1)
+    predictions_hot = tf.one_hot(predictions_var, channels_out)
+    metrics = fetch_metrics(labels_hot, predictions_hot, labels_var, predictions_var, channels_out)
+
+    if b_verbose:
+        metrics = tf.Print(metrics, [tf.shape(labels_hot), tf.shape(predictions_hot)], 'fetched metrics with labels/preds: ', summarize=20)
+
+    # Get the values of the metrics (used for update later)
+    metrics_values = {k: v[0] for k, v in metrics.items()}
+
+    # Group the update ops for the tf.metrics, so that we can run only one op to update them all
+    update_metrics_op = tf.group(*[op for _, op in metrics.values()])
+
+    # Get the op to reset the local variables used in tf.metrics, for when we restart an epoch
+    scope_full = os.path.join(tf.get_default_graph().get_name_scope())
+    metric_variables = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope=scope_full)
+    metrics_init_op = tf.variables_initializer(metric_variables)
+
+    return metrics_init_op, update_metrics_op, metrics_values
+
+
+def fetch_metrics(labels_hot, predictions_hot, labels_var, predictions_var, channels_out):
+    """ metrics used for calculation """
+
+    dice_scores = calc_dice_scores(labels_hot, predictions_hot)
+    single_dice_scores = tf.split(dice_scores, channels_out, axis=-1)
+
+    # generate masks for proper mean scores
+    mean_masks = tf.clip_by_value(tf.count_nonzero(labels_hot, axis=[1, 2, 3]), 0, 1)
+    masks = tf.split(mean_masks, channels_out, axis=-1)
+    mean_dice_scores = tf.reduce_mean(tf.boolean_mask(dice_scores, mean_masks), axis=-1)
+
+    # average metrics
+    metrics = {
+            'accuracy': tf.metrics.accuracy(labels_hot, predictions_hot),
+            'mean_pc_acc': tf.metrics.mean_per_class_accuracy(labels_var, predictions_var, 3),
+            'mean_iou': tf.metrics.mean_iou(labels_hot, predictions_hot, 3),
+            'mean_dice': tf.metrics.mean(mean_dice_scores)}
+
+    # single class metrics
+    for idx_ch in range(channels_out):
+        metrics[f'dice_c{idx_ch}'] = tf.metrics.mean(single_dice_scores[idx_ch], weights=masks[idx_ch])
+
+    return metrics
+
+
+def calc_dice_scores(labels_hot, predictions_hot, eps=1e-12):
+
+    nom = 2 * tf.reduce_sum(tf.cast(tf.logical_and(tf.cast(labels_hot, dtype=tf.bool), tf.cast(predictions_hot, dtype=tf.bool)), dtype=tf.float32), axis=(1, 2, 3))
+    denom = tf.reduce_sum(labels_hot + predictions_hot, axis=(1, 2, 3))
+    scores = tf.divide(nom, denom + eps)
+
+    return scores