Skip to content

Commit

Permalink
add support for torch 2.0, 2.1
Browse files Browse the repository at this point in the history
  • Loading branch information
irenaby committed Mar 31, 2024
1 parent 8b6a240 commit ba8cc5c
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 33 deletions.
11 changes: 8 additions & 3 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,16 @@ jobs:
fail-fast: false
matrix:
py_ver: [ "3.8", "3.9", "3.10", "3.11" ]
torch_ver: ["2.2.*"]
torchvision_ver: ["0.17.*"] # <0.17 incompatible with torch2.2
ort_ver: ["1.15.*", "1.16.*", "1.17.*"]
ort_ver: ["1.15.*", "1.16.*", "1.17.*"]
ort_ext_ver: ["0.8.*", "0.9.*", "0.10.*"]
onnx_ver: ["1.14.*", "1.15.*"]
include:
- torch_ver: "2.2.*"
torchvision_ver: "0.17.*"
- torch_ver: "2.1.*"
torchvision_ver: "0.16.*"
- torch_ver: "2.0.*"
torchvision_ver: "0.15.*"
exclude:
- py_ver: "3.11"
ort_ext_ver: "0.8.*"
Expand Down
2 changes: 1 addition & 1 deletion sony_custom_layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@
# for use by setup.py and for dynamic validation in sony_custom_layers.{keras, pytorch}.__init__
requirements = {
'tf': ['tensorflow>=2.10,<2.16'],
'torch': ['torch>=2.2.0', 'torchvision>=0.17.0'],
'torch': ['torch>=2.0', 'torchvision>=0.15'],
'torch_ort': ['onnxruntime', 'onnxruntime_extensions>=0.8.0'],
}
4 changes: 2 additions & 2 deletions sony_custom_layers/keras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
# limitations under the License.
# -----------------------------------------------------------------------------

from sony_custom_layers.util.import_util import check_pip_requirements
from sony_custom_layers.util.import_util import validate_pip_requirements
from sony_custom_layers import requirements

check_pip_requirements(requirements['tf'])
validate_pip_requirements(requirements['tf'])

from .object_detection import FasterRCNNBoxDecode, SSDPostProcess, ScoreConverter # noqa: E402
from .custom_objects import custom_layers_scope # noqa: E402
6 changes: 3 additions & 3 deletions sony_custom_layers/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
# -----------------------------------------------------------------------------
from typing import Optional, TYPE_CHECKING

from sony_custom_layers.util.import_util import check_pip_requirements
from sony_custom_layers.util.import_util import validate_pip_requirements
from sony_custom_layers import requirements

if TYPE_CHECKING:
import onnxruntime as ort

__all__ = ['multiclass_nms', 'NMSResults', 'load_custom_ops']

check_pip_requirements(requirements['torch'])
validate_pip_requirements(requirements['torch'])

from .object_detection import multiclass_nms, NMSResults # noqa: E402

