Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for torch 2.0 and 2.1 #11

Merged
merged 4 commits into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 24 additions & 17 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ jobs:
strategy:
fail-fast: false
matrix:
# NOTE if newer versions are added, update sony_custom_layers.__init__ pinned_requirements!!!
py_ver: ["3.8", "3.9", "3.10", "3.11"]
tf_ver: ["2.10", "2.11", "2.12", "2.13", "2.14", "2.15"]
exclude:
Expand All @@ -38,7 +39,7 @@ jobs:
python-version: ${{matrix.py_ver}}
- name: Install dependencies
run: |
pip install tensorflow==${{matrix.tf_ver}}
pip install tensorflow==${{matrix.tf_ver}}.*
pip install -r requirements_test.txt
pip list
- name: Run pytest
Expand All @@ -50,15 +51,25 @@ jobs:
strategy:
fail-fast: false
matrix:
# NOTE if newer versions are added, update sony_custom_layers.__init__ pinned_requirements!!!
py_ver: [ "3.8", "3.9", "3.10", "3.11" ]
torch_ver: ["2.2.*"]
torchvision_ver: ["0.17.*"] # <0.17 incompatible with torch2.2
ort_ver: ["1.15.*", "1.16.*", "1.17.*"]
ort_ext_ver: ["0.8.*", "0.9.*", "0.10.*"]
onnx_ver: ["1.14.*", "1.15.*"]
torch_ver: ["2.0", "2.1", "2.2"]
ort_ver: ["1.15", "1.16", "1.17"]
ort_ext_ver: ["0.8", "0.9", "0.10"]
include:
- torch_ver: "2.2"
torchvision_ver: "0.17"
onnx_ver: "1.15"
- torch_ver: "2.1"
torchvision_ver: "0.16"
onnx_ver: "1.14"
- torch_ver: "2.0"
torchvision_ver: "0.15"
onnx_ver: "1.15"

exclude:
- py_ver: "3.11"
ort_ext_ver: "0.8.*"
ort_ext_ver: "0.8"
steps:
- name: Checkout
uses: actions/checkout@v4
Expand All @@ -68,11 +79,11 @@ jobs:
python-version: ${{matrix.py_ver}}
- name: Install dependencies
run: |
pip install torch==${{matrix.torch_ver}} \
torchvision==${{matrix.torchvision_ver}} \
onnxruntime==${{matrix.ort_ver}} \
onnxruntime_extensions==${{matrix.ort_ext_ver}} \
onnx==${{matrix.onnx_ver}} \
pip install torch==${{matrix.torch_ver}}.* \
torchvision==${{matrix.torchvision_ver}}.* \
onnxruntime==${{matrix.ort_ver}}.* \
onnxruntime_extensions==${{matrix.ort_ext_ver}}.* \
onnx==${{matrix.onnx_ver}}.* \
--index-url https://download.pytorch.org/whl/cpu \
--extra-index-url https://pypi.org/simple

Expand All @@ -91,12 +102,10 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: "3.10"

- name: Run pre-commit
run: |
./install-pre-commit.sh
pre-commit run --all

- name: get new dev tag
shell: bash
run : |
Expand Down Expand Up @@ -128,7 +137,7 @@ jobs:
echo "__version__ = '${{ env.new_ver }}'" > sony_custom_layers/version.py
echo "print sony_custom_layers/version.py"
cat sony_custom_layers/version.py

sed -i 's/name = sony-custom-layers/name = sony-custom-layers-dev/' setup.cfg
echo "print setup.cfg"
cat setup.cfg
Expand All @@ -148,5 +157,3 @@ jobs:
git tag ${{ env.new_tag }}
git push origin ${{ env.new_tag }}
fi


