Source code for kwimage.algo.algo_nms

"""
Generic Non-Maximum Suppression API with efficient backend implementations
"""
import sys
import numpy as np
import ubelt as ub
import warnings
import kwarray


[docs] def daq_spatial_nms(ltrb, scores, diameter, thresh, max_depth=6, stop_size=2048, recsize=2048, impl='auto', device_id=None): """ Divide and conquor speedup non-max-supression algorithm for when bboxes have a known max size Args: ltrb (ndarray): boxes in (tlx, tly, brx, bry) format scores (ndarray): scores of each box diameter (int | Tuple[int, int]): Distance from split point to consider rectification. If specified as an integer, then number is used for both height and width. If specified as a tuple, then dims are assumed to be in [height, width] format. thresh (float): iou threshold. Boxes are removed if they overlap greater than this threshold. 0 is the most strict, resulting in the fewest boxes, and 1 is the most permissive resulting in the most. max_depth (int): maximum number of times we can divide and conquor stop_size (int): number of boxes that triggers full NMS computation recsize (int): number of boxes that triggers full NMS recombination impl (str): algorithm to use Note: # TODO: Look Into # Didn't read yet but it seems similar http://www.cyberneum.de/fileadmin/user_upload/files/publications/CVPR2010-Lampert_[0].pdf https://www.researchgate.net/publication/220929789_Efficient_Non-Maximum_Suppression # This seems very similar https://projet.liris.cnrs.fr/m2disco/pub/Congres/2006-ICPR/DATA/C03_0406.PDF Example: >>> import kwimage >>> # Make a bunch of boxes with the same width and height >>> #boxes = kwimage.Boxes.random(230397, scale=1000, format='cxywh') >>> boxes = kwimage.Boxes.random(237, scale=1000, format='cxywh') >>> boxes.data.T[2] = 10 >>> boxes.data.T[3] = 10 >>> # >>> ltrb = boxes.to_ltrb().data.astype(np.float32) >>> scores = np.arange(0, len(ltrb)).astype(np.float32) >>> # >>> n_megabytes = (ltrb.size * ltrb.dtype.itemsize) / (2 ** 20) >>> print('n_megabytes = {!r}'.format(n_megabytes)) >>> # >>> thresh = iou_thresh = 0.01 >>> impl = 'auto' >>> max_depth = 20 >>> diameter = 10 >>> stop_size = 2000 >>> recsize = 500 >>> # >>> import ubelt as ub >>> # >>> with ub.Timer(label='daq'): >>> keep1 = daq_spatial_nms(ltrb, scores, >>> diameter=diameter, thresh=thresh, max_depth=max_depth, >>> stop_size=stop_size, recsize=recsize, impl=impl) >>> # >>> with ub.Timer(label='full'): >>> keep2 = non_max_supression(ltrb, scores, >>> thresh=thresh, impl=impl) >>> # >>> # Due to the greedy nature of the algorithm, there will be slight >>> # differences in results, but they will be mostly similar. >>> similarity = len(set(keep1) & set(keep2)) / len(set(keep1) | set(keep2)) >>> print('similarity = {!r}'.format(similarity)) """ def _rectify(ltrb, both_keep, needs_rectify): if len(needs_rectify) == 0: keep = sorted(both_keep) else: nr_arr = np.array(sorted(needs_rectify)) nr = needs_rectify bk = set(both_keep) rectified_keep = non_max_supression( ltrb[nr_arr], scores[nr_arr], thresh=thresh, impl=impl, device_id=device_id) rk = set(nr_arr[rectified_keep]) keep = sorted((bk - nr) | rk) return keep def _recurse(ltrb, scores, dim, depth, diameter_wh): """ Args: dim (int): flips between 0 and 1 depth (int): recursion depth """ # print('recurse') n_boxes = len(ltrb) if depth >= max_depth or n_boxes < stop_size: # print('n_boxes = {!r}'.format(n_boxes)) # print('depth = {!r}'.format(depth)) # print('stop') keep = non_max_supression(ltrb, scores, thresh=thresh, impl=impl) both_keep = sorted(keep) needs_rectify = set() else: # Break up the NMS into two subproblems. middle = np.median(ltrb.T[dim]) left_flags = ltrb.T[dim] < middle right_flags = ~left_flags left_idxs = np.where(left_flags)[0] right_idxs = np.where(right_flags)[0] left_scores = scores[left_idxs] left_ltrb = ltrb[left_idxs] right_scores = scores[right_idxs] right_ltrb = ltrb[right_idxs] next_depth = depth + 1 next_dim = 1 - dim # Solve each subproblem left_keep_, lrec_ = _recurse( left_ltrb, left_scores, depth=next_depth, dim=next_dim, diameter_wh=diameter_wh) right_keep_, rrec_ = _recurse( right_ltrb, right_scores, depth=next_depth, dim=next_dim, diameter_wh=diameter_wh) # Recombine the results (note that because we have a diameter_wh, # we have to check less results) rrec = set(right_idxs[sorted(rrec_)]) lrec = set(left_idxs[sorted(lrec_)]) left_keep = left_idxs[left_keep_] right_keep = right_idxs[right_keep_] both_keep = np.hstack([left_keep, right_keep]) both_keep.sort() dist_to_middle = np.abs(ltrb[both_keep].T[dim] - middle) # Find all surviving boxes that are close to the midpoint. We will # need to recheck these because they may overlap, but they also may # have been split into different subproblems. rectify_flags = dist_to_middle < diameter_wh[dim] needs_rectify = set(both_keep[rectify_flags]) needs_rectify.update(rrec) needs_rectify.update(lrec) nrec = len(needs_rectify) # print('nrec = {!r}'.format(nrec)) if nrec > recsize: both_keep = _rectify(ltrb, both_keep, needs_rectify) needs_rectify = set() return both_keep, needs_rectify if not ub.iterable(diameter): diameter_wh = [diameter, diameter] else: diameter_wh = diameter[::-1] depth = 0 dim = 0 both_keep, needs_rectify = _recurse(ltrb, scores, dim=dim, depth=depth, diameter_wh=diameter_wh) keep = _rectify(ltrb, both_keep, needs_rectify) return keep
_impls = None
[docs] class _NMS_Impls(): # TODO: could make this prettier def __init__(self): self._funcs = None
[docs] def _lazy_init(self): _funcs = {} from kwimage import _internal try: import torch except ImportError: torch = None # These are pure python and should always be available from kwimage.algo._nms_backend import py_nms from kwimage.algo._nms_backend import torch_nms _funcs['numpy'] = py_nms.py_nms from distutils.version import LooseVersion recent_numpy = LooseVersion(np.__version__) >= LooseVersion('1.20.0') if torch is not None and recent_numpy: _funcs['torch'] = torch_nms.torch_nms if not _internal.KWIMAGE_DISABLE_TORCHVISION_NMS: # The torchvision _C libraray may cause segfaults, which is # why we have an option to disable even trying it try: # TODO: torchvision impl might be the best, need to test from torchvision import _C as C # NOQA import torchvision _funcs['torchvision'] = torchvision.ops.nms except (ImportError, UnicodeDecodeError) as ex: warnings.warn( 'optional torchvision C nms is not available: {}'.format( str(ex))) if recent_numpy: # Only use cython extensions if numpy is > 1.20. # Otherwise there seems to be a # "Fatal Python error: Illegal instruction" that can occur # I dont know why this is yet # See https://gitlab.kitware.com/computer-vision/kwimage/-/jobs/6061311 # for an example of when this last occurred try: if not _internal.KWIMAGE_DISABLE_C_EXTENSIONS: from kwimage_ext.algo._nms_backend import cpu_nms _funcs['cython_cpu'] = cpu_nms.cpu_nms except Exception as ex: warnings.warn( 'optional cpu_nms is not available: {}'.format(str(ex))) try: if not _internal.KWIMAGE_DISABLE_C_EXTENSIONS: if torch is not None and torch.cuda.is_available(): from kwimage_ext.algo._nms_backend import gpu_nms _funcs['cython_gpu'] = gpu_nms.gpu_nms # NOTE: GPU is not the fastests on all systems. # See the benchmarks for more info. # ~/code/kwimage/dev/bench_nms.py except Exception as ex: warnings.warn ('optional gpu_nms is not available: {}'.format(str(ex))) self._funcs = _funcs self._valid = frozenset(_impls._funcs.keys())
_impls = _NMS_Impls()
[docs] def available_nms_impls(): """ List available values for the `impl` kwarg of `non_max_supression` CommandLine: xdoctest -m kwimage.algo.algo_nms available_nms_impls Example: >>> impls = available_nms_impls() >>> assert 'numpy' in impls >>> print('impls = {!r}'.format(impls)) """ if not _impls._funcs: _impls._lazy_init() return list(_impls._funcs.keys())
# @ub.memoize
[docs] def _heuristic_auto_nms_impl(code, num, valid=None): """ Defined with help from ``~/code/kwimage/dev/bench_nms.py`` Args: code (str): text that indicates which type of data you have tensor0 is a tensor on a cuda device, tensor is on the cpu, and numpy is a ndarray. num (int): number of boxes you have to supress. valid (List[str]): the list of valid implementations, an error will be raised if heuristic preferences do not intersect with this list. Ignore: _impls._funcs valid_pref = ub.oset(preference) & set(_impls._funcs.keys()) python ~/code/kwimage/dev/bench_nms.py --show --small-boxes --thresh=0.6 """ if code not in {'tensor0', 'tensor', 'ndarray'}: raise KeyError(code) if num <= 10: if code == 'tensor0': # dict(cython_cpu=4118.4, torchvision=3042.5, cython_gpu=2244.4, torch=841.9) preference = ['cython_cpu', 'torchvision', 'cython_gpu', 'torch'] if code == 'tensor': # dict(torchvision=5857.1, cython_gpu=3058.1) preference = ['torchvision', 'cython_gpu', 'torch', 'numpy'] if code == 'ndarray': # dict(cython_cpu=12226.1, numpy=7759.1, cython_gpu=3679.0, torch=1786.2) preference = ['cython_cpu', 'numpy', 'cython_gpu', 'torch'] elif num <= 100: if code == 'tensor0': # dict(cython_cpu=4160.7, torchvision=3089.9, cython_gpu=2261.8, torch=846.8) preference = ['cython_cpu', 'torchvision', 'cython_gpu', 'torch', 'numpy'] if code == 'tensor': # dict(torchvision=5875.3, cython_gpu=3076.9) preference = ['torchvision', 'cython_gpu', 'torch', 'numpy'] if code == 'ndarray': # dict(cython_cpu=12256.7, cython_gpu=3702.9, numpy=2311.3, torch=1738.0) preference = ['cython_cpu', 'cython_gpu', 'numpy', 'torch'] elif num <= 200: if code == 'tensor0': # dict(cython_cpu=3460.8, torchvision=2912.9, cython_gpu=2125.2, torch=782.4) preference = ['cython_cpu', 'torchvision', 'cython_gpu', 'torch'] if code == 'tensor': # dict(torchvision=3394.6, cython_gpu=2641.2) preference = ['torchvision', 'cython_gpu', 'torch', 'numpy'] if code == 'ndarray': # dict(cython_cpu=8220.6, cython_gpu=3114.5, torch=1240.7, numpy=309.5) preference = ['cython_cpu', 'cython_gpu', 'torch', 'numpy'] elif num <= 300: if code == 'tensor0': # dict(torchvision=2647.1, cython_cpu=2264.9, cython_gpu=1915.5, torch=672.0) preference = ['torchvision', 'cython_cpu', 'cython_gpu', 'torch'] if code == 'tensor': # dict(cython_gpu=2496.9, torchvision=1781.1) preference = ['cython_gpu', 'torchvision', 'torch', 'numpy'] if code == 'ndarray': # dict(cython_cpu=4085.6, cython_gpu=2944.4, torch=799.8, numpy=173.0) preference = ['cython_cpu', 'cython_gpu', 'torch', 'numpy'] else: if code == 'tensor0': # dict(torchvision=2585.5, cython_gpu=1868.7, cython_cpu=1650.6, torch=623.1) preference = ['torchvision', 'cython_gpu', 'cython_cpu', 'torch'] if code == 'tensor': # dict(cython_gpu=2463.1, torchvision=1126.2) preference = ['cython_gpu', 'torchvision', 'torch', 'numpy'] if code == 'ndarray': # dict(cython_gpu=2880.2, cython_cpu=2432.5, torch=511.9, numpy=114.0) preference = ['cython_gpu', 'cython_cpu', 'torch', 'numpy'] if valid: valid_pref = ub.oset(preference) & valid else: valid_pref = preference if not valid_pref: raise Exception( 'no valid nms algo: code={}, num={}, valid={}, preference={}, valid_pref={}'.format( code, num, valid, preference, valid_pref)) impl = valid_pref[0] return impl
[docs] def non_max_supression(ltrb, scores, thresh, bias=0.0, classes=None, impl='auto', device_id=None): """ Non-Maximum Suppression - remove redundant bounding boxes Based on information from [CythonNMS]_ and [NMSPython]_. Args: ltrb (ndarray[Any, Float32]): Float32 array of shape Nx4 representing boxes in ltrb format scores (ndarray[Any, Float32]): Float32 array of shape N representing scores for each box thresh (float): iou threshold. Boxes are removed if they overlap greater than this threshold (i.e. Boxes are removed if iou > threshold). Thresh = 0 is the most strict, resulting in the fewest boxes, and 1 is the most permissive resulting in the most. bias (float): bias for iou computation either 0 or 1 classes (ndarray[Shape['*'], Int64] | None): integer classes. If specified NMS is done on a perclass basis. impl (str): implementation can be "auto", "python", "cython_cpu", "gpu", "torch", or "torchvision". device_id (int): used if impl is gpu, device id to work on. If not specified `torch.cuda.current_device()` is used. Note: Using impl='cython_gpu' may result in an CUDA memory error that is not exposed to the python processes. In other words your program will hard crash if impl='cython_gpu', and you feed it too many bounding boxes. Ideally this will be fixed in the future. TODO: SoftNMS [SoftNMS]_. References: .. [CythonNMS] https://github.com/facebookresearch/Detectron/blob/master/detectron/utils/cython_nms.pyx .. [NMSPython] https://www.pyimagesearch.com/2015/02/16/faster-non-maximum-suppression-python/ .. [SoftNMS] https://github.com/bharatsingh430/soft-nms/blob/master/lib/nms/cpu_nms.pyx CommandLine: xdoctest -m ~/code/kwimage/kwimage/algo/algo_nms.py non_max_supression Example: >>> from kwimage.algo.algo_nms import * >>> from kwimage.algo.algo_nms import _impls >>> ltrb = np.array([ >>> [0, 0, 100, 100], >>> [100, 100, 10, 10], >>> [10, 10, 100, 100], >>> [50, 50, 100, 100], >>> ], dtype=np.float32) >>> scores = np.array([.1, .5, .9, .1]) >>> keep = non_max_supression(ltrb, scores, thresh=0.5, impl='numpy') >>> print('keep = {!r}'.format(keep)) >>> assert keep == [2, 1, 3] >>> thresh = 0.0 >>> non_max_supression(ltrb, scores, thresh, impl='numpy') >>> if 'numpy' in available_nms_impls(): >>> keep = non_max_supression(ltrb, scores, thresh, impl='numpy') >>> assert list(keep) == [2, 1] >>> if 'cython_cpu' in available_nms_impls(): >>> keep = non_max_supression(ltrb, scores, thresh, impl='cython_cpu') >>> assert list(keep) == [2, 1] >>> if 'cython_gpu' in available_nms_impls(): >>> keep = non_max_supression(ltrb, scores, thresh, impl='cython_gpu') >>> assert list(keep) == [2, 1] >>> if 'torch' in available_nms_impls(): >>> keep = non_max_supression(ltrb, scores, thresh, impl='torch') >>> assert set(keep.tolist()) == {2, 1} >>> if 'torchvision' in available_nms_impls(): >>> keep = non_max_supression(ltrb, scores, thresh, impl='torchvision') # note torchvision has no bias >>> assert list(keep) == [2] >>> thresh = 1.0 >>> if 'numpy' in available_nms_impls(): >>> keep = non_max_supression(ltrb, scores, thresh, impl='numpy') >>> assert list(keep) == [2, 1, 3, 0] >>> if 'cython_cpu' in available_nms_impls(): >>> keep = non_max_supression(ltrb, scores, thresh, impl='cython_cpu') >>> assert list(keep) == [2, 1, 3, 0] >>> if 'cython_gpu' in available_nms_impls(): >>> keep = non_max_supression(ltrb, scores, thresh, impl='cython_gpu') >>> assert list(keep) == [2, 1, 3, 0] >>> if 'torch' in available_nms_impls(): >>> keep = non_max_supression(ltrb, scores, thresh, impl='torch') >>> assert set(keep.tolist()) == {2, 1, 3, 0} >>> if 'torchvision' in available_nms_impls(): >>> keep = non_max_supression(ltrb, scores, thresh, impl='torchvision') # note torchvision has no bias >>> assert set(kwarray.ArrayAPI.tolist(keep)) == {2, 1, 3, 0} Example: >>> import ubelt as ub >>> ltrb = np.array([ >>> [0, 0, 100, 100], >>> [100, 100, 10, 10], >>> [10, 10, 100, 100], >>> [50, 50, 100, 100], >>> [100, 100, 150, 101], >>> [120, 100, 180, 101], >>> [150, 100, 200, 101], >>> ], dtype=np.float32) >>> scores = np.linspace(0, 1, len(ltrb)) >>> thresh = .2 >>> solutions = {} >>> if not _impls._funcs: >>> _impls._lazy_init() >>> for impl in _impls._funcs: >>> keep = non_max_supression(ltrb, scores, thresh, impl=impl) >>> solutions[impl] = sorted(keep) >>> assert 'numpy' in solutions >>> print('solutions = {}'.format(ub.urepr(solutions, nl=1))) >>> assert ub.allsame(solutions.values()) CommandLine: xdoctest -m ~/code/kwimage/kwimage/algo/algo_nms.py non_max_supression Example: >>> import ubelt as ub >>> # Check that zero-area boxes are ok >>> ltrb = np.array([ >>> [0, 0, 0, 0], >>> [0, 0, 0, 0], >>> [10, 10, 10, 10], >>> ], dtype=np.float32) >>> scores = np.array([1, 2, 3], dtype=np.float32) >>> thresh = .2 >>> solutions = {} >>> if not _impls._funcs: >>> _impls._lazy_init() >>> for impl in _impls._funcs: >>> keep = non_max_supression(ltrb, scores, thresh, impl=impl) >>> solutions[impl] = sorted(keep) >>> assert 'numpy' in solutions >>> print('solutions = {}'.format(ub.urepr(solutions, nl=1))) >>> assert ub.allsame(solutions.values()) """ torch = sys.modules.get('torch', None) if impl == 'cpu': import warnings warnings.warn( 'impl="cpu" is deprecated use impl="cython_cpu" instead', DeprecationWarning) impl = 'cython_cpu' elif impl == 'gpu': import warnings warnings.warn( 'impl="gpu" is deprecated use impl="cython_gpu" instead', DeprecationWarning) impl = 'cython_gpu' elif impl == 'py': import warnings warnings.warn( 'impl="py" is deprecated use impl="numpy" instead', DeprecationWarning) impl = 'numpy' if not _impls._funcs: _impls._lazy_init() if ltrb.shape[0] == 0: return [] if impl == 'auto': is_tensor = torch is not None and torch.is_tensor(ltrb) num = len(ltrb) if is_tensor: if ltrb.device.type == 'cuda': code = 'tensor0' else: code = 'tensor' else: code = 'ndarray' valid = _impls._valid impl = _heuristic_auto_nms_impl(code, num, valid) # print('impl._valid = {!r}'.format(_impls._valid)) # print('impl = {!r}'.format(impl)) elif ub.iterable(impl): # if impl is iterable, it is a preference order found = False for item in impl: if item in _impls._funcs: impl = item found = True break if not found: raise KeyError('Unknown impls={}'.format(impl)) if classes is not None: keep = [] for idxs in ub.group_items(range(len(classes)), classes).values(): # cls_ltrb = ltrb.take(idxs, axis=0) # cls_scores = scores.take(idxs, axis=0) cls_ltrb = ltrb[idxs] cls_scores = scores[idxs] cls_keep = non_max_supression(cls_ltrb, cls_scores, thresh=thresh, bias=bias, impl=impl) keep.extend(list(ub.take(idxs, cls_keep))) return keep else: if impl == 'numpy': api = kwarray.ArrayAPI.coerce(ltrb) ltrb = api.numpy(ltrb) scores = api.numpy(scores) func = _impls._funcs['numpy'] keep = func(ltrb, scores, thresh, bias=float(bias)) elif impl == 'torch' or impl == 'torchvision': api = kwarray.ArrayAPI.coerce(ltrb) ltrb = api.tensor(ltrb).float() scores = api.tensor(scores).float() # Default output of torch impl is a mask if impl == 'torchvision': # if bias != 1: # warnings.warn('torchvision only supports bias==1') func = _impls._funcs['torchvision'] # Torchvision returns indices keep = func(ltrb, scores, iou_threshold=thresh) else: func = _impls._funcs['torch'] flags = func(ltrb, scores, thresh=thresh, bias=float(bias)) keep = torch.nonzero(flags).view(-1) # Ensure than input type is the same as output type keep = api.numpy(keep) else: # TODO: it would be nice to be able to pass torch tensors here nms = _impls._funcs[impl] ltrb = kwarray.ArrayAPI.numpy(ltrb) scores = kwarray.ArrayAPI.numpy(scores) ltrb = ltrb.astype(np.float32) scores = scores.astype(np.float32) if impl == 'cython_gpu': # TODO: if the data is already on a torch GPU can we just # use it? # HACK: we should parameterize which device is used if device_id is None: device_id = torch.cuda.current_device() keep = nms(ltrb, scores, float(thresh), bias=float(bias), device_id=device_id) elif impl == 'cython_cpu': keep = nms(ltrb, scores, float(thresh), bias=float(bias)) else: raise KeyError(impl) return keep
# TODO: add soft nms bindings if __name__ == '__main__': """ CommandLine: xdoctest -m kwimage.algo.algo_nms """ import xdoctest xdoctest.doctest_module(__file__)