Skip to content

Commit

Permalink
Merge pull request #27 from sony/indices
Browse files Browse the repository at this point in the history
[torch] add multiclass_nms_with_indices layer
  • Loading branch information
irenaby committed Sep 10, 2024
2 parents 8f26840 + e3510cf commit 49d5483
Show file tree
Hide file tree
Showing 14 changed files with 1,310 additions and 541 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ jobs:
python-version: ${{matrix.py_ver}}
- name: Install dependencies
run: |
pip install tensorflow==${{matrix.tf_ver}}.*
if [ ${{matrix.tf_ver}} == 2.10 ] || [ ${{matrix.tf_ver}} == 2.11 ];then
extra_req='numpy<2'
fi
pip install tensorflow==${{matrix.tf_ver}}.* $extra_req
pip install -r requirements_test.txt
pip list
- name: Run pytest
Expand Down
6 changes: 3 additions & 3 deletions docs/index.html

Large diffs are not rendered by default.

24 changes: 15 additions & 9 deletions docs/sony_custom_layers/keras.html

Large diffs are not rendered by default.

754 changes: 552 additions & 202 deletions docs/sony_custom_layers/pytorch.html

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion sony_custom_layers/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@
if TYPE_CHECKING:
import onnxruntime as ort

__all__ = ['multiclass_nms', 'NMSResults', 'load_custom_ops']
__all__ = ['multiclass_nms', 'NMSResults', 'multiclass_nms_with_indices', 'NMSWithIndicesResults', 'load_custom_ops']

validate_installed_libraries(required_libraries['torch'])

from .object_detection import multiclass_nms, NMSResults # noqa: E402
from .object_detection import multiclass_nms_with_indices, NMSWithIndicesResults # noqa: E402


