--- a
+++ b/custom_extensions/nms/nms.py
@@ -0,0 +1,75 @@
+"""
+adopted from pytorch framework, torchvision.ops.boxes
+
+"""
+
+import torch
+import nms_extension
+
+def nms(boxes, scores, iou_threshold):
+    """
+    Performs non-maximum suppression (NMS) on the boxes according
+    to their intersection-over-union (IoU).
+
+    NMS iteratively removes lower scoring boxes which have an
+    IoU greater than iou_threshold with another (higher scoring)
+    box.
+
+    Parameters
+    ----------
+    boxes : Tensor[N, 4] for 2D or Tensor[N,6] for 3D.
+        boxes to perform NMS on. They
+        are expected to be in (y1, x1, y2, x2(, z1, z2)) format
+    scores : Tensor[N]
+        scores for each one of the boxes
+    iou_threshold : float
+        discards all overlapping
+        boxes with IoU < iou_threshold
+
+    Returns
+    -------
+    keep : Tensor
+        int64 tensor with the indices
+        of the elements that have been kept
+        by NMS, sorted in decreasing order of scores
+    """
+    return nms_extension.nms(boxes, scores, iou_threshold)
+
+
+def batched_nms(boxes, scores, idxs, iou_threshold):
+    """
+    Performs non-maximum suppression in a batched fashion.
+
+    Each index value correspond to a category, and NMS
+    will not be applied between elements of different categories.
+
+    Parameters
+    ----------
+    boxes : Tensor[N, 4] for 2D or Tensor[N,6] for 3D.
+        boxes to perform NMS on. They
+        are expected to be in (y1, x1, y2, x2(, z1, z2)) format
+    scores : Tensor[N]
+        scores for each one of the boxes
+    idxs : Tensor[N]
+        indices of the categories for each one of the boxes.
+    iou_threshold : float
+        discards all overlapping boxes
+        with IoU < iou_threshold
+
+    Returns
+    -------
+    keep : Tensor
+        int64 tensor with the indices of
+        the elements that have been kept by NMS, sorted
+        in decreasing order of scores
+    """
+    if boxes.numel() == 0:
+        return torch.empty((0,), dtype=torch.int64, device=boxes.device)
+    # strategy: in order to perform NMS independently per class.
+    # we add an offset to all the boxes. The offset is dependent
+    # only on the class idx, and is large enough so that boxes
+    # from different classes do not overlap
+    max_coordinate = boxes.max()
+    offsets = idxs.to(boxes) * (max_coordinate + 1)
+    boxes_for_nms = boxes + offsets[:, None]
+    return  nms(boxes_for_nms, scores, iou_threshold)