Expand Down Expand Up @@ -53,7 +53,7 @@ def load_custom_ops(load_ort: bool = False,
SessionOptions object if ort registration was requested, otherwise None
"""
if load_ort or ort_session_ops:
check_pip_requirements(requirements['torch_ort'])
validate_pip_requirements(requirements['torch_ort'])

# trigger onnxruntime op registration
from .object_detection import nms_ort
Expand Down
50 changes: 31 additions & 19 deletions sony_custom_layers/pytorch/object_detection/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
from torch import Tensor
import torchvision # noqa: F401 # needed for torch.ops.torchvision

MULTICLASS_NMS_TORCH_OP = 'sony::multiclass_nms'
from sony_custom_layers.util.import_util import is_compatible

CUSTOM_LIB_NAME = 'sony'
MULTICLASS_NMS_TORCH_OP = 'multiclass_nms'
MULTICLASS_NMS_TORCH_OP_QUALNAME = CUSTOM_LIB_NAME + '::' + MULTICLASS_NMS_TORCH_OP

__all__ = ['multiclass_nms', 'NMSResults']

Expand Down Expand Up @@ -57,13 +61,19 @@ def multiclass_nms(boxes, scores, score_threshold: float, iou_threshold: float,
return NMSResults(*torch.ops.sony.multiclass_nms(boxes, scores, score_threshold, iou_threshold, max_detections))


torch.library.define(
MULTICLASS_NMS_TORCH_OP,
"(Tensor boxes, Tensor scores, float score_threshold, float iou_threshold, SymInt max_detections) -> "
"(Tensor, Tensor, Tensor, Tensor)")
custom_lib = torch.library.Library(CUSTOM_LIB_NAME, "DEF")
schema = (MULTICLASS_NMS_TORCH_OP +
"(Tensor boxes, Tensor scores, float score_threshold, float iou_threshold, SymInt max_detections) "
"-> (Tensor, Tensor, Tensor, Tensor)")
op_name = custom_lib.define(schema)

if is_compatible('torch>=2.2'):
register_impl = torch.library.impl(MULTICLASS_NMS_TORCH_OP_QUALNAME, 'default')
else:
register_impl = torch.library.impl(custom_lib, MULTICLASS_NMS_TORCH_OP)


@torch.library.impl(MULTICLASS_NMS_TORCH_OP, 'default')
@register_impl
def _multiclass_nms_op(boxes: torch.Tensor, scores: torch.Tensor, score_threshold: float, iou_threshold: float,
max_detections: int) -> NMSResults:
""" Registers the torch op as torch.ops.sony.multiclass_nms """
Expand All @@ -74,19 +84,21 @@ def _multiclass_nms_op(boxes: torch.Tensor, scores: torch.Tensor, score_threshol
max_detections=max_detections)


@torch.library.impl_abstract(MULTICLASS_NMS_TORCH_OP)
def _multiclass_nms_meta(boxes: torch.Tensor, scores: torch.Tensor, score_threshold: float, iou_threshold: float,
max_detections: int) -> NMSResults:
""" Registers torch op's abstract implementation. It specifies the properties of the output tensors.
Needed for torch.export """
ctx = torch.library.get_ctx()
batch = ctx.new_dynamic_size()
return NMSResults(
torch.empty((batch, max_detections, 4)),
torch.empty((batch, max_detections)),
torch.empty((batch, max_detections), dtype=torch.int64),
torch.empty((batch, 1), dtype=torch.int64)
) # yapf: disable
if is_compatible('torch>=2.2'):

@torch.library.impl_abstract(MULTICLASS_NMS_TORCH_OP_QUALNAME)
def _multiclass_nms_meta(boxes: torch.Tensor, scores: torch.Tensor, score_threshold: float, iou_threshold: float,
max_detections: int) -> NMSResults:
""" Registers torch op's abstract implementation. It specifies the properties of the output tensors.
Needed for torch.export """
ctx = torch.library.get_ctx()
batch = ctx.new_dynamic_size()
return NMSResults(
torch.empty((batch, max_detections, 4)),
torch.empty((batch, max_detections)),
torch.empty((batch, max_detections), dtype=torch.int64),
torch.empty((batch, 1), dtype=torch.int64)
) # yapf: disable


def _multiclass_nms_impl(boxes: Union[Tensor, np.ndarray], scores: Union[Tensor, np.ndarray], score_threshold: float,
Expand Down
4 changes: 2 additions & 2 deletions sony_custom_layers/pytorch/object_detection/nms_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# -----------------------------------------------------------------------------
import torch

from .nms import MULTICLASS_NMS_TORCH_OP
from .nms import MULTICLASS_NMS_TORCH_OP_QUALNAME

MULTICLASS_NMS_ONNX_OP = "Sony::MultiClassNMS"

Expand All @@ -42,4 +42,4 @@ def multiclass_nms_onnx(g, boxes, scores, score_threshold, iou_threshold, max_de
return outputs


torch.onnx.register_custom_op_symbolic(MULTICLASS_NMS_TORCH_OP, multiclass_nms_onnx, opset_version=1)
torch.onnx.register_custom_op_symbolic(MULTICLASS_NMS_TORCH_OP_QUALNAME, multiclass_nms_onnx, opset_version=1)
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from sony_custom_layers.pytorch.object_detection import nms
from sony_custom_layers.pytorch import load_custom_ops
from sony_custom_layers.util.import_util import is_compatible
from sony_custom_layers.util.test_util import exec_in_clean_process


Expand Down Expand Up @@ -261,6 +262,7 @@ def test_ort(self, dynamic_batch, tmpdir_factory):
"""
exec_in_clean_process(code, check=True)

@pytest.mark.skipif(not is_compatible('torch>=2.2'), reason='unsupported')
def test_pt2_export(self, tmpdir_factory):

def f(boxes, scores):
Expand Down
23 changes: 20 additions & 3 deletions sony_custom_layers/util/import_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
from typing import List
from typing import List, Union

from packaging.requirements import Requirement
from packaging.version import parse
Expand All @@ -24,9 +24,9 @@ class RequirementError(Exception):
pass


def check_pip_requirements(requirements: List[str]):
def validate_pip_requirements(requirements: List[str]):
"""
Check if all requirements are installed and meet the version specifications.
Validate that all requirements are installed and meet the version specifications.
Args:
requirements: a list of pip-style requirement strings
Expand All @@ -47,3 +47,20 @@ def check_pip_requirements(requirements: List[str]):
error += f'\nRequired {req_str}, installed version {installed_ver}'
if error:
raise RequirementError(error)


def is_compatible(requirements: Union[str, List]) -> bool:
"""
Non-raising requirement(s) check
Args:
requirements (str, List): requirement pip-style string or a list of requirement strings
Returns:
(bool) whether requirement(s) are satisfied
"""
requirements = [requirements] if isinstance(requirements, str) else requirements
try:
validate_pip_requirements(requirements)
except RequirementError:
return False
return True

0 comments on commit ba8cc5c

Please sign in to comment.