:py:mod:`kwimage.algo._nms_backend.torch_nms` ============================================= .. py:module:: kwimage.algo._nms_backend.torch_nms Module Contents --------------- Functions ~~~~~~~~~ .. autoapisummary:: kwimage.algo._nms_backend.torch_nms.torch_nms kwimage.algo._nms_backend.torch_nms.test_class_torch Attributes ~~~~~~~~~~ .. autoapisummary:: kwimage.algo._nms_backend.torch_nms.torch .. py:data:: torch .. py:function:: torch_nms(ltrb, scores, classes=None, thresh=0.5, bias=0, fast=False) Non maximum suppression implemented with pytorch tensors CURRENTLY NOT WORKING :Parameters: * **ltrb** (*Tensor*) -- Bounding boxes of one image in the format (ltrb) * **scores** (*Tensor*) -- Scores of each box * **classes** (*Tensor, optional*) -- the classes of each box. If specified nms is applied to each class separately. * **thresh** (*float*) -- iou threshold :returns: keep: boolean array indicating which boxes were not pruned. :rtype: ByteTensor .. rubric:: Example >>> # DISABLE_DOCTEST >>> # xdoctest: +REQUIRES(module:torch) >>> import torch >>> import numpy as np >>> ltrb = torch.FloatTensor(np.array([ >>> [0, 0, 100, 100], >>> [100, 100, 10, 10], >>> [10, 10, 100, 100], >>> [50, 50, 100, 100], >>> [100, 100, 130, 130], >>> [100, 100, 130, 130], >>> [100, 100, 130, 130], >>> ], dtype=np.float32)) >>> scores = torch.FloatTensor(np.array([.1, .5, .9, .1, .3, .5, .4])) >>> classes = torch.LongTensor(np.array([0, 0, 0, 0, 0, 0, 0])) >>> thresh = .5 >>> flags = torch_nms(ltrb, scores, classes, thresh) >>> keep = np.nonzero(flags).view(-1) >>> ltrb[flags] >>> ltrb[keep] .. rubric:: Example >>> # DISABLE_DOCTEST >>> # xdoctest: +REQUIRES(module:torch) >>> import torch >>> import numpy as np >>> # Test to check that conflicts are correctly resolved >>> ltrb = torch.FloatTensor(np.array([ >>> [100, 100, 150, 101], >>> [120, 100, 180, 101], >>> [150, 100, 200, 101], >>> ], dtype=np.float32)) >>> scores = torch.FloatTensor(np.linspace(.8, .9, len(ltrb))) >>> classes = None >>> thresh = .3 >>> keep = torch_nms(ltrb, scores, classes, thresh, fast=False) >>> bboxes[keep] .. py:function:: test_class_torch()