10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ To install the latest stable release of SCL, run the following command:
pip install sony-custom-layers
```
By default, no framework dependencies are installed.
To install SCL including the dependencies for TensorFlow:
To install SCL including the latest tested dependencies (up to patch version) for TensorFlow:
```
pip install sony-custom-layers[tf]
```
To install SCL including the dependencies for PyTorch/ONNX/OnnxRuntime:
To install SCL including the latest tested dependencies (up to patch version) for PyTorch/ONNX/OnnxRuntime:
```
pip install sony-custom-layers[torch]
```
Expand All @@ -43,9 +43,9 @@ pip install sony-custom-layers[torch]

#### PyTorch

| **Tested FW versions** | **Tested Python version** | **Serialization** |
|---------------------------------|---------------------------|------------------------------------------------------------------|
| torch 2.2<br/>torchvision 0.17<br/>onnxruntime 1.15-1.17<br/>onnxruntime_extensions 0.8-0.10<br/>onnx 1.14-1.15| 3.8-3.11 | .onnx (via torch.onnx.export)<br/>.pt2 (via torch.export.export) |
| **Tested FW versions** | **Tested Python version** | **Serialization** |
|--------------------------------------------------------------------------------------------------------------------------|---------------------------|---------------------------------------------------------------------------------|
| torch 2.0-2.2<br/>torchvision 0.15-0.17<br/>onnxruntime 1.15-1.17<br/>onnxruntime_extensions 0.8-0.10<br/>onnx 1.14-1.15 | 3.8-3.11 | .onnx (via torch.onnx.export)<br/>.pt2 (via torch.export.export, torch2.2 only) |

## Implemented Layers
SCL currently includes implementations of the following layers:
Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
"""
from setuptools import setup

from sony_custom_layers import requirements
from sony_custom_layers import pinned_requirements

extras_require = {
'torch': requirements['torch'] + requirements['torch_ort'],
'tf': requirements['tf'],
'torch': pinned_requirements['torch'] + pinned_requirements['torch_ort'],
'tf': pinned_requirements['tf'],
}

setup(extras_require=extras_require)
15 changes: 11 additions & 4 deletions sony_custom_layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,16 @@
# limitations under the License.
# -----------------------------------------------------------------------------

# for use by setup.py and for dynamic validation in sony_custom_layers.{keras, pytorch}.__init__
# minimal requirements for dynamic validation in sony_custom_layers.{keras, pytorch}.__init__
requirements = {
'tf': ['tensorflow>=2.10,<2.16'],
'torch': ['torch>=2.2.0', 'torchvision>=0.17.0'],
'torch_ort': ['onnxruntime', 'onnxruntime_extensions>=0.8.0'],
'tf': ['tensorflow>=2.10'],
'torch': ['torch>=2.0', 'torchvision>=0.15'],
'torch_ort': ['onnx', 'onnxruntime', 'onnxruntime_extensions>=0.8.0'],
}

# pinned requirements of latest tested versions for extra_requires
pinned_requirements = {
'tf': ['tensorflow==2.15.*'],
'torch': ['torch==2.2.*', 'torchvision==0.17.*'],
'torch_ort': ['onnx==1.15.*', 'onnxruntime==1.17.*', 'onnxruntime_extensions==0.10.*']
}
4 changes: 2 additions & 2 deletions sony_custom_layers/keras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
# limitations under the License.
# -----------------------------------------------------------------------------

from sony_custom_layers.util.import_util import check_pip_requirements
from sony_custom_layers.util.import_util import validate_pip_requirements
from sony_custom_layers import requirements

check_pip_requirements(requirements['tf'])
validate_pip_requirements(requirements['tf'])

from .object_detection import FasterRCNNBoxDecode, SSDPostProcess, ScoreConverter # noqa: E402
from .custom_objects import custom_layers_scope # noqa: E402
6 changes: 3 additions & 3 deletions sony_custom_layers/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
# -----------------------------------------------------------------------------
from typing import Optional, TYPE_CHECKING

from sony_custom_layers.util.import_util import check_pip_requirements
from sony_custom_layers.util.import_util import validate_pip_requirements
from sony_custom_layers import requirements

if TYPE_CHECKING:
import onnxruntime as ort

__all__ = ['multiclass_nms', 'NMSResults', 'load_custom_ops']

check_pip_requirements(requirements['torch'])
validate_pip_requirements(requirements['torch'])

from .object_detection import multiclass_nms, NMSResults # noqa: E402

Expand Down Expand Up @@ -53,7 +53,7 @@ def load_custom_ops(load_ort: bool = False,
SessionOptions object if ort registration was requested, otherwise None
"""
if load_ort or ort_session_ops:
check_pip_requirements(requirements['torch_ort'])
validate_pip_requirements(requirements['torch_ort'])

# trigger onnxruntime op registration
from .object_detection import nms_ort
Expand Down
50 changes: 31 additions & 19 deletions sony_custom_layers/pytorch/object_detection/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
from torch import Tensor
import torchvision # noqa: F401 # needed for torch.ops.torchvision

MULTICLASS_NMS_TORCH_OP = 'sony::multiclass_nms'
from sony_custom_layers.util.import_util import is_compatible

CUSTOM_LIB_NAME = 'sony'
MULTICLASS_NMS_TORCH_OP = 'multiclass_nms'
MULTICLASS_NMS_TORCH_OP_QUALNAME = CUSTOM_LIB_NAME + '::' + MULTICLASS_NMS_TORCH_OP

__all__ = ['multiclass_nms', 'NMSResults']

Expand Down Expand Up @@ -57,13 +61,19 @@ def multiclass_nms(boxes, scores, score_threshold: float, iou_threshold: float,
return NMSResults(*torch.ops.sony.multiclass_nms(boxes, scores, score_threshold, iou_threshold, max_detections))


torch.library.define(
MULTICLASS_NMS_TORCH_OP,
"(Tensor boxes, Tensor scores, float score_threshold, float iou_threshold, SymInt max_detections) -> "
"(Tensor, Tensor, Tensor, Tensor)")
custom_lib = torch.library.Library(CUSTOM_LIB_NAME, "DEF")
schema = (MULTICLASS_NMS_TORCH_OP +
"(Tensor boxes, Tensor scores, float score_threshold, float iou_threshold, SymInt max_detections) "
"-> (Tensor, Tensor, Tensor, Tensor)")
op_name = custom_lib.define(schema)

if is_compatible('torch>=2.2'):
register_impl = torch.library.impl(MULTICLASS_NMS_TORCH_OP_QUALNAME, 'default')
else:
register_impl = torch.library.impl(custom_lib, MULTICLASS_NMS_TORCH_OP)


@torch.library.impl(MULTICLASS_NMS_TORCH_OP, 'default')
@register_impl
def _multiclass_nms_op(boxes: torch.Tensor, scores: torch.Tensor, score_threshold: float, iou_threshold: float,
max_detections: int) -> NMSResults:
""" Registers the torch op as torch.ops.sony.multiclass_nms """
Expand All @@ -74,19 +84,21 @@ def _multiclass_nms_op(boxes: torch.Tensor, scores: torch.Tensor, score_threshol
max_detections=max_detections)


@torch.library.impl_abstract(MULTICLASS_NMS_TORCH_OP)
def _multiclass_nms_meta(boxes: torch.Tensor, scores: torch.Tensor, score_threshold: float, iou_threshold: float,
max_detections: int) -> NMSResults:
""" Registers torch op's abstract implementation. It specifies the properties of the output tensors.
Needed for torch.export """
ctx = torch.library.get_ctx()
batch = ctx.new_dynamic_size()
return NMSResults(
torch.empty((batch, max_detections, 4)),
torch.empty((batch, max_detections)),
torch.empty((batch, max_detections), dtype=torch.int64),
torch.empty((batch, 1), dtype=torch.int64)
) # yapf: disable
if is_compatible('torch>=2.2'):

@torch.library.impl_abstract(MULTICLASS_NMS_TORCH_OP_QUALNAME)
def _multiclass_nms_meta(boxes: torch.Tensor, scores: torch.Tensor, score_threshold: float, iou_threshold: float,
max_detections: int) -> NMSResults:
""" Registers torch op's abstract implementation. It specifies the properties of the output tensors.
Needed for torch.export """
ctx = torch.library.get_ctx()
batch = ctx.new_dynamic_size()
return NMSResults(
torch.empty((batch, max_detections, 4)),
torch.empty((batch, max_detections)),
torch.empty((batch, max_detections), dtype=torch.int64),
torch.empty((batch, 1), dtype=torch.int64)
) # yapf: disable


def _multiclass_nms_impl(boxes: Union[Tensor, np.ndarray], scores: Union[Tensor, np.ndarray], score_threshold: float,
Expand Down
4 changes: 2 additions & 2 deletions sony_custom_layers/pytorch/object_detection/nms_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# -----------------------------------------------------------------------------
import torch

from .nms import MULTICLASS_NMS_TORCH_OP
from .nms import MULTICLASS_NMS_TORCH_OP_QUALNAME

MULTICLASS_NMS_ONNX_OP = "Sony::MultiClassNMS"

Expand All @@ -42,4 +42,4 @@ def multiclass_nms_onnx(g, boxes, scores, score_threshold, iou_threshold, max_de
return outputs


torch.onnx.register_custom_op_symbolic(MULTICLASS_NMS_TORCH_OP, multiclass_nms_onnx, opset_version=1)
torch.onnx.register_custom_op_symbolic(MULTICLASS_NMS_TORCH_OP_QUALNAME, multiclass_nms_onnx, opset_version=1)
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from sony_custom_layers.pytorch.object_detection import nms
from sony_custom_layers.pytorch import load_custom_ops
from sony_custom_layers.util.import_util import is_compatible
from sony_custom_layers.util.test_util import exec_in_clean_process


Expand Down Expand Up @@ -261,6 +262,7 @@ def test_ort(self, dynamic_batch, tmpdir_factory):
"""
exec_in_clean_process(code, check=True)

@pytest.mark.skipif(not is_compatible('torch>=2.2'), reason='unsupported')
def test_pt2_export(self, tmpdir_factory):

def f(boxes, scores):
Expand Down
23 changes: 20 additions & 3 deletions sony_custom_layers/util/import_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
from typing import List
from typing import List, Union

from packaging.requirements import Requirement
from packaging.version import parse
Expand All @@ -24,9 +24,9 @@ class RequirementError(Exception):
pass


def check_pip_requirements(requirements: List[str]):
def validate_pip_requirements(requirements: List[str]):
"""
Check if all requirements are installed and meet the version specifications.
Validate that all requirements are installed and meet the version specifications.

Args:
requirements: a list of pip-style requirement strings
Expand All @@ -47,3 +47,20 @@ def check_pip_requirements(requirements: List[str]):
error += f'\nRequired {req_str}, installed version {installed_ver}'
if error:
raise RequirementError(error)


def is_compatible(requirements: Union[str, List]) -> bool:
"""
Non-raising requirement(s) check
Args:
requirements (str, List): requirement pip-style string or a list of requirement strings

Returns:
(bool) whether requirement(s) are satisfied
"""
requirements = [requirements] if isinstance(requirements, str) else requirements
try:
validate_pip_requirements(requirements)
except RequirementError:
return False
return True