diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index b799f20..fa73e4c 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -51,11 +51,16 @@ jobs: fail-fast: false matrix: py_ver: [ "3.8", "3.9", "3.10", "3.11" ] - torch_ver: ["2.2.*"] - torchvision_ver: ["0.17.*"] # <0.17 incompatible with torch2.2 - ort_ver: ["1.15.*", "1.16.*", "1.17.*"] + ort_ver: ["1.15.*", "1.16.*", "1.17.*"] ort_ext_ver: ["0.8.*", "0.9.*", "0.10.*"] onnx_ver: ["1.14.*", "1.15.*"] + include: + - torch_ver: "2.2.*" + torchvision_ver: "0.17.*" + - torch_ver: "2.1.*" + torchvision_ver: "0.16.*" + - torch_ver: "2.0.*" + torchvision_ver: "0.15.*" exclude: - py_ver: "3.11" ort_ext_ver: "0.8.*" diff --git a/sony_custom_layers/__init__.py b/sony_custom_layers/__init__.py index abd89cb..e8a2814 100644 --- a/sony_custom_layers/__init__.py +++ b/sony_custom_layers/__init__.py @@ -17,6 +17,6 @@ # for use by setup.py and for dynamic validation in sony_custom_layers.{keras, pytorch}.__init__ requirements = { 'tf': ['tensorflow>=2.10,<2.16'], - 'torch': ['torch>=2.2.0', 'torchvision>=0.17.0'], + 'torch': ['torch>=2.0', 'torchvision>=0.15'], 'torch_ort': ['onnxruntime', 'onnxruntime_extensions>=0.8.0'], } diff --git a/sony_custom_layers/keras/__init__.py b/sony_custom_layers/keras/__init__.py index 4d7d25e..6ad534e 100644 --- a/sony_custom_layers/keras/__init__.py +++ b/sony_custom_layers/keras/__init__.py @@ -14,10 +14,10 @@ # limitations under the License. # ----------------------------------------------------------------------------- -from sony_custom_layers.util.import_util import check_pip_requirements +from sony_custom_layers.util.import_util import validate_pip_requirements from sony_custom_layers import requirements -check_pip_requirements(requirements['tf']) +validate_pip_requirements(requirements['tf']) from .object_detection import FasterRCNNBoxDecode, SSDPostProcess, ScoreConverter # noqa: E402 from .custom_objects import custom_layers_scope # noqa: E402 diff --git a/sony_custom_layers/pytorch/__init__.py b/sony_custom_layers/pytorch/__init__.py index e29a600..5e6e8db 100644 --- a/sony_custom_layers/pytorch/__init__.py +++ b/sony_custom_layers/pytorch/__init__.py @@ -15,7 +15,7 @@ # ----------------------------------------------------------------------------- from typing import Optional, TYPE_CHECKING -from sony_custom_layers.util.import_util import check_pip_requirements +from sony_custom_layers.util.import_util import validate_pip_requirements from sony_custom_layers import requirements if TYPE_CHECKING: @@ -23,7 +23,7 @@ __all__ = ['multiclass_nms', 'NMSResults', 'load_custom_ops'] -check_pip_requirements(requirements['torch']) +validate_pip_requirements(requirements['torch']) from .object_detection import multiclass_nms, NMSResults # noqa: E402 @@ -53,7 +53,7 @@ def load_custom_ops(load_ort: bool = False, SessionOptions object if ort registration was requested, otherwise None """ if load_ort or ort_session_ops: - check_pip_requirements(requirements['torch_ort']) + validate_pip_requirements(requirements['torch_ort']) # trigger onnxruntime op registration from .object_detection import nms_ort diff --git a/sony_custom_layers/pytorch/object_detection/nms.py b/sony_custom_layers/pytorch/object_detection/nms.py index 5dd0981..46073aa 100644 --- a/sony_custom_layers/pytorch/object_detection/nms.py +++ b/sony_custom_layers/pytorch/object_detection/nms.py @@ -20,7 +20,11 @@ from torch import Tensor import torchvision # noqa: F401 # needed for torch.ops.torchvision -MULTICLASS_NMS_TORCH_OP = 'sony::multiclass_nms' +from sony_custom_layers.util.import_util import is_compatible + +CUSTOM_LIB_NAME = 'sony' +MULTICLASS_NMS_TORCH_OP = 'multiclass_nms' +MULTICLASS_NMS_TORCH_OP_QUALNAME = CUSTOM_LIB_NAME + '::' + MULTICLASS_NMS_TORCH_OP __all__ = ['multiclass_nms', 'NMSResults'] @@ -57,13 +61,19 @@ def multiclass_nms(boxes, scores, score_threshold: float, iou_threshold: float, return NMSResults(*torch.ops.sony.multiclass_nms(boxes, scores, score_threshold, iou_threshold, max_detections)) -torch.library.define( - MULTICLASS_NMS_TORCH_OP, - "(Tensor boxes, Tensor scores, float score_threshold, float iou_threshold, SymInt max_detections) -> " - "(Tensor, Tensor, Tensor, Tensor)") +custom_lib = torch.library.Library(CUSTOM_LIB_NAME, "DEF") +schema = (MULTICLASS_NMS_TORCH_OP + + "(Tensor boxes, Tensor scores, float score_threshold, float iou_threshold, SymInt max_detections) " + "-> (Tensor, Tensor, Tensor, Tensor)") +op_name = custom_lib.define(schema) + +if is_compatible('torch>=2.2'): + register_impl = torch.library.impl(MULTICLASS_NMS_TORCH_OP_QUALNAME, 'default') +else: + register_impl = torch.library.impl(custom_lib, MULTICLASS_NMS_TORCH_OP) -@torch.library.impl(MULTICLASS_NMS_TORCH_OP, 'default') +@register_impl def _multiclass_nms_op(boxes: torch.Tensor, scores: torch.Tensor, score_threshold: float, iou_threshold: float, max_detections: int) -> NMSResults: """ Registers the torch op as torch.ops.sony.multiclass_nms """ @@ -74,19 +84,21 @@ def _multiclass_nms_op(boxes: torch.Tensor, scores: torch.Tensor, score_threshol max_detections=max_detections) -@torch.library.impl_abstract(MULTICLASS_NMS_TORCH_OP) -def _multiclass_nms_meta(boxes: torch.Tensor, scores: torch.Tensor, score_threshold: float, iou_threshold: float, - max_detections: int) -> NMSResults: - """ Registers torch op's abstract implementation. It specifies the properties of the output tensors. - Needed for torch.export """ - ctx = torch.library.get_ctx() - batch = ctx.new_dynamic_size() - return NMSResults( - torch.empty((batch, max_detections, 4)), - torch.empty((batch, max_detections)), - torch.empty((batch, max_detections), dtype=torch.int64), - torch.empty((batch, 1), dtype=torch.int64) - ) # yapf: disable +if is_compatible('torch>=2.2'): + + @torch.library.impl_abstract(MULTICLASS_NMS_TORCH_OP_QUALNAME) + def _multiclass_nms_meta(boxes: torch.Tensor, scores: torch.Tensor, score_threshold: float, iou_threshold: float, + max_detections: int) -> NMSResults: + """ Registers torch op's abstract implementation. It specifies the properties of the output tensors. + Needed for torch.export """ + ctx = torch.library.get_ctx() + batch = ctx.new_dynamic_size() + return NMSResults( + torch.empty((batch, max_detections, 4)), + torch.empty((batch, max_detections)), + torch.empty((batch, max_detections), dtype=torch.int64), + torch.empty((batch, 1), dtype=torch.int64) + ) # yapf: disable def _multiclass_nms_impl(boxes: Union[Tensor, np.ndarray], scores: Union[Tensor, np.ndarray], score_threshold: float, diff --git a/sony_custom_layers/pytorch/object_detection/nms_onnx.py b/sony_custom_layers/pytorch/object_detection/nms_onnx.py index d514a77..83e70a8 100644 --- a/sony_custom_layers/pytorch/object_detection/nms_onnx.py +++ b/sony_custom_layers/pytorch/object_detection/nms_onnx.py @@ -15,7 +15,7 @@ # ----------------------------------------------------------------------------- import torch -from .nms import MULTICLASS_NMS_TORCH_OP +from .nms import MULTICLASS_NMS_TORCH_OP_QUALNAME MULTICLASS_NMS_ONNX_OP = "Sony::MultiClassNMS" @@ -42,4 +42,4 @@ def multiclass_nms_onnx(g, boxes, scores, score_threshold, iou_threshold, max_de return outputs -torch.onnx.register_custom_op_symbolic(MULTICLASS_NMS_TORCH_OP, multiclass_nms_onnx, opset_version=1) +torch.onnx.register_custom_op_symbolic(MULTICLASS_NMS_TORCH_OP_QUALNAME, multiclass_nms_onnx, opset_version=1) diff --git a/sony_custom_layers/pytorch/tests/object_detection/test_multiclass_nms.py b/sony_custom_layers/pytorch/tests/object_detection/test_multiclass_nms.py index 5df1f76..d88e4d2 100644 --- a/sony_custom_layers/pytorch/tests/object_detection/test_multiclass_nms.py +++ b/sony_custom_layers/pytorch/tests/object_detection/test_multiclass_nms.py @@ -25,6 +25,7 @@ from sony_custom_layers.pytorch.object_detection import nms from sony_custom_layers.pytorch import load_custom_ops +from sony_custom_layers.util.import_util import is_compatible from sony_custom_layers.util.test_util import exec_in_clean_process @@ -261,6 +262,7 @@ def test_ort(self, dynamic_batch, tmpdir_factory): """ exec_in_clean_process(code, check=True) + @pytest.mark.skipif(not is_compatible('torch>=2.2'), reason='unsupported') def test_pt2_export(self, tmpdir_factory): def f(boxes, scores): diff --git a/sony_custom_layers/util/import_util.py b/sony_custom_layers/util/import_util.py index b6bf03b..3ad9d5e 100644 --- a/sony_custom_layers/util/import_util.py +++ b/sony_custom_layers/util/import_util.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ----------------------------------------------------------------------------- -from typing import List +from typing import List, Union from packaging.requirements import Requirement from packaging.version import parse @@ -24,9 +24,9 @@ class RequirementError(Exception): pass -def check_pip_requirements(requirements: List[str]): +def validate_pip_requirements(requirements: List[str]): """ - Check if all requirements are installed and meet the version specifications. + Validate that all requirements are installed and meet the version specifications. Args: requirements: a list of pip-style requirement strings @@ -47,3 +47,20 @@ def check_pip_requirements(requirements: List[str]): error += f'\nRequired {req_str}, installed version {installed_ver}' if error: raise RequirementError(error) + + +def is_compatible(requirements: Union[str, List]) -> bool: + """ + Non-raising requirement(s) check + Args: + requirements (str, List): requirement pip-style string or a list of requirement strings + + Returns: + (bool) whether requirement(s) are satisfied + """ + requirements = [requirements] if isinstance(requirements, str) else requirements + try: + validate_pip_requirements(requirements) + except RequirementError: + return False + return True