Diff of /utils.py [000000] .. [a49583]

Switch to side-by-side view

--- a
+++ b/utils.py
@@ -0,0 +1,135 @@
+import torch
+import torch.nn
+
+# set all modules to training mode
+def set_to_train_mode(model, report=True):
+    for _k in model._modules.keys():
+        if 'new' in _k:
+           if report:
+              print("Setting {0:} to training mode".format(_k))
+           model._modules[_k].train(True)
+
+# switch on gradients and add parameters to the list of trainable parameters
+# implemented only for update_type=new_bn (classification module S + all batch
+# normalization layers in the model
+# This doesn't apply to running_var, running_mean and batch tracking (frozen)
+# in batch normalization layers
+# This assumes that trainable layers have either 'new' or 'bn' in their name
+def switch_model_on(model, ckpt, list_trained_pars):
+    param_names = ckpt['model_weights'].keys()
+    for _n,_p in model.named_parameters():
+      if _p.dtype==torch.float32 and _n in param_names:
+         if not 'new' in _n and not 'bn' in _n:
+            _p.requires_grad_(True)
+            print(_n, "grads on")
+         else:
+            _p.requires_grad_(True)
+            list_trained_pars.append(_p)
+            print(_n, "trainable pars")
+      elif _p.dtype==torch.float32 and not _n in param_names:
+         _p.requires_grad_(True)
+         list_trained_pars.append(_p)
+         print(_n, "new pars, trainable")
+
+# AVERAGE PRECISION COMPUTATION
+# adapted from Matterport Mask R-CNN implementation
+# https://github.com/matterport/Mask_RCNN
+# inputs are predicted masks>threshold (0.5)
+def compute_overlaps_masks(masks1, masks2):
+    # masks1: (HxWxnum_pred)
+    # masks2: (HxWxnum_gts)
+    # flatten masks and compute their areas
+    # masks1: num_pred x H*W
+    # masks2: num_gt x H*W
+    # overlap: num_pred x num_gt
+    masks1 = masks1.flatten(start_dim=1)
+    masks2 = masks2.flatten(start_dim=1)
+    area2 = masks2.sum(dim=(1,), dtype=torch.float)
+    area1 = masks1.sum(dim=(1,), dtype=torch.float)
+    # duplicatae each predicted mask num_gt times, compute the union (sum) of areas
+    # num_pred x num_gt
+    area1 = area1.unsqueeze_(1).expand(*[area1.size()[0], area2.size()[0]])
+    union = area1 + area2
+    # intersections and union: transpose predictions, the overlap matrix is num_predxnum_gts
+    intersections = masks1.float().matmul(masks2.t().float())
+    # +1: divide by 0
+    overlaps = intersections / (union-intersections)
+    return overlaps
+
+
+# compute average precision for the  specified IoU threshold
+def compute_matches(gt_boxes, gt_class_ids, gt_masks,
+                    pred_boxes, pred_class_ids, pred_scores, pred_masks,
+                    iou_threshold=0.5):
+    # Sort predictions by score from high to low
+    indices = pred_scores.argsort().flip(dims=(0,))
+    pred_boxes = pred_boxes[indices]
+    pred_class_ids = pred_class_ids[indices]
+    pred_scores = pred_scores[indices]
+    pred_masks = pred_masks[indices,...]
+    # Compute IoU overlaps [pred_masks, gt_masks]
+    overlaps = compute_overlaps_masks(pred_masks, gt_masks)
+    # separate predictions for each gt object (a total of gt_masks splits
+    split_overlaps = overlaps.t().split(1)
+    # Loop through predictions and find matching ground truth boxes
+    match_count = 0
+    # At the start all predictions are False Positives, all gts are False Negatives
+    pred_match = torch.tensor([-1]).expand(pred_boxes.size()[0]).float()
+    gt_match = torch.tensor([-1]).expand(gt_boxes.size()[0]).float()
+    # Alex: loop through each column (gt object), get
+    for _i, splits in enumerate(split_overlaps):
+        # ground truth class
+        gt_class = gt_class_ids[_i]
+        if (splits>iou_threshold).any():
+           # get best predictions, their indices inthe IoU tensor and their classes
+           global_best_preds_inds = torch.nonzero(splits[0]>iou_threshold).view(-1)
+           pred_classes = pred_class_ids[global_best_preds_inds]
+           best_preds = splits[0][splits[0]>iou_threshold]
+           #  sort them locally-nothing else,
+           local_best_preds_sorted = best_preds.argsort().flip(dims=(0,))
+           # loop through each prediction's index, sorted in the descending order
+           for p in local_best_preds_sorted:
+               if pred_classes[p]==gt_class:
+                  # Hit?
+                  match_count +=1
+                  pred_match[global_best_preds_inds[p]] = _i
+                  gt_match[_i] = global_best_preds_inds[p]
+                  # important: if the prediction is True Positive, finish the loop
+                  break
+
+    return gt_match, pred_match, overlaps
+
+
+# AP for a single IoU threshold and 1 image
+def compute_ap(gt_boxes, gt_class_ids, gt_masks,
+               pred_boxes, pred_class_ids, pred_scores, pred_masks,
+               iou_threshold=0.5):
+
+    # Get matches and overlaps
+    gt_match, pred_match, overlaps = compute_matches(
+        gt_boxes, gt_class_ids, gt_masks,
+        pred_boxes, pred_class_ids, pred_scores, pred_masks,
+        iou_threshold)
+
+    # Compute precision and recall at each prediction box step
+    precisions = (pred_match>-1).cumsum(dim=0).float().div(torch.arange(pred_match.numel()).float()+1)
+    recalls = (pred_match>-1).cumsum(dim=0).float().div(gt_match.numel())
+    # Pad with start and end values to simplify the math
+    precisions = torch.cat([torch.tensor([0]).float(), precisions, torch.tensor([0]).float()])
+    recalls = torch.cat([torch.tensor([0]).float(), recalls, torch.tensor([1]).float()])
+    # Ensure precision values decrease but don't increase. This way, the
+    # precision value at each recall threshold is the maximum it can be
+    # for all following recall thresholds, as specified by the VOC paper.
+    for i in range(len(precisions) - 2, -1, -1):
+        precisions[i] = torch.max(precisions[i], precisions[i + 1])
+    # Compute mean AP over recall range
+    indices = torch.nonzero(recalls[:-1] !=recalls[1:]).squeeze_(1)+1
+    map = torch.sum((recalls[indices] - recalls[indices - 1]) *
+                 precisions[indices])
+    return map, precisions, recalls, overlaps
+
+
+# easier boolean argument
+def str_to_bool(v):
+    return v.lower() in ('true')
+