--- a +++ b/utils.py @@ -0,0 +1,135 @@ +import torch +import torch.nn + +# set all modules to training mode +def set_to_train_mode(model, report=True): + for _k in model._modules.keys(): + if 'new' in _k: + if report: + print("Setting {0:} to training mode".format(_k)) + model._modules[_k].train(True) + +# switch on gradients and add parameters to the list of trainable parameters +# implemented only for update_type=new_bn (classification module S + all batch +# normalization layers in the model +# This doesn't apply to running_var, running_mean and batch tracking (frozen) +# in batch normalization layers +# This assumes that trainable layers have either 'new' or 'bn' in their name +def switch_model_on(model, ckpt, list_trained_pars): + param_names = ckpt['model_weights'].keys() + for _n,_p in model.named_parameters(): + if _p.dtype==torch.float32 and _n in param_names: + if not 'new' in _n and not 'bn' in _n: + _p.requires_grad_(True) + print(_n, "grads on") + else: + _p.requires_grad_(True) + list_trained_pars.append(_p) + print(_n, "trainable pars") + elif _p.dtype==torch.float32 and not _n in param_names: + _p.requires_grad_(True) + list_trained_pars.append(_p) + print(_n, "new pars, trainable") + +# AVERAGE PRECISION COMPUTATION +# adapted from Matterport Mask R-CNN implementation +# https://github.com/matterport/Mask_RCNN +# inputs are predicted masks>threshold (0.5) +def compute_overlaps_masks(masks1, masks2): + # masks1: (HxWxnum_pred) + # masks2: (HxWxnum_gts) + # flatten masks and compute their areas + # masks1: num_pred x H*W + # masks2: num_gt x H*W + # overlap: num_pred x num_gt + masks1 = masks1.flatten(start_dim=1) + masks2 = masks2.flatten(start_dim=1) + area2 = masks2.sum(dim=(1,), dtype=torch.float) + area1 = masks1.sum(dim=(1,), dtype=torch.float) + # duplicatae each predicted mask num_gt times, compute the union (sum) of areas + # num_pred x num_gt + area1 = area1.unsqueeze_(1).expand(*[area1.size()[0], area2.size()[0]]) + union = area1 + area2 + # intersections and union: transpose predictions, the overlap matrix is num_predxnum_gts + intersections = masks1.float().matmul(masks2.t().float()) + # +1: divide by 0 + overlaps = intersections / (union-intersections) + return overlaps + + +# compute average precision for the specified IoU threshold +def compute_matches(gt_boxes, gt_class_ids, gt_masks, + pred_boxes, pred_class_ids, pred_scores, pred_masks, + iou_threshold=0.5): + # Sort predictions by score from high to low + indices = pred_scores.argsort().flip(dims=(0,)) + pred_boxes = pred_boxes[indices] + pred_class_ids = pred_class_ids[indices] + pred_scores = pred_scores[indices] + pred_masks = pred_masks[indices,...] + # Compute IoU overlaps [pred_masks, gt_masks] + overlaps = compute_overlaps_masks(pred_masks, gt_masks) + # separate predictions for each gt object (a total of gt_masks splits + split_overlaps = overlaps.t().split(1) + # Loop through predictions and find matching ground truth boxes + match_count = 0 + # At the start all predictions are False Positives, all gts are False Negatives + pred_match = torch.tensor([-1]).expand(pred_boxes.size()[0]).float() + gt_match = torch.tensor([-1]).expand(gt_boxes.size()[0]).float() + # Alex: loop through each column (gt object), get + for _i, splits in enumerate(split_overlaps): + # ground truth class + gt_class = gt_class_ids[_i] + if (splits>iou_threshold).any(): + # get best predictions, their indices inthe IoU tensor and their classes + global_best_preds_inds = torch.nonzero(splits[0]>iou_threshold).view(-1) + pred_classes = pred_class_ids[global_best_preds_inds] + best_preds = splits[0][splits[0]>iou_threshold] + # sort them locally-nothing else, + local_best_preds_sorted = best_preds.argsort().flip(dims=(0,)) + # loop through each prediction's index, sorted in the descending order + for p in local_best_preds_sorted: + if pred_classes[p]==gt_class: + # Hit? + match_count +=1 + pred_match[global_best_preds_inds[p]] = _i + gt_match[_i] = global_best_preds_inds[p] + # important: if the prediction is True Positive, finish the loop + break + + return gt_match, pred_match, overlaps + + +# AP for a single IoU threshold and 1 image +def compute_ap(gt_boxes, gt_class_ids, gt_masks, + pred_boxes, pred_class_ids, pred_scores, pred_masks, + iou_threshold=0.5): + + # Get matches and overlaps + gt_match, pred_match, overlaps = compute_matches( + gt_boxes, gt_class_ids, gt_masks, + pred_boxes, pred_class_ids, pred_scores, pred_masks, + iou_threshold) + + # Compute precision and recall at each prediction box step + precisions = (pred_match>-1).cumsum(dim=0).float().div(torch.arange(pred_match.numel()).float()+1) + recalls = (pred_match>-1).cumsum(dim=0).float().div(gt_match.numel()) + # Pad with start and end values to simplify the math + precisions = torch.cat([torch.tensor([0]).float(), precisions, torch.tensor([0]).float()]) + recalls = torch.cat([torch.tensor([0]).float(), recalls, torch.tensor([1]).float()]) + # Ensure precision values decrease but don't increase. This way, the + # precision value at each recall threshold is the maximum it can be + # for all following recall thresholds, as specified by the VOC paper. + for i in range(len(precisions) - 2, -1, -1): + precisions[i] = torch.max(precisions[i], precisions[i + 1]) + # Compute mean AP over recall range + indices = torch.nonzero(recalls[:-1] !=recalls[1:]).squeeze_(1)+1 + map = torch.sum((recalls[indices] - recalls[indices - 1]) * + precisions[indices]) + return map, precisions, recalls, overlaps + + +# easier boolean argument +def str_to_bool(v): + return v.lower() in ('true') +