Diff of /models/mrcnn.py [000000] .. [bb7f56]

Switch to side-by-side view

--- a
+++ b/models/mrcnn.py
@@ -0,0 +1,1181 @@
+#!/usr/bin/env python
+# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""
+Parts are based on https://github.com/multimodallearning/pytorch-mask-rcnn
+published under MIT license.
+"""
+import sys
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils
+
+sys.path.append("..")
+import utils.model_utils as mutils
+import utils.exp_utils as utils
+from custom_extensions.nms import nms
+from custom_extensions.roi_align import roi_align
+
+############################################################
+# Networks on top of backbone
+############################################################
+
+class RPN(nn.Module):
+    """
+    Region Proposal Network.
+    """
+
+    def __init__(self, cf, conv):
+
+        super(RPN, self).__init__()
+        self.dim = conv.dim
+
+        self.conv_shared = conv(cf.end_filts, cf.n_rpn_features, ks=3, stride=cf.rpn_anchor_stride, pad=1, relu=cf.relu)
+        self.conv_class = conv(cf.n_rpn_features, 2 * len(cf.rpn_anchor_ratios), ks=1, stride=1, relu=None)
+        self.conv_bbox = conv(cf.n_rpn_features, 2 * self.dim * len(cf.rpn_anchor_ratios), ks=1, stride=1, relu=None)
+
+
+    def forward(self, x):
+        """
+        :param x: input feature maps (b, in_channels, y, x, (z))
+        :return: rpn_class_logits (b, 2, n_anchors)
+        :return: rpn_probs_logits (b, 2, n_anchors)
+        :return: rpn_bbox (b, 2 * dim, n_anchors)
+        """
+
+        # Shared convolutional base of the RPN.
+        x = self.conv_shared(x)
+
+        # Anchor Score. (batch, anchors per location * 2, y, x, (z)).
+        rpn_class_logits = self.conv_class(x)
+        # Reshape to (batch, 2, anchors)
+        axes = (0, 2, 3, 1) if self.dim == 2 else (0, 2, 3, 4, 1)
+        rpn_class_logits = rpn_class_logits.permute(*axes)
+        rpn_class_logits = rpn_class_logits.contiguous()
+        rpn_class_logits = rpn_class_logits.view(x.size()[0], -1, 2)
+
+        # Softmax on last dimension (fg vs. bg).
+        rpn_probs = F.softmax(rpn_class_logits, dim=2)
+
+        # Bounding box refinement. (batch, anchors_per_location * (y, x, (z), log(h), log(w), (log(d)), y, x, (z))
+        rpn_bbox = self.conv_bbox(x)
+
+        # Reshape to (batch, 2*dim, anchors)
+        rpn_bbox = rpn_bbox.permute(*axes)
+        rpn_bbox = rpn_bbox.contiguous()
+        rpn_bbox = rpn_bbox.view(x.size()[0], -1, self.dim * 2)
+
+        return [rpn_class_logits, rpn_probs, rpn_bbox]
+
+
+
+class Classifier(nn.Module):
+    """
+    Head network for classification and bounding box refinement. Performs RoiAlign, processes resulting features through a
+    shared convolutional base and finally branches off the classifier- and regression head.
+    """
+    def __init__(self, cf, conv):
+        super(Classifier, self).__init__()
+
+        self.dim = conv.dim
+        self.in_channels = cf.end_filts
+        self.pool_size = cf.pool_size
+        self.pyramid_levels = cf.pyramid_levels
+        # instance_norm does not work with spatial dims (1, 1, (1))
+        norm = cf.norm if cf.norm != 'instance_norm' else None
+
+        self.conv1 = conv(cf.end_filts, cf.end_filts * 4, ks=self.pool_size, stride=1, norm=norm, relu=cf.relu)
+        self.conv2 = conv(cf.end_filts * 4, cf.end_filts * 4, ks=1, stride=1, norm=norm, relu=cf.relu)
+        self.linear_class = nn.Linear(cf.end_filts * 4, cf.head_classes)
+        self.linear_bbox = nn.Linear(cf.end_filts * 4, cf.head_classes * 2 * self.dim)
+
+    def forward(self, x, rois):
+        """
+        :param x: input feature maps (b, in_channels, y, x, (z))
+        :param rois: normalized box coordinates as proposed by the RPN to be forwarded through
+        the second stage (n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ix). Proposals of all batch elements
+        have been merged to one vector, while the origin info has been stored for re-allocation.
+        :return: mrcnn_class_logits (n_proposals, n_head_classes)
+        :return: mrcnn_bbox (n_proposals, n_head_classes, 2 * dim) predicted corrections to be applied to proposals for refinement.
+        """
+        x = pyramid_roi_align(x, rois, self.pool_size, self.pyramid_levels, self.dim)
+        x = self.conv1(x)
+        x = self.conv2(x)
+        x = x.view(-1, self.in_channels * 4)
+        mrcnn_class_logits = self.linear_class(x)
+        mrcnn_bbox = self.linear_bbox(x)
+        mrcnn_bbox = mrcnn_bbox.view(mrcnn_bbox.size()[0], -1, self.dim * 2)
+
+        return [mrcnn_class_logits, mrcnn_bbox]
+
+
+
+class Mask(nn.Module):
+    """
+    Head network for proposal-based mask segmentation. Performs RoiAlign, some convolutions and applies sigmoid on the
+    output logits to allow for overlapping classes.
+    """
+    def __init__(self, cf, conv):
+        super(Mask, self).__init__()
+        self.pool_size = cf.mask_pool_size
+        self.pyramid_levels = cf.pyramid_levels
+        self.dim = conv.dim
+        self.conv1 = conv(cf.end_filts, cf.end_filts, ks=3, stride=1, pad=1, norm=cf.norm, relu=cf.relu)
+        self.conv2 = conv(cf.end_filts, cf.end_filts, ks=3, stride=1, pad=1, norm=cf.norm, relu=cf.relu)
+        self.conv3 = conv(cf.end_filts, cf.end_filts, ks=3, stride=1, pad=1, norm=cf.norm, relu=cf.relu)
+        self.conv4 = conv(cf.end_filts, cf.end_filts, ks=3, stride=1, pad=1, norm=cf.norm, relu=cf.relu)
+        if conv.dim == 2:
+            self.deconv = nn.ConvTranspose2d(cf.end_filts, cf.end_filts, kernel_size=2, stride=2)
+        else:
+            self.deconv = nn.ConvTranspose3d(cf.end_filts, cf.end_filts, kernel_size=2, stride=2)
+
+        self.relu = nn.ReLU(inplace=True) if cf.relu == 'relu' else nn.LeakyReLU(inplace=True)
+        self.conv5 = conv(cf.end_filts, cf.head_classes, ks=1, stride=1, relu=None)
+        self.sigmoid = nn.Sigmoid()
+
+    def forward(self, x, rois):
+        """
+        :param x: input feature maps (b, in_channels, y, x, (z))
+        :param rois: normalized box coordinates as proposed by the RPN to be forwarded through
+        the second stage (n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ix). Proposals of all batch elements
+        have been merged to one vector, while the origin info has been stored for re-allocation.
+        :return: x: masks (n_sampled_proposals (n_detections in inference), n_classes, y, x, (z))
+        """
+        x = pyramid_roi_align(x, rois, self.pool_size, self.pyramid_levels, self.dim)
+        x = self.conv1(x)
+        x = self.conv2(x)
+        x = self.conv3(x)
+        x = self.conv4(x)
+        x = self.relu(self.deconv(x))
+        x = self.conv5(x)
+        x = self.sigmoid(x)
+        return x
+
+
+############################################################
+#  Loss Functions
+############################################################
+
+def compute_rpn_class_loss(rpn_match, rpn_class_logits, shem_poolsize):
+    """
+    :param rpn_match: (n_anchors). [-1, 0, 1] for negative, neutral, and positive matched anchors.
+    :param rpn_class_logits: (n_anchors, 2). logits from RPN classifier.
+    :param shem_poolsize: int. factor of top-k candidates to draw from per negative sample
+    (stochastic-hard-example-mining).
+    :return: loss: torch tensor
+    :return: np_neg_ix: 1D array containing indices of the neg_roi_logits, which have been sampled for training.
+    """
+
+    # filter out neutral anchors.
+    pos_indices = torch.nonzero(rpn_match == 1)
+    neg_indices = torch.nonzero(rpn_match == -1)
+
+    # loss for positive samples
+    if 0 not in pos_indices.size():
+        pos_indices = pos_indices.squeeze(1)
+        roi_logits_pos = rpn_class_logits[pos_indices]
+        pos_loss = F.cross_entropy(roi_logits_pos, torch.LongTensor([1] * pos_indices.shape[0]).cuda())
+    else:
+        pos_loss = torch.FloatTensor([0]).cuda()
+
+    # loss for negative samples: draw hard negative examples (SHEM)
+    # that match the number of positive samples, but at least 1.
+    if 0 not in neg_indices.size():
+        neg_indices = neg_indices.squeeze(1)
+        roi_logits_neg = rpn_class_logits[neg_indices]
+        negative_count = np.max((1, pos_indices.cpu().data.numpy().size))
+        roi_probs_neg = F.softmax(roi_logits_neg, dim=1)
+        neg_ix = mutils.shem(roi_probs_neg, negative_count, shem_poolsize)
+        neg_loss = F.cross_entropy(roi_logits_neg[neg_ix], torch.LongTensor([0] * neg_ix.shape[0]).cuda())
+        np_neg_ix = neg_ix.cpu().data.numpy()
+    else:
+        neg_loss = torch.FloatTensor([0]).cuda()
+        np_neg_ix = np.array([]).astype('int32')
+
+    loss = (pos_loss + neg_loss) / 2
+    return loss, np_neg_ix
+
+
+def compute_rpn_bbox_loss(rpn_target_deltas, rpn_pred_deltas, rpn_match):
+    """
+    :param rpn_target_deltas:   (b, n_positive_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd)))).
+    Uses 0 padding to fill in unsed bbox deltas.
+    :param rpn_pred_deltas: predicted deltas from RPN. (b, n_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd))))
+    :param rpn_match: (n_anchors). [-1, 0, 1] for negative, neutral, and positive matched anchors.
+    :return: loss: torch 1D tensor.
+    """
+    if 0 not in torch.nonzero(rpn_match == 1).size():
+
+        indices = torch.nonzero(rpn_match == 1).squeeze(1)
+        # Pick bbox deltas that contribute to the loss
+        rpn_pred_deltas = rpn_pred_deltas[indices]
+        # Trim target bounding box deltas to the same length as rpn_bbox.
+        target_deltas = rpn_target_deltas[:rpn_pred_deltas.size()[0], :]
+        # Smooth L1 loss
+        loss = F.smooth_l1_loss(rpn_pred_deltas, target_deltas)
+    else:
+        loss = torch.FloatTensor([0]).cuda()
+
+    return loss
+
+
+def compute_mrcnn_class_loss(target_class_ids, pred_class_logits):
+    """
+    :param target_class_ids: (n_sampled_rois) batch dimension was merged into roi dimension.
+    :param pred_class_logits: (n_sampled_rois, n_classes)
+    :return: loss: torch 1D tensor.
+    """
+    if 0 not in target_class_ids.size():
+        loss = F.cross_entropy(pred_class_logits, target_class_ids.long())
+    else:
+        loss = torch.FloatTensor([0.]).cuda()
+
+    return loss
+
+
+def compute_mrcnn_bbox_loss(mrcnn_target_deltas, mrcnn_pred_deltas, target_class_ids):
+    """
+    :param mrcnn_target_deltas: (n_sampled_rois, (dy, dx, (dz), log(dh), log(dw), (log(dh)))
+    :param mrcnn_pred_deltas: (n_sampled_rois, n_classes, (dy, dx, (dz), log(dh), log(dw), (log(dh)))
+    :param target_class_ids: (n_sampled_rois)
+    :return: loss: torch 1D tensor.
+    """
+    if 0 not in torch.nonzero(target_class_ids > 0).size():
+        positive_roi_ix = torch.nonzero(target_class_ids > 0)[:, 0]
+        positive_roi_class_ids = target_class_ids[positive_roi_ix].long()
+        target_bbox = mrcnn_target_deltas[positive_roi_ix, :].detach()
+        pred_bbox = mrcnn_pred_deltas[positive_roi_ix, positive_roi_class_ids, :]
+        loss = F.smooth_l1_loss(pred_bbox, target_bbox)
+    else:
+        loss = torch.FloatTensor([0]).cuda()
+
+    return loss
+
+
+def compute_mrcnn_mask_loss(target_masks, pred_masks, target_class_ids):
+    """
+    :param target_masks: (n_sampled_rois, y, x, (z)) A float32 tensor of values 0 or 1. Uses zero padding to fill array.
+    :param pred_masks: (n_sampled_rois, n_classes, y, x, (z)) float32 tensor with values between [0, 1].
+    :param target_class_ids: (n_sampled_rois)
+    :return: loss: torch 1D tensor.
+    """
+    if 0 not in torch.nonzero(target_class_ids > 0).size():
+        # Only positive ROIs contribute to the loss. And only
+        # the class specific mask of each ROI.
+        positive_ix = torch.nonzero(target_class_ids > 0)[:, 0]
+        positive_class_ids = target_class_ids[positive_ix].long()
+        y_true = target_masks[positive_ix, :, :].detach()
+        y_pred = pred_masks[positive_ix, positive_class_ids, :, :]
+        loss = F.binary_cross_entropy(y_pred, y_true)
+    else:
+        loss = torch.FloatTensor([0]).cuda()
+
+    return loss
+
+
+############################################################
+#  Helper Layers
+############################################################
+
+def refine_proposals(rpn_pred_probs, rpn_pred_deltas, proposal_count, batch_anchors, cf):
+    """
+    Receives anchor scores and selects a subset to pass as proposals
+    to the second stage. Filtering is done based on anchor scores and
+    non-max suppression to remove overlaps. It also applies bounding
+    box refinment details to anchors.
+    :param rpn_pred_probs: (b, n_anchors, 2)
+    :param rpn_pred_deltas: (b, n_anchors, (y, x, (z), log(h), log(w), (log(d))))
+    :return: batch_normalized_props: Proposals in normalized coordinates (b, proposal_count, (y1, x1, y2, x2, (z1), (z2), score))
+    :return: batch_out_proposals: Box coords + RPN foreground scores
+    for monitoring/plotting (b, proposal_count, (y1, x1, y2, x2, (z1), (z2), score))
+    """
+    std_dev = torch.from_numpy(cf.rpn_bbox_std_dev[None]).float().cuda()
+    norm = torch.from_numpy(cf.scale).float().cuda()
+    anchors = batch_anchors.clone()
+
+
+
+    batch_scores = rpn_pred_probs[:, :, 1]
+    # norm deltas
+    batch_deltas = rpn_pred_deltas * std_dev
+    batch_normalized_props = []
+    batch_out_proposals = []
+
+    # loop over batch dimension.
+    for ix in range(batch_scores.shape[0]):
+
+        scores = batch_scores[ix]
+        deltas = batch_deltas[ix]
+
+        # improve performance by trimming to top anchors by score
+        # and doing the rest on the smaller subset.
+        pre_nms_limit = min(cf.pre_nms_limit, anchors.size()[0])
+        scores, order = scores.sort(descending=True)
+        order = order[:pre_nms_limit]
+        scores = scores[:pre_nms_limit]
+        deltas = deltas[order, :]
+
+        # apply deltas to anchors to get refined anchors and filter with non-maximum suppression.
+        if batch_deltas.shape[-1] == 4:
+            boxes = mutils.apply_box_deltas_2D(anchors[order, :], deltas)
+            boxes = mutils.clip_boxes_2D(boxes, cf.window)
+        else:
+            boxes = mutils.apply_box_deltas_3D(anchors[order, :], deltas)
+            boxes = mutils.clip_boxes_3D(boxes, cf.window)
+        # boxes are y1,x1,y2,x2, torchvision-nms requires x1,y1,x2,y2, but consistent swap x<->y is irrelevant.
+        keep = nms.nms(boxes, scores, cf.rpn_nms_threshold)
+
+
+        keep = keep[:proposal_count]
+        boxes = boxes[keep, :]
+        rpn_scores = scores[keep][:, None]
+
+        # pad missing boxes with 0.
+        if boxes.shape[0] < proposal_count:
+            n_pad_boxes = proposal_count - boxes.shape[0]
+            zeros = torch.zeros([n_pad_boxes, boxes.shape[1]]).cuda()
+            boxes = torch.cat([boxes, zeros], dim=0)
+            zeros = torch.zeros([n_pad_boxes, rpn_scores.shape[1]]).cuda()
+            rpn_scores = torch.cat([rpn_scores, zeros], dim=0)
+
+        # concat box and score info for monitoring/plotting.
+        batch_out_proposals.append(torch.cat((boxes, rpn_scores), 1).cpu().data.numpy())
+        # normalize dimensions to range of 0 to 1.
+        normalized_boxes = boxes / norm
+        assert torch.all(normalized_boxes <= 1), "normalized box coords >1 found"
+
+        # add again batch dimension
+        batch_normalized_props.append(normalized_boxes.unsqueeze(0))
+
+    batch_normalized_props = torch.cat(batch_normalized_props)
+    batch_out_proposals = np.array(batch_out_proposals)
+
+    return batch_normalized_props, batch_out_proposals
+
+
+def pyramid_roi_align(feature_maps, rois, pool_size, pyramid_levels, dim):
+    """
+    Implements ROI Pooling on multiple levels of the feature pyramid.
+    :param feature_maps: list of feature maps, each of shape (b, c, y, x , (z))
+    :param rois: proposals (normalized coords.) as returned by RPN. contain info about original batch element allocation.
+    (n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ixs)
+    :param pool_size: list of poolsizes in dims: [x, y, (z)]
+    :param pyramid_levels: list. [0, 1, 2, ...]
+    :return: pooled: pooled feature map rois (n_proposals, c, poolsize_y, poolsize_x, (poolsize_z))
+    Output:
+    Pooled regions in the shape: [num_boxes, height, width, channels].
+    The width and height are those specific in the pool_shape in the layer
+    constructor.
+    """
+    boxes = rois[:, :dim*2]
+    batch_ixs = rois[:, dim*2]
+
+    # Assign each ROI to a level in the pyramid based on the ROI area.
+    if dim == 2:
+        y1, x1, y2, x2 = boxes.chunk(4, dim=1)
+    else:
+        y1, x1, y2, x2, z1, z2 = boxes.chunk(6, dim=1)
+
+    h = y2 - y1
+    w = x2 - x1
+
+    # Equation 1 in https://arxiv.org/abs/1612.03144. Account for
+    # the fact that our coordinates are normalized here.
+    # divide sqrt(h*w) by 1 instead image_area.
+    roi_level = (4 + torch.log2(torch.sqrt(h*w))).round().int().clamp(pyramid_levels[0], pyramid_levels[-1])
+    # if Pyramid contains additional level P6, adapt the roi_level assignment accordingly.
+    if len(pyramid_levels) == 5:
+        roi_level[h*w > 0.65] = 5
+
+    # Loop through levels and apply ROI pooling to each.
+    pooled = []
+    box_to_level = []
+    fmap_shapes = [f.shape for f in feature_maps]
+    for level_ix, level in enumerate(pyramid_levels):
+        ix = roi_level == level
+        if not ix.any():
+            continue
+        ix = torch.nonzero(ix)[:, 0]
+        level_boxes = boxes[ix, :]
+        # re-assign rois to feature map of original batch element.
+        ind = batch_ixs[ix].int()
+
+        # Keep track of which box is mapped to which level
+        box_to_level.append(ix)
+
+        # Stop gradient propogation to ROI proposals
+        level_boxes = level_boxes.detach()
+        if len(pool_size) == 2:
+            # remap to feature map coordinate system
+            y_exp, x_exp = fmap_shapes[level_ix][2:]  # exp = expansion
+            level_boxes.mul_(torch.tensor([y_exp, x_exp, y_exp, x_exp], dtype=torch.float32).cuda())
+            pooled_features = roi_align.roi_align_2d(feature_maps[level_ix],
+                                                     torch.cat((ind.unsqueeze(1).float(), level_boxes), dim=1),
+                                                     pool_size)
+        else:
+            y_exp, x_exp, z_exp = fmap_shapes[level_ix][2:]
+            level_boxes.mul_(torch.tensor([y_exp, x_exp, y_exp, x_exp, z_exp, z_exp], dtype=torch.float32).cuda())
+            pooled_features = roi_align.roi_align_3d(feature_maps[level_ix],
+                                                     torch.cat((ind.unsqueeze(1).float(), level_boxes), dim=1),
+                                                     pool_size)
+        pooled.append(pooled_features)
+
+
+    # Pack pooled features into one tensor
+    pooled = torch.cat(pooled, dim=0)
+
+    # Pack box_to_level mapping into one array and add another
+    # column representing the order of pooled boxes
+    box_to_level = torch.cat(box_to_level, dim=0)
+
+    # Rearrange pooled features to match the order of the original boxes
+    _, box_to_level = torch.sort(box_to_level)
+    pooled = pooled[box_to_level, :, :]
+
+    return pooled
+
+
+def detection_target_layer(batch_proposals, batch_mrcnn_class_scores, batch_gt_class_ids, batch_gt_boxes, batch_gt_masks, cf):
+    """
+    Subsamples proposals for mrcnn losses and generates targets. Sampling is done per batch element, seems to have positive
+    effects on training, as opposed to sampling over entire batch. Negatives are sampled via stochastic-hard-example-mining
+    (SHEM), where a number of negative proposals are drawn from larger pool of highest scoring proposals for stochasticity.
+    Scoring is obtained here as the max over all foreground probabilities as returned by mrcnn_classifier (worked better than
+    loss-based class balancing methods like "online-hard-example-mining" or "focal loss".)
+    :param batch_proposals: (n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ixs).
+    boxes as proposed by RPN. n_proposals here is determined by batch_size * POST_NMS_ROIS.
+    :param batch_mrcnn_class_scores: (n_proposals, n_classes)
+    :param batch_gt_class_ids: list over batch elements. Each element is a list over the corresponding roi target labels.
+    :param batch_gt_boxes: list over batch elements. Each element is a list over the corresponding roi target coordinates.
+    :param batch_gt_masks: list over batch elements. Each element is binary mask of shape (n_gt_rois, y, x, (z), c)
+    :return: sample_indices: (n_sampled_rois) indices of sampled proposals to be used for loss functions.
+    :return: target_class_ids: (n_sampled_rois)containing target class labels of sampled proposals.
+    :return: target_deltas: (n_sampled_rois, 2 * dim) containing target deltas of sampled proposals for box refinement.
+    :return: target_masks: (n_sampled_rois, y, x, (z)) containing target masks of sampled proposals.
+    """
+    # normalization of target coordinates
+    if cf.dim == 2:
+        h, w = cf.patch_size
+        scale = torch.from_numpy(np.array([h, w, h, w])).float().cuda()
+    else:
+        h, w, z = cf.patch_size
+        scale = torch.from_numpy(np.array([h, w, h, w, z, z])).float().cuda()
+
+    positive_count = 0
+    negative_count = 0
+    sample_positive_indices = []
+    sample_negative_indices = []
+    sample_deltas = []
+    sample_masks = []
+    sample_class_ids = []
+
+    std_dev = torch.from_numpy(cf.bbox_std_dev).float().cuda()
+
+    # loop over batch and get positive and negative sample rois.
+    for b in range(len(batch_gt_class_ids)):
+
+        gt_class_ids = torch.from_numpy(batch_gt_class_ids[b]).int().cuda()
+        gt_masks = torch.from_numpy(batch_gt_masks[b]).float().cuda()
+        if np.any(batch_gt_class_ids[b] > 0):  # skip roi selection for no gt images.
+            gt_boxes = torch.from_numpy(batch_gt_boxes[b]).float().cuda() / scale
+        else:
+            gt_boxes = torch.FloatTensor().cuda()
+
+        # get proposals and indices of current batch element.
+        proposals = batch_proposals[batch_proposals[:, -1] == b][:, :-1]
+        batch_element_indices = torch.nonzero(batch_proposals[:, -1] == b).squeeze(1)
+
+        # Compute overlaps matrix [proposals, gt_boxes]
+        if 0 not in gt_boxes.size():
+            if gt_boxes.shape[1] == 4:
+                assert cf.dim == 2, "gt_boxes shape {} doesnt match cf.dim{}".format(gt_boxes.shape, cf.dim)
+                overlaps = mutils.bbox_overlaps_2D(proposals, gt_boxes)
+            else:
+                assert cf.dim == 3, "gt_boxes shape {} doesnt match cf.dim{}".format(gt_boxes.shape, cf.dim)
+                overlaps = mutils.bbox_overlaps_3D(proposals, gt_boxes)
+
+            # Determine postive and negative ROIs
+            roi_iou_max = torch.max(overlaps, dim=1)[0]
+            # 1. Positive ROIs are those with >= 0.5 IoU with a GT box
+            positive_roi_bool = roi_iou_max >= (0.5 if cf.dim == 2 else 0.3)
+            # 2. Negative ROIs are those with < 0.1 with every GT box.
+            negative_roi_bool = roi_iou_max < (0.1 if cf.dim == 2 else 0.01)
+        else:
+            positive_roi_bool = torch.FloatTensor().cuda()
+            negative_roi_bool = torch.from_numpy(np.array([1]*proposals.shape[0])).cuda()
+
+        # Sample Positive ROIs
+        if 0 not in torch.nonzero(positive_roi_bool).size():
+            positive_indices = torch.nonzero(positive_roi_bool).squeeze(1)
+            positive_samples = int(cf.train_rois_per_image * cf.roi_positive_ratio)
+            rand_idx = torch.randperm(positive_indices.size()[0])
+            rand_idx = rand_idx[:positive_samples].cuda()
+            positive_indices = positive_indices[rand_idx]
+            positive_samples = positive_indices.size()[0]
+            positive_rois = proposals[positive_indices, :]
+            # Assign positive ROIs to GT boxes.
+            positive_overlaps = overlaps[positive_indices, :]
+            roi_gt_box_assignment = torch.max(positive_overlaps, dim=1)[1]
+            roi_gt_boxes = gt_boxes[roi_gt_box_assignment, :]
+            roi_gt_class_ids = gt_class_ids[roi_gt_box_assignment]
+
+            # Compute bbox refinement targets for positive ROIs
+            deltas = mutils.box_refinement(positive_rois, roi_gt_boxes)
+            deltas /= std_dev
+
+            # Assign positive ROIs to GT masks
+            roi_masks = gt_masks[roi_gt_box_assignment]
+            assert roi_masks.shape[1] == 1, "desired to have more than one channel in gt masks?"
+
+            # Compute mask targets
+            boxes = positive_rois
+            box_ids = torch.arange(roi_masks.shape[0]).cuda().unsqueeze(1).float()
+            if len(cf.mask_shape) == 2:
+                # need to remap normalized box coordinates to unnormalized mask coordinates.
+                y_exp, x_exp = roi_masks.shape[2:]  # exp = expansion
+                boxes.mul_(torch.tensor([y_exp, x_exp, y_exp, x_exp], dtype=torch.float32).cuda())
+                masks = roi_align.roi_align_2d(roi_masks, torch.cat((box_ids, boxes), dim=1), cf.mask_shape)
+            else:
+                y_exp, x_exp, z_exp = roi_masks.shape[2:]  # exp = expansion
+                boxes.mul_(torch.tensor([y_exp, x_exp, y_exp, x_exp, z_exp, z_exp], dtype=torch.float32).cuda())
+                masks = roi_align.roi_align_3d(roi_masks, torch.cat((box_ids, boxes), dim=1), cf.mask_shape)
+            masks = masks.squeeze(1)
+            # Threshold mask pixels at 0.5 to have GT masks be 0 or 1 to use with
+            # binary cross entropy loss.
+            masks = torch.round(masks)
+
+            sample_positive_indices.append(batch_element_indices[positive_indices])
+            sample_deltas.append(deltas)
+            sample_masks.append(masks)
+            sample_class_ids.append(roi_gt_class_ids)
+            positive_count += positive_samples
+        else:
+            positive_samples = 0
+
+        # Negative ROIs. Add enough to maintain positive:negative ratio, but at least 1. Sample via SHEM.
+        if 0 not in torch.nonzero(negative_roi_bool).size():
+            negative_indices = torch.nonzero(negative_roi_bool).squeeze(1)
+            r = 1.0 / cf.roi_positive_ratio
+            b_neg_count = np.max((int(r * positive_samples - positive_samples), 1))
+            roi_probs_neg = batch_mrcnn_class_scores[batch_element_indices[negative_indices]]
+            raw_sampled_indices = mutils.shem(roi_probs_neg, b_neg_count, cf.shem_poolsize)
+            sample_negative_indices.append(batch_element_indices[negative_indices[raw_sampled_indices]])
+            negative_count += raw_sampled_indices.size()[0]
+
+    if len(sample_positive_indices) > 0:
+        target_deltas = torch.cat(sample_deltas)
+        target_masks = torch.cat(sample_masks)
+        target_class_ids = torch.cat(sample_class_ids)
+
+    # Pad target information with zeros for negative ROIs.
+    if positive_count > 0 and negative_count > 0:
+        sample_indices = torch.cat((torch.cat(sample_positive_indices), torch.cat(sample_negative_indices)), dim=0)
+        zeros = torch.zeros(negative_count).int().cuda()
+        target_class_ids = torch.cat([target_class_ids, zeros], dim=0)
+        zeros = torch.zeros(negative_count, cf.dim * 2).cuda()
+        target_deltas = torch.cat([target_deltas, zeros], dim=0)
+        zeros = torch.zeros(negative_count, *cf.mask_shape).cuda()
+        target_masks = torch.cat([target_masks, zeros], dim=0)
+    elif positive_count > 0:
+        sample_indices = torch.cat(sample_positive_indices)
+    elif negative_count > 0:
+        sample_indices = torch.cat(sample_negative_indices)
+        zeros = torch.zeros(negative_count).int().cuda()
+        target_class_ids = zeros
+        zeros = torch.zeros(negative_count, cf.dim * 2).cuda()
+        target_deltas = zeros
+        zeros = torch.zeros(negative_count, *cf.mask_shape).cuda()
+        target_masks = zeros
+    else:
+        sample_indices = torch.LongTensor().cuda()
+        target_class_ids = torch.IntTensor().cuda()
+        target_deltas = torch.FloatTensor().cuda()
+        target_masks = torch.FloatTensor().cuda()
+
+    return sample_indices, target_class_ids, target_deltas, target_masks
+
+
+############################################################
+#  Output Handler
+############################################################
+
+# def refine_detections(rois, probs, deltas, batch_ixs, cf):
+#     """
+#     Refine classified proposals, filter overlaps and return final detections.
+#
+#     :param rois: (n_proposals, 2 * dim) normalized boxes as proposed by RPN. n_proposals = batch_size * POST_NMS_ROIS
+#     :param probs: (n_proposals, n_classes) softmax probabilities for all rois as predicted by mrcnn classifier.
+#     :param deltas: (n_proposals, n_classes, 2 * dim) box refinement deltas as predicted by mrcnn bbox regressor.
+#     :param batch_ixs: (n_proposals) batch element assignemnt info for re-allocation.
+#     :return: result: (n_final_detections, (y1, x1, y2, x2, (z1), (z2), batch_ix, pred_class_id, pred_score))
+#     """
+#     # class IDs per ROI. Since scores of all classes are of interest (not just max class), all are kept at this point.
+#     class_ids = []
+#     fg_classes = cf.head_classes - 1
+#     # repeat vectors to fill in predictions for all foreground classes.
+#     for ii in range(1, fg_classes + 1):
+#         class_ids += [ii] * rois.shape[0]
+#     class_ids = torch.from_numpy(np.array(class_ids)).cuda()
+#
+#     rois = rois.repeat(fg_classes, 1)
+#     probs = probs.repeat(fg_classes, 1)
+#     deltas = deltas.repeat(fg_classes, 1, 1)
+#     batch_ixs = batch_ixs.repeat(fg_classes)
+#
+#     # get class-specific scores and  bounding box deltas
+#     idx = torch.arange(class_ids.size()[0]).long().cuda()
+#     class_scores = probs[idx, class_ids]
+#     deltas_specific = deltas[idx, class_ids]
+#     batch_ixs = batch_ixs[idx]
+#
+#     # apply bounding box deltas. re-scale to image coordinates.
+#     std_dev = torch.from_numpy(np.reshape(cf.rpn_bbox_std_dev, [1, cf.dim * 2])).float().cuda()
+#     scale = torch.from_numpy(cf.scale).float().cuda()
+#     refined_rois = mutils.apply_box_deltas_2D(rois, deltas_specific * std_dev) * scale if cf.dim == 2 else \
+#         mutils.apply_box_deltas_3D(rois, deltas_specific * std_dev) * scale
+#
+#     # round and cast to int since we're deadling with pixels now
+#     refined_rois = mutils.clip_to_window(cf.window, refined_rois)
+#     refined_rois = torch.round(refined_rois)
+#
+#     # filter out low confidence boxes
+#     keep = idx
+#     keep_bool = (class_scores >= cf.model_min_confidence)
+#     if 0 not in torch.nonzero(keep_bool).size():
+#
+#         score_keep = torch.nonzero(keep_bool)[:, 0]
+#         pre_nms_class_ids = class_ids[score_keep]
+#         pre_nms_rois = refined_rois[score_keep]
+#         pre_nms_scores = class_scores[score_keep]
+#         pre_nms_batch_ixs = batch_ixs[score_keep]
+#
+#         for j, b in enumerate(mutils.unique1d(pre_nms_batch_ixs)):
+#
+#             bixs = torch.nonzero(pre_nms_batch_ixs == b)[:, 0]
+#             bix_class_ids = pre_nms_class_ids[bixs]
+#             bix_rois = pre_nms_rois[bixs]
+#             bix_scores = pre_nms_scores[bixs]
+#
+#             for i, class_id in enumerate(mutils.unique1d(bix_class_ids)):
+#
+#                 ixs = torch.nonzero(bix_class_ids == class_id)[:, 0]
+#                 # nms expects boxes sorted by score.
+#                 ix_rois = bix_rois[ixs]
+#                 ix_scores = bix_scores[ixs]
+#                 ix_scores, order = ix_scores.sort(descending=True)
+#                 ix_rois = ix_rois[order, :]
+#
+#                 if cf.dim == 2:
+#                     class_keep = nms_2D(torch.cat((ix_rois, ix_scores.unsqueeze(1)), dim=1), cf.detection_nms_threshold)
+#                 else:
+#                     class_keep = nms_3D(torch.cat((ix_rois, ix_scores.unsqueeze(1)), dim=1), cf.detection_nms_threshold)
+#
+#                 # map indices back.
+#                 class_keep = keep[score_keep[bixs[ixs[order[class_keep]]]]]
+#                 # merge indices over classes for current batch element
+#                 b_keep = class_keep if i == 0 else mutils.unique1d(torch.cat((b_keep, class_keep)))
+#
+#             # only keep top-k boxes of current batch-element
+#             top_ids = class_scores[b_keep].sort(descending=True)[1][:cf.model_max_instances_per_batch_element]
+#             b_keep = b_keep[top_ids]
+#
+#             # merge indices over batch elements.
+#             batch_keep = b_keep if j == 0 else mutils.unique1d(torch.cat((batch_keep, b_keep)))
+#
+#         keep = batch_keep
+#
+#     else:
+#         keep = torch.tensor([0]).long().cuda()
+#
+#     # arrange output
+#     result = torch.cat((refined_rois[keep],
+#                         batch_ixs[keep].unsqueeze(1),
+#                         class_ids[keep].unsqueeze(1).float(),
+#                         class_scores[keep].unsqueeze(1)), dim=1)
+#
+#     return result
+
+def refine_detections(cf, batch_ixs, rois, deltas, scores):
+    """
+    Refine classified proposals (apply deltas to rpn rois), filter overlaps (nms) and return final detections.
+    :param rois: (n_proposals, 2 * dim) normalized boxes as proposed by RPN. n_proposals = batch_size * POST_NMS_ROIS
+    :param deltas: (n_proposals, n_classes, 2 * dim) box refinement deltas as predicted by mrcnn bbox regressor.
+    :param batch_ixs: (n_proposals) batch element assignment info for re-allocation.
+    :param scores: (n_proposals, n_classes) probabilities for all classes per roi as predicted by mrcnn classifier.
+    :return: result: (n_final_detections, (y1, x1, y2, x2, (z1), (z2), batch_ix, pred_class_id, pred_score, *regression vector features))
+    """
+    # class IDs per ROI. Since scores of all classes are of interest (not just max class), all are kept at this point.
+    class_ids = []
+    fg_classes = cf.head_classes - 1
+    # repeat vectors to fill in predictions for all foreground classes.
+    for ii in range(1, fg_classes + 1):
+        class_ids += [ii] * rois.shape[0]
+    class_ids = torch.from_numpy(np.array(class_ids)).cuda()
+
+    batch_ixs = batch_ixs.repeat(fg_classes)
+    rois = rois.repeat(fg_classes, 1)
+    deltas = deltas.repeat(fg_classes, 1, 1)
+    scores = scores.repeat(fg_classes, 1)
+
+    # get class-specific scores and  bounding box deltas
+    idx = torch.arange(class_ids.size()[0]).long().cuda()
+    # using idx instead of slice [:,] squashes first dimension.
+    #len(class_ids)>scores.shape[1] --> probs is broadcasted by expansion from fg_classes-->len(class_ids)
+    batch_ixs = batch_ixs[idx]
+    deltas_specific = deltas[idx, class_ids]
+    class_scores = scores[idx, class_ids]
+
+    # apply bounding box deltas. re-scale to image coordinates.
+    std_dev = torch.from_numpy(np.reshape(cf.rpn_bbox_std_dev, [1, cf.dim * 2])).float().cuda()
+    scale = torch.from_numpy(cf.scale).float().cuda()
+    refined_rois = mutils.apply_box_deltas_2D(rois, deltas_specific * std_dev) * scale if cf.dim == 2 else \
+        mutils.apply_box_deltas_3D(rois, deltas_specific * std_dev) * scale
+
+    # round and cast to int since we're dealing with pixels now
+    refined_rois = mutils.clip_to_window(cf.window, refined_rois)
+    refined_rois = torch.round(refined_rois)
+
+    # filter out low confidence boxes
+    keep = idx
+    keep_bool = (class_scores >= cf.model_min_confidence)
+    if not 0 in torch.nonzero(keep_bool).size():
+
+        score_keep = torch.nonzero(keep_bool)[:, 0]
+        pre_nms_class_ids = class_ids[score_keep]
+        pre_nms_rois = refined_rois[score_keep]
+        pre_nms_scores = class_scores[score_keep]
+        pre_nms_batch_ixs = batch_ixs[score_keep]
+
+        for j, b in enumerate(mutils.unique1d(pre_nms_batch_ixs)):
+
+            bixs = torch.nonzero(pre_nms_batch_ixs == b)[:, 0]
+            bix_class_ids = pre_nms_class_ids[bixs]
+            bix_rois = pre_nms_rois[bixs]
+            bix_scores = pre_nms_scores[bixs]
+
+            for i, class_id in enumerate(mutils.unique1d(bix_class_ids)):
+
+                ixs = torch.nonzero(bix_class_ids == class_id)[:, 0]
+                # nms expects boxes sorted by score.
+                ix_rois = bix_rois[ixs]
+                ix_scores = bix_scores[ixs]
+                ix_scores, order = ix_scores.sort(descending=True)
+                ix_rois = ix_rois[order, :]
+
+                class_keep = nms.nms(ix_rois, ix_scores, cf.detection_nms_threshold)
+
+                # map indices back.
+                class_keep = keep[score_keep[bixs[ixs[order[class_keep]]]]]
+                # merge indices over classes for current batch element
+                b_keep = class_keep if i == 0 else mutils.unique1d(torch.cat((b_keep, class_keep)))
+
+            # only keep top-k boxes of current batch-element
+            top_ids = class_scores[b_keep].sort(descending=True)[1][:cf.model_max_instances_per_batch_element]
+            b_keep = b_keep[top_ids]
+
+            # merge indices over batch elements.
+            batch_keep = b_keep  if j == 0 else mutils.unique1d(torch.cat((batch_keep, b_keep)))
+
+        keep = batch_keep
+
+    else:
+        keep = torch.tensor([0]).long().cuda()
+
+    # arrange output
+    output = [refined_rois[keep], batch_ixs[keep].unsqueeze(1)]
+    output += [class_ids[keep].unsqueeze(1).float(), class_scores[keep].unsqueeze(1)]
+
+    result = torch.cat(output, dim=1)
+    # shape: (n_keeps, catted feats), catted feats: [0:dim*2] are box_coords, [dim*2] are batch_ics,
+    # [dim*2+1] are class_ids, [dim*2+2] are scores, [dim*2+3:] are regression vector features (incl uncertainty)
+    return result
+
+
+def get_results(cf, img_shape, detections, detection_masks, box_results_list=None, return_masks=True):
+    """
+    Restores batch dimension of merged detections, unmolds detections, creates and fills results dict.
+    :param img_shape:
+    :param detections: (n_final_detections, (y1, x1, y2, x2, (z1), (z2), batch_ix, pred_class_id, pred_score)
+    :param detection_masks: (n_final_detections, n_classes, y, x, (z)) raw molded masks as returned by mask-head.
+    :param box_results_list: None or list of output boxes for monitoring/plotting.
+    each element is a list of boxes per batch element.
+    :param return_masks: boolean. If True, full resolution masks are returned for all proposals (speed trade-off).
+    :return: results_dict: dictionary with keys:
+             'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary:
+                      [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...]
+             'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, 1] only fg. vs. bg for now.
+             class-specific return of masks will come with implementation of instance segmentation evaluation.
+    """
+    detections = detections.cpu().data.numpy()
+    if cf.dim == 2:
+        detection_masks = detection_masks.permute(0, 2, 3, 1).cpu().data.numpy()
+    else:
+        detection_masks = detection_masks.permute(0, 2, 3, 4, 1).cpu().data.numpy()
+
+    # restore batch dimension of merged detections using the batch_ix info.
+    batch_ixs = detections[:, cf.dim*2]
+    detections = [detections[batch_ixs == ix] for ix in range(img_shape[0])]
+    mrcnn_mask = [detection_masks[batch_ixs == ix] for ix in range(img_shape[0])]
+
+    # for test_forward, where no previous list exists.
+    if box_results_list is None:
+        box_results_list = [[] for _ in range(img_shape[0])]
+
+    seg_preds = []
+    # loop over batch and unmold detections.
+    for ix in range(img_shape[0]):
+
+        if 0 not in detections[ix].shape:
+            boxes = detections[ix][:, :2 * cf.dim].astype(np.int32)
+            class_ids = detections[ix][:, 2 * cf.dim + 1].astype(np.int32)
+            scores = detections[ix][:, 2 * cf.dim + 2]
+            masks = mrcnn_mask[ix][np.arange(boxes.shape[0]), ..., class_ids]
+
+            # Filter out detections with zero area. Often only happens in early
+            # stages of training when the network weights are still a bit random.
+            if cf.dim == 2:
+                exclude_ix = np.where((boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) <= 0)[0]
+            else:
+                exclude_ix = np.where(
+                    (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 5] - boxes[:, 4]) <= 0)[0]
+
+            if exclude_ix.shape[0] > 0:
+                boxes = np.delete(boxes, exclude_ix, axis=0)
+                class_ids = np.delete(class_ids, exclude_ix, axis=0)
+                scores = np.delete(scores, exclude_ix, axis=0)
+                masks = np.delete(masks, exclude_ix, axis=0)
+
+            # Resize masks to original image size and set boundary threshold.
+            full_masks = []
+            permuted_image_shape = list(img_shape[2:]) + [img_shape[1]]
+            if return_masks:
+                for i in range(masks.shape[0]):
+                    # Convert neural network mask to full size mask.
+                    full_masks.append(mutils.unmold_mask_2D(masks[i], boxes[i], permuted_image_shape)
+                    if cf.dim == 2 else mutils.unmold_mask_3D(masks[i], boxes[i], permuted_image_shape))
+            # if masks are returned, take max over binary full masks of all predictions in this image.
+            # right now only binary masks for plotting/monitoring. for instance segmentation return all proposal masks.
+            final_masks = np.max(np.array(full_masks), 0) if len(full_masks) > 0 else np.zeros(
+                (*permuted_image_shape[:-1],))
+
+            # add final predictions to results.
+            if 0 not in boxes.shape:
+                for ix2, score in enumerate(scores):
+                    box_results_list[ix].append({'box_coords': boxes[ix2], 'box_score': score,
+                                                 'box_type': 'det', 'box_pred_class_id': class_ids[ix2]})
+        else:
+            # pad with zero dummy masks.
+            final_masks = np.zeros(img_shape[2:])
+
+        seg_preds.append(final_masks)
+
+    # create and fill results dictionary.
+    results_dict = {'boxes': box_results_list,
+                    'seg_preds': np.round(np.array(seg_preds))[:, np.newaxis].astype('uint8')}
+
+    return results_dict
+
+
+############################################################
+#  Mask R-CNN Class
+############################################################
+
+class net(nn.Module):
+
+
+    def __init__(self, cf, logger):
+
+        super(net, self).__init__()
+        self.cf = cf
+        self.logger = logger
+        self.build()
+
+        if self.cf.weight_init is not None:
+            logger.info("using pytorch weight init of type {}".format(self.cf.weight_init))
+            mutils.initialize_weights(self)
+        else:
+            logger.info("using default pytorch weight init")
+
+
+    def build(self):
+        """Build Mask R-CNN architecture."""
+
+        # Image size must be dividable by 2 multiple times.
+        h, w = self.cf.patch_size[:2]
+        if h / 2**5 != int(h / 2**5) or w / 2**5 != int(w / 2**5):
+            raise Exception("Image size must be dividable by 2 at least 5 times "
+                            "to avoid fractions when downscaling and upscaling."
+                            "For example, use 256, 320, 384, 448, 512, ... etc. ")
+        if len(self.cf.patch_size) == 3:
+            d = self.cf.patch_size[2]
+            if d / 2**3 != int(d / 2**3):
+                raise Exception("Image z dimension must be dividable by 2 at least 3 times "
+                                "to avoid fractions when downscaling and upscaling.")
+
+
+
+        # instanciate abstract multi dimensional conv class and backbone class.
+        conv = mutils.NDConvGenerator(self.cf.dim)
+        backbone = utils.import_module('bbone', self.cf.backbone_path)
+
+        # build Anchors, FPN, RPN, Classifier / Bbox-Regressor -head, Mask-head
+        self.np_anchors = mutils.generate_pyramid_anchors(self.logger, self.cf)
+        self.anchors = torch.from_numpy(self.np_anchors).float().cuda()
+        self.fpn = backbone.FPN(self.cf, conv)
+        self.rpn = RPN(self.cf, conv)
+        self.classifier = Classifier(self.cf, conv)
+        self.mask = Mask(self.cf, conv)
+
+
+    def train_forward(self, batch, is_validation=False):
+        """
+        train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data
+        for processing, computes losses, and stores outputs in a dictionary.
+        :param batch: dictionary containing 'data', 'seg', etc.
+                    data_dict['roi_masks']: (b, n(b), 1, h(n), w(n) (z(n))) list like batch['class_target'] but with
+                    arrays (masks) inplace of integers. n == number of rois per this batch element.
+        :return: results_dict: dictionary with keys:
+                'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary:
+                        [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...]
+                'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes].
+                'monitor_values': dict of values to be monitored.
+        """
+        img = batch['data']
+        if "roi_labels" in batch.keys():
+            raise Exception("Key for roi-wise class targets changed in v0.1.0 from 'roi_labels' to 'class_target'.\n"
+                            "If you use DKFZ's batchgenerators, please make sure you run version >= 0.20.1.")
+        gt_class_ids = batch['class_target']
+        gt_boxes = batch['bb_target']
+        #axes = (0, 2, 3, 1) if self.cf.dim == 2 else (0, 2, 3, 4, 1)
+        #gt_masks = [np.transpose(batch['roi_masks'][ii], axes=axes) for ii in range(len(batch['roi_masks']))]
+        # --> now GT masks has c==channels in last dimension.
+        gt_masks = batch['roi_masks']
+        img = torch.from_numpy(img).float().cuda()
+        batch_rpn_class_loss = torch.FloatTensor([0]).cuda()
+        batch_rpn_bbox_loss = torch.FloatTensor([0]).cuda()
+
+        # list of output boxes for monitoring/plotting. each element is a list of boxes per batch element.
+        box_results_list = [[] for _ in range(img.shape[0])]
+
+        #forward passes. 1. general forward pass, where no activations are saved in second stage (for performance
+        # monitoring and loss sampling). 2. second stage forward pass of sampled rois with stored activations for backprop.
+        rpn_class_logits, rpn_pred_deltas, proposal_boxes, detections, detection_masks = self.forward(img)
+        mrcnn_class_logits, mrcnn_pred_deltas, mrcnn_pred_mask, target_class_ids, mrcnn_target_deltas, target_mask,  \
+        sample_proposals = self.loss_samples_forward(gt_class_ids, gt_boxes, gt_masks)
+
+        # loop over batch
+        for b in range(img.shape[0]):
+            if len(gt_boxes[b]) > 0:
+
+                # add gt boxes to output list for monitoring.
+                for ix in range(len(gt_boxes[b])):
+                    box_results_list[b].append({'box_coords': batch['bb_target'][b][ix],
+                                                'box_label': batch['class_target'][b][ix], 'box_type': 'gt'})
+
+                # match gt boxes with anchors to generate targets for RPN losses.
+                rpn_match, rpn_target_deltas = mutils.gt_anchor_matching(self.cf, self.np_anchors, gt_boxes[b])
+
+                # add positive anchors used for loss to output list for monitoring.
+                pos_anchors = mutils.clip_boxes_numpy(self.np_anchors[np.argwhere(rpn_match == 1)][:, 0], img.shape[2:])
+                for p in pos_anchors:
+                    box_results_list[b].append({'box_coords': p, 'box_type': 'pos_anchor'})
+
+            else:
+                rpn_match = np.array([-1]*self.np_anchors.shape[0])
+                rpn_target_deltas = np.array([0])
+
+            rpn_match_gpu = torch.from_numpy(rpn_match).cuda()
+            rpn_target_deltas = torch.from_numpy(rpn_target_deltas).float().cuda()
+
+            # compute RPN losses.
+            rpn_class_loss, neg_anchor_ix = compute_rpn_class_loss(rpn_match_gpu, rpn_class_logits[b], self.cf.shem_poolsize)
+            rpn_bbox_loss = compute_rpn_bbox_loss(rpn_target_deltas, rpn_pred_deltas[b], rpn_match_gpu)
+            batch_rpn_class_loss += rpn_class_loss / img.shape[0]
+            batch_rpn_bbox_loss += rpn_bbox_loss / img.shape[0]
+
+            # add negative anchors used for loss to output list for monitoring.
+            neg_anchors = mutils.clip_boxes_numpy(self.np_anchors[rpn_match == -1][neg_anchor_ix], img.shape[2:])
+            for n in neg_anchors:
+                box_results_list[b].append({'box_coords': n, 'box_type': 'neg_anchor'})
+
+            # add highest scoring proposals to output list for monitoring.
+            rpn_proposals = proposal_boxes[b][proposal_boxes[b, :, -1].argsort()][::-1]
+            for r in rpn_proposals[:self.cf.n_plot_rpn_props, :-1]:
+                box_results_list[b].append({'box_coords': r, 'box_type': 'prop'})
+
+        # add positive and negative roi samples used for mrcnn losses to output list for monitoring.
+        if 0 not in sample_proposals.shape:
+            rois = mutils.clip_to_window(self.cf.window, sample_proposals).cpu().data.numpy()
+            for ix, r in enumerate(rois):
+                box_results_list[int(r[-1])].append({'box_coords': r[:-1] * self.cf.scale,
+                                            'box_type': 'pos_class' if target_class_ids[ix] > 0 else 'neg_class'})
+
+        batch_rpn_class_loss = batch_rpn_class_loss
+        batch_rpn_bbox_loss = batch_rpn_bbox_loss
+
+        # compute mrcnn losses.
+        mrcnn_class_loss = compute_mrcnn_class_loss(target_class_ids, mrcnn_class_logits)
+        mrcnn_bbox_loss = compute_mrcnn_bbox_loss(mrcnn_target_deltas, mrcnn_pred_deltas, target_class_ids)
+
+        # mrcnn can be run without pixelwise annotations available (Faster R-CNN mode).
+        # In this case, the mask_loss is taken out of training.
+        if not self.cf.frcnn_mode:
+            mrcnn_mask_loss = compute_mrcnn_mask_loss(target_mask, mrcnn_pred_mask, target_class_ids)
+        else:
+            mrcnn_mask_loss = torch.FloatTensor([0]).cuda()
+
+        loss = batch_rpn_class_loss + batch_rpn_bbox_loss + mrcnn_class_loss + mrcnn_bbox_loss + mrcnn_mask_loss
+
+        # monitor RPN performance: detection count = the number of correctly matched proposals per fg-class.
+        dcount = [list(target_class_ids.cpu().data.numpy()).count(c) for c in np.arange(self.cf.head_classes)[1:]]
+
+
+
+        # run unmolding of predictions for monitoring and merge all results to one dictionary.
+        return_masks = True#self.cf.return_masks_in_val if is_validation else False
+        results_dict = get_results(self.cf, img.shape, detections, detection_masks,
+                                   box_results_list, return_masks=return_masks)
+
+        results_dict['torch_loss'] = loss
+        results_dict['monitor_values'] = {'loss': loss.item(), 'class_loss': mrcnn_class_loss.item()}
+
+        results_dict['logger_string'] =  \
+            "loss: {0:.2f}, rpn_class: {1:.2f}, rpn_bbox: {2:.2f}, mrcnn_class: {3:.2f}, mrcnn_bbox: {4:.2f}, " \
+            "mrcnn_mask: {5:.2f}, dcount {6}".format(loss.item(), batch_rpn_class_loss.item(),
+                                                     batch_rpn_bbox_loss.item(), mrcnn_class_loss.item(),
+                                                     mrcnn_bbox_loss.item(), mrcnn_mask_loss.item(), dcount)
+
+        return results_dict
+
+
+    def test_forward(self, batch, return_masks=True):
+        """
+        test method. wrapper around forward pass of network without usage of any ground truth information.
+        prepares input data for processing and stores outputs in a dictionary.
+        :param batch: dictionary containing 'data'
+        :param return_masks: boolean. If True, full resolution masks are returned for all proposals (speed trade-off).
+        :return: results_dict: dictionary with keys:
+               'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary:
+                       [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...]
+               'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes]
+        """
+        img = batch['data']
+        img = torch.from_numpy(img).float().cuda()
+        _, _, _, detections, detection_masks = self.forward(img)
+        results_dict = get_results(self.cf, img.shape, detections, detection_masks, return_masks=return_masks)
+        return results_dict
+
+
+    def forward(self, img, is_training=True):
+        """
+        :param img: input images (b, c, y, x, (z)).
+        :return: rpn_pred_logits: (b, n_anchors, 2)
+        :return: rpn_pred_deltas: (b, n_anchors, (y, x, (z), log(h), log(w), (log(d))))
+        :return: batch_proposal_boxes: (b, n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ix)) only for monitoring/plotting.
+        :return: detections: (n_final_detections, (y1, x1, y2, x2, (z1), (z2), batch_ix, pred_class_id, pred_score)
+        :return: detection_masks: (n_final_detections, n_classes, y, x, (z)) raw molded masks as returned by mask-head.
+        """
+        # extract features.
+        fpn_outs = self.fpn(img)
+        rpn_feature_maps = [fpn_outs[i] for i in self.cf.pyramid_levels]
+        self.mrcnn_feature_maps = rpn_feature_maps
+
+        # loop through pyramid layers and apply RPN.
+        layer_outputs = []  # list of lists
+        for p in rpn_feature_maps:
+            layer_outputs.append(self.rpn(p))
+
+        # concatenate layer outputs.
+        # convert from list of lists of level outputs to list of lists of outputs across levels.
+        # e.g. [[a1, b1, c1], [a2, b2, c2]] => [[a1, a2], [b1, b2], [c1, c2]]
+        outputs = list(zip(*layer_outputs))
+        outputs = [torch.cat(list(o), dim=1) for o in outputs]
+        rpn_pred_logits, rpn_pred_probs, rpn_pred_deltas = outputs
+
+        # generate proposals: apply predicted deltas to anchors and filter by foreground scores from RPN classifier.
+        proposal_count = self.cf.post_nms_rois_training if is_training else self.cf.post_nms_rois_inference
+        batch_rpn_rois, batch_proposal_boxes = refine_proposals(rpn_pred_probs, rpn_pred_deltas, proposal_count, self.anchors, self.cf)
+
+        # merge batch dimension of proposals while storing allocation info in coordinate dimension.
+        batch_ixs = torch.from_numpy(np.repeat(np.arange(batch_rpn_rois.shape[0]), batch_rpn_rois.shape[1])).float().cuda()
+        rpn_rois = batch_rpn_rois.view(-1, batch_rpn_rois.shape[2])
+        self.rpn_rois_batch_info = torch.cat((rpn_rois, batch_ixs.unsqueeze(1)), dim=1)
+
+        # this is the first of two forward passes in the second stage, where no activations are stored for backprop.
+        # here, all proposals are forwarded (with virtual_batch_size = batch_size * post_nms_rois.)
+        # for inference/monitoring as well as sampling of rois for the loss functions.
+        # processed in chunks of roi_chunk_size to re-adjust to gpu-memory.
+        chunked_rpn_rois = self.rpn_rois_batch_info.split(self.cf.roi_chunk_size)
+        class_logits_list, bboxes_list = [], []
+        with torch.no_grad():
+            for chunk in chunked_rpn_rois:
+                chunk_class_logits, chunk_bboxes = self.classifier(self.mrcnn_feature_maps, chunk)
+                class_logits_list.append(chunk_class_logits)
+                bboxes_list.append(chunk_bboxes)
+        batch_mrcnn_class_logits = torch.cat(class_logits_list, 0)
+        batch_mrcnn_bbox = torch.cat(bboxes_list, 0)
+        self.batch_mrcnn_class_scores = F.softmax(batch_mrcnn_class_logits, dim=1)
+
+        # refine classified proposals, filter and return final detections.
+        detections = refine_detections(self.cf, batch_ixs, rpn_rois, batch_mrcnn_bbox, self.batch_mrcnn_class_scores)
+
+        # forward remaining detections through mask-head to generate corresponding masks.
+        scale = [img.shape[2]] * 4 + [img.shape[-1]] * 2
+        scale = torch.from_numpy(np.array(scale[:self.cf.dim * 2] + [1])[None]).float().cuda()
+
+
+        detection_boxes = detections[:, :self.cf.dim * 2 + 1] / scale
+        with torch.no_grad():
+            detection_masks = self.mask(self.mrcnn_feature_maps, detection_boxes)
+
+        return [rpn_pred_logits, rpn_pred_deltas, batch_proposal_boxes, detections, detection_masks]
+
+
+    def loss_samples_forward(self, batch_gt_class_ids, batch_gt_boxes, batch_gt_masks):
+        """
+        this is the second forward pass through the second stage (features from stage one are re-used).
+        samples few rois in detection_target_layer and forwards only those for loss computation.
+        :param batch_gt_class_ids: list over batch elements. Each element is a list over the corresponding roi target labels.
+        :param batch_gt_boxes: list over batch elements. Each element is a list over the corresponding roi target coordinates.
+        :param batch_gt_masks: list over batch elements. Each element is binary mask of shape (n_gt_rois, y, x, (z), c)
+        :return: sample_logits: (n_sampled_rois, n_classes) predicted class scores.
+        :return: sample_boxes: (n_sampled_rois, n_classes, 2 * dim) predicted corrections to be applied to proposals for refinement.
+        :return: sample_mask: (n_sampled_rois, n_classes, y, x, (z)) predicted masks per class and proposal.
+        :return: sample_target_class_ids: (n_sampled_rois) target class labels of sampled proposals.
+        :return: sample_target_deltas: (n_sampled_rois, 2 * dim) target deltas of sampled proposals for box refinement.
+        :return: sample_target_masks: (n_sampled_rois, y, x, (z)) target masks of sampled proposals.
+        :return: sample_proposals: (n_sampled_rois, 2 * dim) RPN output for sampled proposals. only for monitoring/plotting.
+        """
+        # sample rois for loss and get corresponding targets for all Mask R-CNN head network losses.
+        sample_ix, sample_target_class_ids, sample_target_deltas, sample_target_mask = \
+            detection_target_layer(self.rpn_rois_batch_info, self.batch_mrcnn_class_scores,
+                                   batch_gt_class_ids, batch_gt_boxes, batch_gt_masks, self.cf)
+
+        # re-use feature maps and RPN output from first forward pass.
+        sample_proposals = self.rpn_rois_batch_info[sample_ix]
+        if 0 not in sample_proposals.size():
+            sample_logits, sample_boxes = self.classifier(self.mrcnn_feature_maps, sample_proposals)
+            sample_mask = self.mask(self.mrcnn_feature_maps, sample_proposals)
+        else:
+            sample_logits = torch.FloatTensor().cuda()
+            sample_boxes = torch.FloatTensor().cuda()
+            sample_mask = torch.FloatTensor().cuda()
+
+        return [sample_logits, sample_boxes, sample_mask, sample_target_class_ids, sample_target_deltas,
+                sample_target_mask, sample_proposals]
\ No newline at end of file