kwimage.algo._nms_backend.torch_nms

Module Contents

Functions

torch_nms(ltrb, scores, classes=None, thresh=0.5, bias=0, fast=False)

Non maximum suppression implemented with pytorch tensors

test_class_torch()

Attributes

torch

kwimage.algo._nms_backend.torch_nms.torch[source]
kwimage.algo._nms_backend.torch_nms.torch_nms(ltrb, scores, classes=None, thresh=0.5, bias=0, fast=False)[source]

Non maximum suppression implemented with pytorch tensors

CURRENTLY NOT WORKING

Parameters
  • ltrb (Tensor) – Bounding boxes of one image in the format (ltrb)

  • scores (Tensor) – Scores of each box

  • classes (Tensor, optional) – the classes of each box. If specified nms is applied to each class separately.

  • thresh (float) – iou threshold

Returns

keep: boolean array indicating which boxes were not pruned.

Return type

ByteTensor

Example

>>> # DISABLE_DOCTEST
>>> # xdoctest: +REQUIRES(module:torch)
>>> import torch
>>> import numpy as np
>>> ltrb = torch.FloatTensor(np.array([
>>>     [0, 0, 100, 100],
>>>     [100, 100, 10, 10],
>>>     [10, 10, 100, 100],
>>>     [50, 50, 100, 100],
>>>     [100, 100, 130, 130],
>>>     [100, 100, 130, 130],
>>>     [100, 100, 130, 130],
>>> ], dtype=np.float32))
>>> scores = torch.FloatTensor(np.array([.1, .5, .9, .1, .3, .5, .4]))
>>> classes = torch.LongTensor(np.array([0, 0, 0, 0, 0, 0, 0]))
>>> thresh = .5
>>> flags = torch_nms(ltrb, scores, classes, thresh)
>>> keep = np.nonzero(flags).view(-1)
>>> ltrb[flags]
>>> ltrb[keep]

Example

>>> # DISABLE_DOCTEST
>>> # xdoctest: +REQUIRES(module:torch)
>>> import torch
>>> import numpy as np
>>> # Test to check that conflicts are correctly resolved
>>> ltrb = torch.FloatTensor(np.array([
>>>     [100, 100, 150, 101],
>>>     [120, 100, 180, 101],
>>>     [150, 100, 200, 101],
>>> ], dtype=np.float32))
>>> scores = torch.FloatTensor(np.linspace(.8, .9, len(ltrb)))
>>> classes = None
>>> thresh = .3
>>> keep = torch_nms(ltrb, scores, classes, thresh, fast=False)
>>> bboxes[keep]
kwimage.algo._nms_backend.torch_nms.test_class_torch()[source]