def load_custom_ops(load_ort: bool = False,
Expand Down
53 changes: 53 additions & 0 deletions sony_custom_layers/pytorch/custom_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# -----------------------------------------------------------------------------
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
from typing import Callable

import torch

from sony_custom_layers.util.import_util import is_compatible

CUSTOM_LIB_NAME = 'sony'
custom_lib = torch.library.Library(CUSTOM_LIB_NAME, "DEF")


def get_op_qualname(torch_op_name):
""" Op qualified name """
return CUSTOM_LIB_NAME + '::' + torch_op_name


def register_op(torch_op_name: str, schema: str, impl: Callable):
"""
Register torch custom op under the custom library.
Args:
torch_op_name: op name to register.
schema: schema for the custom op.
impl: implementation of the custom op.
Returns:
Custom op qualified name.
"""
torch_op_qualname = get_op_qualname(torch_op_name)

custom_lib.define(schema)

if is_compatible('torch>=2.2'):
register_impl = torch.library.impl(torch_op_qualname, 'default')
else:
register_impl = torch.library.impl(custom_lib, torch_op_name)
register_impl(impl)

return torch_op_qualname
9 changes: 8 additions & 1 deletion sony_custom_layers/pytorch/object_detection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,14 @@
# -----------------------------------------------------------------------------

from .nms import multiclass_nms, NMSResults
from .nms_with_indices import multiclass_nms_with_indices, NMSWithIndicesResults

# trigger onnx op registration
from . import nms_onnx

__all__ = ['multiclass_nms', 'NMSResults']
__all__ = [
'multiclass_nms',
'multiclass_nms_with_indices',
'NMSResults',
'NMSWithIndicesResults',
]
168 changes: 33 additions & 135 deletions sony_custom_layers/pytorch/object_detection/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
from typing import Tuple, NamedTuple, Union, Callable
from typing import NamedTuple, Callable

import numpy as np
import torch
from torch import Tensor
import torchvision # noqa: F401 # needed for torch.ops.torchvision

from sony_custom_layers.pytorch.custom_lib import register_op
from sony_custom_layers.pytorch.object_detection.nms_common import _batch_multiclass_nms, SCORES, LABELS
from sony_custom_layers.util.import_util import is_compatible

CUSTOM_LIB_NAME = 'sony'
MULTICLASS_NMS_TORCH_OP = 'multiclass_nms'
MULTICLASS_NMS_TORCH_OP_QUALNAME = CUSTOM_LIB_NAME + '::' + MULTICLASS_NMS_TORCH_OP

__all__ = ['multiclass_nms', 'NMSResults']

Expand All @@ -36,17 +35,19 @@ class NMSResults(NamedTuple):
labels: Tensor
n_valid: Tensor

# Note: convenience methods below are replicated in each Results container, since NamedTuple supports neither adding
# new fields in derived classes nor multiple inheritance, and we want it to behave like a tuple, so no dataclasses.
def detach(self) -> 'NMSResults':
""" Detach all tensors and return a new NMSResults object """
""" Detach all tensors and return a new object """
return self.apply(lambda t: t.detach())

def cpu(self) -> 'NMSResults':
""" Move all tensors to cpu and return a new NMSResults object """
""" Move all tensors to cpu and return a new object """
return self.apply(lambda t: t.cpu())

def apply(self, f: Callable[[Tensor], Tensor]) -> 'NMSResults':
""" Apply any function to all tensors and return a NMSResults new object """
return NMSResults(*[f(t) for t in self])
""" Apply any function to all tensors and return a new object """
return self.__class__(*[f(t) for t in self])


def multiclass_nms(boxes, scores, score_threshold: float, iou_threshold: float, max_detections: int) -> NMSResults:
Expand All @@ -56,6 +57,8 @@ def multiclass_nms(boxes, scores, score_threshold: float, iou_threshold: float,
The output tensors always contain a fixed number of detections, as defined by 'max_detections'.
If fewer detections are selected, the output tensors are zero-padded up to 'max_detections'.
If you also require the input indices of the selected boxes, see `multiclass_nms_with_indices`.
Args:
boxes (Tensor): Input boxes with shape [batch, n_boxes, 4], specified in corner coordinates
(x_min, y_min, x_max, y_max). Agnostic to the x-y axes order.
Expand Down Expand Up @@ -92,32 +95,35 @@ def multiclass_nms(boxes, scores, score_threshold: float, iou_threshold: float,
return NMSResults(*torch.ops.sony.multiclass_nms(boxes, scores, score_threshold, iou_threshold, max_detections))


custom_lib = torch.library.Library(CUSTOM_LIB_NAME, "DEF")
schema = (MULTICLASS_NMS_TORCH_OP +
"(Tensor boxes, Tensor scores, float score_threshold, float iou_threshold, SymInt max_detections) "
"-> (Tensor, Tensor, Tensor, Tensor)")
op_name = custom_lib.define(schema)
######################
# Register custom op #
######################

if is_compatible('torch>=2.2'):
register_impl = torch.library.impl(MULTICLASS_NMS_TORCH_OP_QUALNAME, 'default')
else:
register_impl = torch.library.impl(custom_lib, MULTICLASS_NMS_TORCH_OP)

def _multiclass_nms_impl(boxes: torch.Tensor, scores: torch.Tensor, score_threshold: float, iou_threshold: float,
max_detections: int) -> NMSResults:
""" This implementation is intended only to be registered as custom torch and onnxruntime op.
NamedTuple is used for clarity, it is not preserved when run through torch / onnxruntime op. """
res, valid_dets = _batch_multiclass_nms(boxes,
scores,
score_threshold=score_threshold,
iou_threshold=iou_threshold,
max_detections=max_detections)
return NMSResults(boxes=res[..., :4],
scores=res[..., SCORES],
labels=res[..., LABELS].to(torch.int64),
n_valid=valid_dets.to(torch.int64))

@register_impl
def _multiclass_nms_op(boxes: torch.Tensor, scores: torch.Tensor, score_threshold: float, iou_threshold: float,
max_detections: int) -> NMSResults:
""" Registers the torch op as torch.ops.sony.multiclass_nms """
return _multiclass_nms_impl(boxes,
scores,
score_threshold=score_threshold,
iou_threshold=iou_threshold,
max_detections=max_detections)

schema = (MULTICLASS_NMS_TORCH_OP +
"(Tensor boxes, Tensor scores, float score_threshold, float iou_threshold, SymInt max_detections) "
"-> (Tensor, Tensor, Tensor, Tensor)")

op_qualname = register_op(MULTICLASS_NMS_TORCH_OP, schema, _multiclass_nms_impl)

if is_compatible('torch>=2.2'):

@torch.library.impl_abstract(MULTICLASS_NMS_TORCH_OP_QUALNAME)
@torch.library.impl_abstract(op_qualname)
def _multiclass_nms_meta(boxes: torch.Tensor, scores: torch.Tensor, score_threshold: float, iou_threshold: float,
max_detections: int) -> NMSResults:
""" Registers torch op's abstract implementation. It specifies the properties of the output tensors.
Expand All @@ -130,111 +136,3 @@ def _multiclass_nms_meta(boxes: torch.Tensor, scores: torch.Tensor, score_thresh
torch.empty((batch, max_detections), dtype=torch.int64),
torch.empty((batch, 1), dtype=torch.int64)
) # yapf: disable


def _multiclass_nms_impl(boxes: Union[Tensor, np.ndarray], scores: Union[Tensor, np.ndarray], score_threshold: float,
iou_threshold: float, max_detections: int) -> NMSResults:
""" See multiclass_nms """
# this is needed for onnxruntime implementation
if not isinstance(boxes, Tensor):
boxes = Tensor(boxes)
if not isinstance(scores, Tensor):
scores = Tensor(scores)

if not 0 <= score_threshold <= 1:
raise ValueError(f'Invalid score_threshold {score_threshold} not in range [0, 1]')
if not 0 <= iou_threshold <= 1:
raise ValueError(f'Invalid iou_threshold {iou_threshold} not in range [0, 1]')
if max_detections <= 0:
raise ValueError(f'Invalid non-positive max_detections {max_detections}')

if boxes.ndim != 3 or boxes.shape[-1] != 4:
raise ValueError(f'Invalid input boxes shape {boxes.shape}. Expected shape (batch, n_boxes, 4).')
if scores.ndim != 3:
raise ValueError(f'Invalid input scores shape {scores.shape}. Expected shape (batch, n_boxes, n_classes).')
if boxes.shape[-2] != scores.shape[-2]:
raise ValueError(f'Mismatch in the number of boxes between input boxes ({boxes.shape[-2]}) '
f'and scores ({scores.shape[-2]})')

batch = boxes.shape[0]
res = torch.zeros((batch, max_detections, 6), device=boxes.device)
valid_dets = torch.zeros((batch, 1), device=boxes.device)
for i in range(batch):
res[i], valid_dets[i] = _image_multiclass_nms(boxes[i],
scores[i],
score_threshold=score_threshold,
iou_threshold=iou_threshold,
max_detections=max_detections)

return NMSResults(boxes=res[..., :4],
scores=res[..., 4],
labels=res[..., 5].to(torch.int64),
n_valid=valid_dets.to(torch.int64))


def _image_multiclass_nms(boxes: Tensor, scores: Tensor, score_threshold: float, iou_threshold: float,
max_detections: int) -> Tuple[Tensor, int]:
"""
Performs multi-class non-maximum suppression on a single image
Args:
boxes: input boxes of shape [n_boxes, 4]
scores: input scores of shape [n_boxes, n_classes]
score_threshold: score threshold
iou_threshold: intersection over union threshold
max_detections: fixed number of detections to return
Returns:
A tensor of shape [max_detections, 6] and the number of valid detections.
out[:, :4] contains the selected boxes
out[:, 4] and out[:, 5] contain the scores and labels for the selected boxes
"""
x = _convert_inputs(boxes, scores, score_threshold)
out = torch.zeros(max_detections, 6, device=boxes.device)
if x.size(0) == 0:
return out, 0
idxs = _nms_with_class_offsets(x, iou_threshold=iou_threshold)
idxs = idxs[:max_detections]
valid_dets = idxs.numel()
out[:valid_dets] = x[idxs]
return out, valid_dets


def _convert_inputs(boxes: Tensor, scores: Tensor, score_threshold: float) -> Tensor:
"""
Converts inputs and filters out boxes with score below the threshold.
Args:
boxes: input boxes of shape [n_boxes, 4]
scores: input scores of shape [n_boxes, n_classes]
score_threshold: score threshold for nms candidates
Returns:
A tensor of shape [m, 6] containing m nms candidates above the score threshold.
x[:, :4] contains the boxes with replication for different labels
x[:, 4] contains the scores
x[:, 5] contains the labels indices (label i corresponds to input scores[:, i])
"""
n_boxes, n_classes = scores.shape
scores_mask = scores > score_threshold
box_indices = torch.arange(n_boxes, device=boxes.device).unsqueeze(1).expand(-1, n_classes)[scores_mask]
x = torch.empty((box_indices.numel(), 6), device=boxes.device)
x[:, :4] = boxes[box_indices]
x[:, 4] = scores[scores_mask]
x[:, 5] = torch.arange(n_classes, device=boxes.device).unsqueeze(0).expand(n_boxes, -1)[scores_mask]
return x


def _nms_with_class_offsets(x: Tensor, iou_threshold: float) -> Tensor:
"""
Args:
x: nms candidates of shape [n, 6] ([:,:4] boxes, [:, 4] scores, [:, 5] labels)
iou_threshold: intersection over union threshold
Returns:
Indices of the selected candidates
"""
# shift boxes of each class to prevent intersection between boxes of different classes, and use single-class nms
# (similar to torchvision batched_nms trick)
offsets = x[:, 5:] * (x[:, :4].max() + 1)
shifted_boxes = x[:, :4] + offsets
return torch.ops.torchvision.nms(shifted_boxes, x[:, 4], iou_threshold)
Loading

0 comments on commit 49d5483

Please sign in to comment.