Switch to side-by-side view

--- a
+++ b/RefineNet & SESNet/utils/training.py
@@ -0,0 +1,170 @@
+import tensorflow as tf
+
+
+def get_labels_from_annotation(annotation_tensor, class_labels):
+    """Returns tensor of size (width, height, num_classes) derived from annotation tensor.
+    The function returns tensor that is of a size (width, height, num_classes) which
+    is derived from annotation tensor with sizes (width, height) where value at
+    each position represents a class. The functions requires a list with class
+    values like [0, 1, 2 ,3] -- they are used to derive labels. Derived values will
+    be ordered in the same way as the class numbers were provided in the list. Last
+    value in the aforementioned list represents a value that indicate that the pixel
+    should be masked out. So, the size of num_classes := len(class_labels) - 1.
+    
+    Parameters
+    ----------
+    annotation_tensor : Tensor of size (width, height)
+        Tensor with class labels for each element
+    class_labels : list of ints
+        List that contains the numbers that represent classes. Last
+        value in the list should represent the number that was used
+        for masking out.
+        
+    Returns
+    -------
+    labels_2d_stacked : Tensor of size (width, height, num_classes).
+        Tensor with labels for each pixel.
+    """
+    
+    # Last value in the classes list should show
+    # which number was used in the annotation to mask out
+    # the ambigious regions or regions that should not be
+    # used for training.
+    # TODO: probably replace class_labels list with some custom object
+    valid_entries_class_labels = class_labels
+    
+    # Stack the binary masks for each class
+    labels_2d = list(map(lambda x: tf.equal(annotation_tensor, x),
+                    valid_entries_class_labels))
+
+    # Perform the merging of all of the binary masks into one matrix
+    labels_2d_stacked = tf.stack(labels_2d, axis=2)
+    
+    # Convert tf.bool to tf.float
+    # Later on in the labels and logits will be used
+    # in tf.softmax_cross_entropy_with_logits() function
+    # where they have to be of the float type.
+    labels_2d_stacked_float = tf.to_float(labels_2d_stacked)
+    
+    return labels_2d_stacked_float
+
+def get_labels_from_annotation_batch(annotation_batch_tensor, class_labels):
+    """Returns tensor of size (batch_size, width, height, num_classes) derived
+    from annotation batch tensor. The function returns tensor that is of a size
+    (batch_size, width, height, num_classes) which is derived from annotation tensor
+    with sizes (batch_size, width, height) where value at each position represents a class.
+    The functions requires a list with class values like [0, 1, 2 ,3] -- they are
+    used to derive labels. Derived values will be ordered in the same way as
+    the class numbers were provided in the list. Last value in the aforementioned
+    list represents a value that indicate that the pixel should be masked out.
+    So, the size of num_classes len(class_labels) - 1.
+    
+    Parameters
+    ----------
+    annotation_batch_tensor : Tensor of size (batch_size, width, height)
+        Tensor with class labels for each element
+    class_labels : list of ints
+        List that contains the numbers that represent classes. Last
+        value in the list should represent the number that was used
+        for masking out.
+        
+    Returns
+    -------
+    batch_labels : Tensor of size (batch_size, width, height, num_classes).
+        Tensor with labels for each batch.
+    """
+    
+    batch_labels = tf.map_fn(fn=lambda x: get_labels_from_annotation(annotation_tensor=x, class_labels=class_labels),
+                             elems=annotation_batch_tensor,
+                             dtype=tf.float32)
+    
+    return batch_labels
+
+def get_valid_entries_indices_from_annotation_batch(annotation_batch_tensor, class_labels):
+    """Returns tensor of size (num_valid_eintries, 3).
+    Returns tensor that contains the indices of valid entries according
+    to the annotation tensor. This can be used to later on extract only
+    valid entries from logits tensor and labels tensor. This function is
+    supposed to work with a batch input like [b, w, h] -- where b is a
+    batch size, w, h -- are width and height sizes. So the output is
+    a tensor which contains indexes of valid entries. This function can
+    also work with a single annotation like [w, h] -- the output will
+    be (num_valid_eintries, 2).
+    
+    Parameters
+    ----------
+    annotation_batch_tensor : Tensor of size (batch_size, width, height)
+        Tensor with class labels for each batch
+    class_labels : list of ints
+        List that contains the numbers that represent classes. Last
+        value in the list should represent the number that was used
+        for masking out.
+        
+    Returns
+    -------
+    valid_labels_indices : Tensor of size (num_valid_eintries, 3).
+        Tensor with indices of valid entries
+    """
+    
+    # Last value in the classes list should show
+    # which number was used in the annotation to mask out
+    # the ambigious regions or regions that should not be
+    # used for training.
+    # TODO: probably replace class_labels list with some custom object
+    mask_out_class_label = 255
+    
+    # Get binary mask for the pixels that we want to
+    # use for training. We do this because some pixels
+    # are marked as ambigious and we don't want to use
+    # them for trainig to avoid confusing the model
+    valid_labels_mask = tf.not_equal(annotation_batch_tensor,
+                                        mask_out_class_label)
+    
+    valid_labels_indices = tf.where(valid_labels_mask)
+    
+    return tf.to_int32(valid_labels_indices)
+
+
+def get_valid_logits_and_labels(annotation_batch_tensor,
+                                logits_batch_tensor,
+                                class_labels):
+    """Returns two tensors of size (num_valid_entries, num_classes).
+    The function converts annotation batch tensor input of the size
+    (batch_size, height, width) into label tensor (batch_size, height,
+    width, num_classes) and then selects only valid entries, resulting
+    in tensor of the size (num_valid_entries, num_classes). The function
+    also returns the tensor with corresponding valid entries in the logits
+    tensor. Overall, two tensors of the same sizes are returned and later on
+    can be used as an input into tf.softmax_cross_entropy_with_logits() to
+    get the cross entropy error for each entry.
+    
+    Parameters
+    ----------
+    annotation_batch_tensor : Tensor of size (batch_size, width, height)
+        Tensor with class labels for each batch
+    logits_batch_tensor : Tensor of size (batch_size, width, height, num_classes)
+        Tensor with logits. Usually can be achived after inference of fcn network.
+    class_labels : list of ints
+        List that contains the numbers that represent classes. Last
+        value in the list should represent the number that was used
+        for masking out.
+        
+    Returns
+    -------
+    (valid_labels_batch_tensor, valid_logits_batch_tensor) : Two Tensors of size (num_valid_eintries, num_classes).
+        Tensors that represent valid labels and logits.
+    """
+    
+    
+    labels_batch_tensor = get_labels_from_annotation_batch(annotation_batch_tensor=annotation_batch_tensor,
+                                                           class_labels=class_labels)
+    
+    # valid_batch_indices = get_valid_entries_indices_from_annotation_batch(annotation_batch_tensor=annotation_batch_tensor,
+    #                                                                       class_labels=class_labels)
+    #
+    # valid_labels_batch_tensor = tf.gather_nd(params=labels_batch_tensor, indices=valid_batch_indices)
+    #
+    # valid_logits_batch_tensor = tf.gather_nd(params=logits_batch_tensor, indices=valid_batch_indices)
+    
+    # return valid_labels_batch_tensor, valid_logits_batch_tensor
+    return labels_batch_tensor, logits_batch_tensor
\ No newline at end of file