--- a +++ b/custom_extensions/nms/nms.py @@ -0,0 +1,75 @@ +""" +adopted from pytorch framework, torchvision.ops.boxes + +""" + +import torch +import nms_extension + +def nms(boxes, scores, iou_threshold): + """ + Performs non-maximum suppression (NMS) on the boxes according + to their intersection-over-union (IoU). + + NMS iteratively removes lower scoring boxes which have an + IoU greater than iou_threshold with another (higher scoring) + box. + + Parameters + ---------- + boxes : Tensor[N, 4] for 2D or Tensor[N,6] for 3D. + boxes to perform NMS on. They + are expected to be in (y1, x1, y2, x2(, z1, z2)) format + scores : Tensor[N] + scores for each one of the boxes + iou_threshold : float + discards all overlapping + boxes with IoU < iou_threshold + + Returns + ------- + keep : Tensor + int64 tensor with the indices + of the elements that have been kept + by NMS, sorted in decreasing order of scores + """ + return nms_extension.nms(boxes, scores, iou_threshold) + + +def batched_nms(boxes, scores, idxs, iou_threshold): + """ + Performs non-maximum suppression in a batched fashion. + + Each index value correspond to a category, and NMS + will not be applied between elements of different categories. + + Parameters + ---------- + boxes : Tensor[N, 4] for 2D or Tensor[N,6] for 3D. + boxes to perform NMS on. They + are expected to be in (y1, x1, y2, x2(, z1, z2)) format + scores : Tensor[N] + scores for each one of the boxes + idxs : Tensor[N] + indices of the categories for each one of the boxes. + iou_threshold : float + discards all overlapping boxes + with IoU < iou_threshold + + Returns + ------- + keep : Tensor + int64 tensor with the indices of + the elements that have been kept by NMS, sorted + in decreasing order of scores + """ + if boxes.numel() == 0: + return torch.empty((0,), dtype=torch.int64, device=boxes.device) + # strategy: in order to perform NMS independently per class. + # we add an offset to all the boxes. The offset is dependent + # only on the class idx, and is large enough so that boxes + # from different classes do not overlap + max_coordinate = boxes.max() + offsets = idxs.to(boxes) * (max_coordinate + 1) + boxes_for_nms = boxes + offsets[:, None] + return nms(boxes_for_nms, scores, iou_threshold)