Diff of /utils.py [000000] .. [a49583]

Switch to unified view

a b/utils.py
1
import torch
2
import torch.nn
3
4
# set all modules to training mode
5
def set_to_train_mode(model, report=True):
6
    for _k in model._modules.keys():
7
        if 'new' in _k:
8
           if report:
9
              print("Setting {0:} to training mode".format(_k))
10
           model._modules[_k].train(True)
11
12
# switch on gradients and add parameters to the list of trainable parameters
13
# implemented only for update_type=new_bn (classification module S + all batch
14
# normalization layers in the model
15
# This doesn't apply to running_var, running_mean and batch tracking (frozen)
16
# in batch normalization layers
17
# This assumes that trainable layers have either 'new' or 'bn' in their name
18
def switch_model_on(model, ckpt, list_trained_pars):
19
    param_names = ckpt['model_weights'].keys()
20
    for _n,_p in model.named_parameters():
21
      if _p.dtype==torch.float32 and _n in param_names:
22
         if not 'new' in _n and not 'bn' in _n:
23
            _p.requires_grad_(True)
24
            print(_n, "grads on")
25
         else:
26
            _p.requires_grad_(True)
27
            list_trained_pars.append(_p)
28
            print(_n, "trainable pars")
29
      elif _p.dtype==torch.float32 and not _n in param_names:
30
         _p.requires_grad_(True)
31
         list_trained_pars.append(_p)
32
         print(_n, "new pars, trainable")
33
34
# AVERAGE PRECISION COMPUTATION
35
# adapted from Matterport Mask R-CNN implementation
36
# https://github.com/matterport/Mask_RCNN
37
# inputs are predicted masks>threshold (0.5)
38
def compute_overlaps_masks(masks1, masks2):
39
    # masks1: (HxWxnum_pred)
40
    # masks2: (HxWxnum_gts)
41
    # flatten masks and compute their areas
42
    # masks1: num_pred x H*W
43
    # masks2: num_gt x H*W
44
    # overlap: num_pred x num_gt
45
    masks1 = masks1.flatten(start_dim=1)
46
    masks2 = masks2.flatten(start_dim=1)
47
    area2 = masks2.sum(dim=(1,), dtype=torch.float)
48
    area1 = masks1.sum(dim=(1,), dtype=torch.float)
49
    # duplicatae each predicted mask num_gt times, compute the union (sum) of areas
50
    # num_pred x num_gt
51
    area1 = area1.unsqueeze_(1).expand(*[area1.size()[0], area2.size()[0]])
52
    union = area1 + area2
53
    # intersections and union: transpose predictions, the overlap matrix is num_predxnum_gts
54
    intersections = masks1.float().matmul(masks2.t().float())
55
    # +1: divide by 0
56
    overlaps = intersections / (union-intersections)
57
    return overlaps
58
59
60
# compute average precision for the  specified IoU threshold
61
def compute_matches(gt_boxes, gt_class_ids, gt_masks,
62
                    pred_boxes, pred_class_ids, pred_scores, pred_masks,
63
                    iou_threshold=0.5):
64
    # Sort predictions by score from high to low
65
    indices = pred_scores.argsort().flip(dims=(0,))
66
    pred_boxes = pred_boxes[indices]
67
    pred_class_ids = pred_class_ids[indices]
68
    pred_scores = pred_scores[indices]
69
    pred_masks = pred_masks[indices,...]
70
    # Compute IoU overlaps [pred_masks, gt_masks]
71
    overlaps = compute_overlaps_masks(pred_masks, gt_masks)
72
    # separate predictions for each gt object (a total of gt_masks splits
73
    split_overlaps = overlaps.t().split(1)
74
    # Loop through predictions and find matching ground truth boxes
75
    match_count = 0
76
    # At the start all predictions are False Positives, all gts are False Negatives
77
    pred_match = torch.tensor([-1]).expand(pred_boxes.size()[0]).float()
78
    gt_match = torch.tensor([-1]).expand(gt_boxes.size()[0]).float()
79
    # Alex: loop through each column (gt object), get
80
    for _i, splits in enumerate(split_overlaps):
81
        # ground truth class
82
        gt_class = gt_class_ids[_i]
83
        if (splits>iou_threshold).any():
84
           # get best predictions, their indices inthe IoU tensor and their classes
85
           global_best_preds_inds = torch.nonzero(splits[0]>iou_threshold).view(-1)
86
           pred_classes = pred_class_ids[global_best_preds_inds]
87
           best_preds = splits[0][splits[0]>iou_threshold]
88
           #  sort them locally-nothing else,
89
           local_best_preds_sorted = best_preds.argsort().flip(dims=(0,))
90
           # loop through each prediction's index, sorted in the descending order
91
           for p in local_best_preds_sorted:
92
               if pred_classes[p]==gt_class:
93
                  # Hit?
94
                  match_count +=1
95
                  pred_match[global_best_preds_inds[p]] = _i
96
                  gt_match[_i] = global_best_preds_inds[p]
97
                  # important: if the prediction is True Positive, finish the loop
98
                  break
99
100
    return gt_match, pred_match, overlaps
101
102
103
# AP for a single IoU threshold and 1 image
104
def compute_ap(gt_boxes, gt_class_ids, gt_masks,
105
               pred_boxes, pred_class_ids, pred_scores, pred_masks,
106
               iou_threshold=0.5):
107
108
    # Get matches and overlaps
109
    gt_match, pred_match, overlaps = compute_matches(
110
        gt_boxes, gt_class_ids, gt_masks,
111
        pred_boxes, pred_class_ids, pred_scores, pred_masks,
112
        iou_threshold)
113
114
    # Compute precision and recall at each prediction box step
115
    precisions = (pred_match>-1).cumsum(dim=0).float().div(torch.arange(pred_match.numel()).float()+1)
116
    recalls = (pred_match>-1).cumsum(dim=0).float().div(gt_match.numel())
117
    # Pad with start and end values to simplify the math
118
    precisions = torch.cat([torch.tensor([0]).float(), precisions, torch.tensor([0]).float()])
119
    recalls = torch.cat([torch.tensor([0]).float(), recalls, torch.tensor([1]).float()])
120
    # Ensure precision values decrease but don't increase. This way, the
121
    # precision value at each recall threshold is the maximum it can be
122
    # for all following recall thresholds, as specified by the VOC paper.
123
    for i in range(len(precisions) - 2, -1, -1):
124
        precisions[i] = torch.max(precisions[i], precisions[i + 1])
125
    # Compute mean AP over recall range
126
    indices = torch.nonzero(recalls[:-1] !=recalls[1:]).squeeze_(1)+1
127
    map = torch.sum((recalls[indices] - recalls[indices - 1]) *
128
                 precisions[indices])
129
    return map, precisions, recalls, overlaps
130
131
132
# easier boolean argument
133
def str_to_bool(v):
134
    return v.lower() in ('true')
135