|
a |
|
b/utils.py |
|
|
1 |
import torch |
|
|
2 |
import torch.nn |
|
|
3 |
|
|
|
4 |
# set all modules to training mode |
|
|
5 |
def set_to_train_mode(model, report=True): |
|
|
6 |
for _k in model._modules.keys(): |
|
|
7 |
if 'new' in _k: |
|
|
8 |
if report: |
|
|
9 |
print("Setting {0:} to training mode".format(_k)) |
|
|
10 |
model._modules[_k].train(True) |
|
|
11 |
|
|
|
12 |
# switch on gradients and add parameters to the list of trainable parameters |
|
|
13 |
# implemented only for update_type=new_bn (classification module S + all batch |
|
|
14 |
# normalization layers in the model |
|
|
15 |
# This doesn't apply to running_var, running_mean and batch tracking (frozen) |
|
|
16 |
# in batch normalization layers |
|
|
17 |
# This assumes that trainable layers have either 'new' or 'bn' in their name |
|
|
18 |
def switch_model_on(model, ckpt, list_trained_pars): |
|
|
19 |
param_names = ckpt['model_weights'].keys() |
|
|
20 |
for _n,_p in model.named_parameters(): |
|
|
21 |
if _p.dtype==torch.float32 and _n in param_names: |
|
|
22 |
if not 'new' in _n and not 'bn' in _n: |
|
|
23 |
_p.requires_grad_(True) |
|
|
24 |
print(_n, "grads on") |
|
|
25 |
else: |
|
|
26 |
_p.requires_grad_(True) |
|
|
27 |
list_trained_pars.append(_p) |
|
|
28 |
print(_n, "trainable pars") |
|
|
29 |
elif _p.dtype==torch.float32 and not _n in param_names: |
|
|
30 |
_p.requires_grad_(True) |
|
|
31 |
list_trained_pars.append(_p) |
|
|
32 |
print(_n, "new pars, trainable") |
|
|
33 |
|
|
|
34 |
# AVERAGE PRECISION COMPUTATION |
|
|
35 |
# adapted from Matterport Mask R-CNN implementation |
|
|
36 |
# https://github.com/matterport/Mask_RCNN |
|
|
37 |
# inputs are predicted masks>threshold (0.5) |
|
|
38 |
def compute_overlaps_masks(masks1, masks2): |
|
|
39 |
# masks1: (HxWxnum_pred) |
|
|
40 |
# masks2: (HxWxnum_gts) |
|
|
41 |
# flatten masks and compute their areas |
|
|
42 |
# masks1: num_pred x H*W |
|
|
43 |
# masks2: num_gt x H*W |
|
|
44 |
# overlap: num_pred x num_gt |
|
|
45 |
masks1 = masks1.flatten(start_dim=1) |
|
|
46 |
masks2 = masks2.flatten(start_dim=1) |
|
|
47 |
area2 = masks2.sum(dim=(1,), dtype=torch.float) |
|
|
48 |
area1 = masks1.sum(dim=(1,), dtype=torch.float) |
|
|
49 |
# duplicatae each predicted mask num_gt times, compute the union (sum) of areas |
|
|
50 |
# num_pred x num_gt |
|
|
51 |
area1 = area1.unsqueeze_(1).expand(*[area1.size()[0], area2.size()[0]]) |
|
|
52 |
union = area1 + area2 |
|
|
53 |
# intersections and union: transpose predictions, the overlap matrix is num_predxnum_gts |
|
|
54 |
intersections = masks1.float().matmul(masks2.t().float()) |
|
|
55 |
# +1: divide by 0 |
|
|
56 |
overlaps = intersections / (union-intersections) |
|
|
57 |
return overlaps |
|
|
58 |
|
|
|
59 |
|
|
|
60 |
# compute average precision for the specified IoU threshold |
|
|
61 |
def compute_matches(gt_boxes, gt_class_ids, gt_masks, |
|
|
62 |
pred_boxes, pred_class_ids, pred_scores, pred_masks, |
|
|
63 |
iou_threshold=0.5): |
|
|
64 |
# Sort predictions by score from high to low |
|
|
65 |
indices = pred_scores.argsort().flip(dims=(0,)) |
|
|
66 |
pred_boxes = pred_boxes[indices] |
|
|
67 |
pred_class_ids = pred_class_ids[indices] |
|
|
68 |
pred_scores = pred_scores[indices] |
|
|
69 |
pred_masks = pred_masks[indices,...] |
|
|
70 |
# Compute IoU overlaps [pred_masks, gt_masks] |
|
|
71 |
overlaps = compute_overlaps_masks(pred_masks, gt_masks) |
|
|
72 |
# separate predictions for each gt object (a total of gt_masks splits |
|
|
73 |
split_overlaps = overlaps.t().split(1) |
|
|
74 |
# Loop through predictions and find matching ground truth boxes |
|
|
75 |
match_count = 0 |
|
|
76 |
# At the start all predictions are False Positives, all gts are False Negatives |
|
|
77 |
pred_match = torch.tensor([-1]).expand(pred_boxes.size()[0]).float() |
|
|
78 |
gt_match = torch.tensor([-1]).expand(gt_boxes.size()[0]).float() |
|
|
79 |
# Alex: loop through each column (gt object), get |
|
|
80 |
for _i, splits in enumerate(split_overlaps): |
|
|
81 |
# ground truth class |
|
|
82 |
gt_class = gt_class_ids[_i] |
|
|
83 |
if (splits>iou_threshold).any(): |
|
|
84 |
# get best predictions, their indices inthe IoU tensor and their classes |
|
|
85 |
global_best_preds_inds = torch.nonzero(splits[0]>iou_threshold).view(-1) |
|
|
86 |
pred_classes = pred_class_ids[global_best_preds_inds] |
|
|
87 |
best_preds = splits[0][splits[0]>iou_threshold] |
|
|
88 |
# sort them locally-nothing else, |
|
|
89 |
local_best_preds_sorted = best_preds.argsort().flip(dims=(0,)) |
|
|
90 |
# loop through each prediction's index, sorted in the descending order |
|
|
91 |
for p in local_best_preds_sorted: |
|
|
92 |
if pred_classes[p]==gt_class: |
|
|
93 |
# Hit? |
|
|
94 |
match_count +=1 |
|
|
95 |
pred_match[global_best_preds_inds[p]] = _i |
|
|
96 |
gt_match[_i] = global_best_preds_inds[p] |
|
|
97 |
# important: if the prediction is True Positive, finish the loop |
|
|
98 |
break |
|
|
99 |
|
|
|
100 |
return gt_match, pred_match, overlaps |
|
|
101 |
|
|
|
102 |
|
|
|
103 |
# AP for a single IoU threshold and 1 image |
|
|
104 |
def compute_ap(gt_boxes, gt_class_ids, gt_masks, |
|
|
105 |
pred_boxes, pred_class_ids, pred_scores, pred_masks, |
|
|
106 |
iou_threshold=0.5): |
|
|
107 |
|
|
|
108 |
# Get matches and overlaps |
|
|
109 |
gt_match, pred_match, overlaps = compute_matches( |
|
|
110 |
gt_boxes, gt_class_ids, gt_masks, |
|
|
111 |
pred_boxes, pred_class_ids, pred_scores, pred_masks, |
|
|
112 |
iou_threshold) |
|
|
113 |
|
|
|
114 |
# Compute precision and recall at each prediction box step |
|
|
115 |
precisions = (pred_match>-1).cumsum(dim=0).float().div(torch.arange(pred_match.numel()).float()+1) |
|
|
116 |
recalls = (pred_match>-1).cumsum(dim=0).float().div(gt_match.numel()) |
|
|
117 |
# Pad with start and end values to simplify the math |
|
|
118 |
precisions = torch.cat([torch.tensor([0]).float(), precisions, torch.tensor([0]).float()]) |
|
|
119 |
recalls = torch.cat([torch.tensor([0]).float(), recalls, torch.tensor([1]).float()]) |
|
|
120 |
# Ensure precision values decrease but don't increase. This way, the |
|
|
121 |
# precision value at each recall threshold is the maximum it can be |
|
|
122 |
# for all following recall thresholds, as specified by the VOC paper. |
|
|
123 |
for i in range(len(precisions) - 2, -1, -1): |
|
|
124 |
precisions[i] = torch.max(precisions[i], precisions[i + 1]) |
|
|
125 |
# Compute mean AP over recall range |
|
|
126 |
indices = torch.nonzero(recalls[:-1] !=recalls[1:]).squeeze_(1)+1 |
|
|
127 |
map = torch.sum((recalls[indices] - recalls[indices - 1]) * |
|
|
128 |
precisions[indices]) |
|
|
129 |
return map, precisions, recalls, overlaps |
|
|
130 |
|
|
|
131 |
|
|
|
132 |
# easier boolean argument |
|
|
133 |
def str_to_bool(v): |
|
|
134 |
return v.lower() in ('true') |
|
|
135 |
|