Skip to content

Commit

Permalink
6.2 Release (#1751)
Browse files Browse the repository at this point in the history
  • Loading branch information
junpeiz committed Feb 3, 2023
1 parent 8ac4610 commit 22cd170
Show file tree
Hide file tree
Showing 112 changed files with 6,267 additions and 2,963 deletions.
6 changes: 3 additions & 3 deletions .gitlab-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ test_py37_pytorch:
WHEEL_PATH: build/dist/*cp37*10_15*

test_py37_tf1:
<<: *test_macos_pkg_with_reqs
<<: *test_macos_pkg
tags:
- macos12
dependencies:
Expand All @@ -107,10 +107,9 @@ test_py37_tf1:
PYTHON: "3.7"
TEST_PACKAGE: coremltools.converters.mil.frontend.tensorflow
WHEEL_PATH: build/dist/*cp37*10_15*
REQUIREMENTS: reqs/test_tf1.pip

test_py37_tf2:
<<: *test_macos_pkg
<<: *test_macos_pkg_with_reqs
tags:
- macos12
dependencies:
Expand All @@ -119,6 +118,7 @@ test_py37_tf2:
PYTHON: "3.7"
TEST_PACKAGE: coremltools.converters.mil.frontend.tensorflow2
WHEEL_PATH: build/dist/*cp37*10_15*
REQUIREMENTS: reqs/test.pip

test_py37_mil:
<<: *test_macos_pkg
Expand Down
1 change: 1 addition & 0 deletions coremlpython/CoreMLPython.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ namespace CoreML {
private:
MLModel *m_model = nil;
NSURL *compiledUrl = nil;
bool m_deleteCompiledModelOnExit;
public:
Model(const Model&) = delete;
Model& operator=(const Model&) = delete;
Expand Down
68 changes: 43 additions & 25 deletions coremlpython/CoreMLPython.mm
Original file line number Diff line number Diff line change
Expand Up @@ -33,42 +33,60 @@ bool usingMacOS13OrHigher() {
return (NSProtocolFromString(@"MLProgram") != nil);
}

bool isCompiledModelPath(const std::string& path) {
const std::string fileExtension = ".mlmodelc";

size_t start = path.length() - fileExtension.length();
if (path.back() == '/') {
start--;
}
const std::string match = path.substr(start, fileExtension.length());

return (match == fileExtension);
}

Model::~Model() {
@autoreleasepool {
NSFileManager *fileManager = [NSFileManager defaultManager];
if (compiledUrl != nil) {
if (compiledUrl != nil and m_deleteCompiledModelOnExit) {
[fileManager removeItemAtURL:compiledUrl error:NULL];
}
}
}

Model::Model(const std::string& urlStr, const std::string& computeUnits) {
@autoreleasepool {

// Compile the model
NSError *error = nil;
NSURL *specUrl = Utils::stringToNSURL(urlStr);

// Swallow output for the very verbose coremlcompiler
int stdoutBack = dup(STDOUT_FILENO);
int devnull = open("/dev/null", O_WRONLY);
dup2(devnull, STDOUT_FILENO);

// Compile the model
compiledUrl = [MLModel compileModelAtURL:specUrl error:&error];

// Close all the file descriptors and revert back to normal
dup2(stdoutBack, STDOUT_FILENO);
close(devnull);
close(stdoutBack);

// Translate into a type that pybind11 can bridge to Python
if (error != nil) {
std::stringstream errmsg;
errmsg << "Error compiling model: \"";
errmsg << error.localizedDescription.UTF8String;
errmsg << "\".";
throw std::runtime_error(errmsg.str());

if (! isCompiledModelPath(urlStr)) {
// Compile the model
NSURL *specUrl = Utils::stringToNSURL(urlStr);

// Swallow output for the very verbose coremlcompiler
int stdoutBack = dup(STDOUT_FILENO);
int devnull = open("/dev/null", O_WRONLY);
dup2(devnull, STDOUT_FILENO);

// Compile the model
compiledUrl = [MLModel compileModelAtURL:specUrl error:&error];
m_deleteCompiledModelOnExit = true;

// Close all the file descriptors and revert back to normal
dup2(stdoutBack, STDOUT_FILENO);
close(devnull);
close(stdoutBack);

// Translate into a type that pybind11 can bridge to Python
if (error != nil) {
std::stringstream errmsg;
errmsg << "Error compiling model: \"";
errmsg << error.localizedDescription.UTF8String;
errmsg << "\".";
throw std::runtime_error(errmsg.str());
}
} else {
m_deleteCompiledModelOnExit = false; // Don't delete user specified file
compiledUrl = Utils::stringToNSURL(urlStr);
}

// Set compute unit
Expand Down
4 changes: 0 additions & 4 deletions coremltools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,6 @@
# New versions for iOS 16.0
_SPECIFICATION_VERSION_IOS_16 = 7

# New versions for iOS 17.0
_SPECIFICATION_VERSION_IOS_17 = 8

class ComputeUnit(_Enum):
'''
The set of processing-unit configurations the model can use to make predictions.
Expand All @@ -79,7 +76,6 @@ class ComputeUnit(_Enum):
_SPECIFICATION_VERSION_IOS_14: "CoreML4",
_SPECIFICATION_VERSION_IOS_15: "CoreML5",
_SPECIFICATION_VERSION_IOS_16: "CoreML6",
_SPECIFICATION_VERSION_IOS_17: "CoreML7",
}

# Default specification version for each backend
Expand Down
2 changes: 1 addition & 1 deletion coremltools/_deps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def __get_sklearn_version(version):

# ---------------------------------------------------------------------------------------
_HAS_TORCH = True
_TORCH_MAX_VERSION = "1.12.1"
_TORCH_MAX_VERSION = "1.13.1"
try:
import torch
_warn_if_above_max_supported_version("Torch", torch.__version__, _TORCH_MAX_VERSION)
Expand Down
2 changes: 1 addition & 1 deletion coremltools/converters/mil/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
from .input_types import (ClassifierConfig, ColorLayout, EnumeratedShapes,
ImageType, InputType, RangeDim, Shape, TensorType)
from .frontend.tensorflow.tf_op_registry import register_tf_op
from .frontend.torch import register_torch_op
from .frontend.torch import register_torch_op
7 changes: 1 addition & 6 deletions coremltools/converters/mil/_deployment_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from coremltools import (_SPECIFICATION_VERSION_IOS_13,
_SPECIFICATION_VERSION_IOS_14,
_SPECIFICATION_VERSION_IOS_15,
_SPECIFICATION_VERSION_IOS_16,
_SPECIFICATION_VERSION_IOS_17)
_SPECIFICATION_VERSION_IOS_16)


class AvailableTarget(IntEnum):
Expand All @@ -18,7 +17,6 @@ class AvailableTarget(IntEnum):
iOS14 = _SPECIFICATION_VERSION_IOS_14
iOS15 = _SPECIFICATION_VERSION_IOS_15
iOS16 = _SPECIFICATION_VERSION_IOS_16
iOS17 = _SPECIFICATION_VERSION_IOS_17

# macOS versions (aliases of iOS versions)
macOS15 = _SPECIFICATION_VERSION_IOS_13
Expand All @@ -28,21 +26,18 @@ class AvailableTarget(IntEnum):
macOS11 = _SPECIFICATION_VERSION_IOS_14
macOS12 = _SPECIFICATION_VERSION_IOS_15
macOS13 = _SPECIFICATION_VERSION_IOS_16
macOS14 = _SPECIFICATION_VERSION_IOS_17

# watchOS versions (aliases of iOS versions)
watchOS6 = _SPECIFICATION_VERSION_IOS_13
watchOS7 = _SPECIFICATION_VERSION_IOS_14
watchOS8 = _SPECIFICATION_VERSION_IOS_15
watchOS9 = _SPECIFICATION_VERSION_IOS_16
watchOS10 = _SPECIFICATION_VERSION_IOS_17

# tvOS versions (aliases of iOS versions)
tvOS13 = _SPECIFICATION_VERSION_IOS_13
tvOS14 = _SPECIFICATION_VERSION_IOS_14
tvOS15 = _SPECIFICATION_VERSION_IOS_15
tvOS16 = _SPECIFICATION_VERSION_IOS_16
tvOS17 = _SPECIFICATION_VERSION_IOS_17

# customized __str__
def __str__(self):
Expand Down
11 changes: 8 additions & 3 deletions coremltools/converters/mil/backend/mil/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,14 @@ def create_file_value_tensor(file_name, offset, dim, data_type):

def types_to_proto_primitive(valuetype):
if valuetype not in builtin_to_proto_types:
additional_error_msg = ""
if valuetype in (types.complex64, types.complex128):
additional_error_msg = (
"(MIL doesn't support complex data as model's output, please extract real and "
"imaginary parts explicitly.) "
)
raise ValueError(
"Unknown type {} to map from SSA types to Proto types".format(
valuetype)
f"Unknown map from SSA type {valuetype} to Proto type. {additional_error_msg}"
)
return builtin_to_proto_types[valuetype]

Expand Down Expand Up @@ -302,7 +307,7 @@ def create_immediate_value(var):
return create_tensor_value(var.val)
elif types.is_list(var.sym_type):
if var.elem_type == types.str:
return create_list_scalarvalue(var.val, np.str)
return create_list_scalarvalue(var.val, str)
elif var.elem_type == types.int64:
return create_list_scalarvalue(var.val, np.int64)
else:
Expand Down
6 changes: 3 additions & 3 deletions coremltools/converters/mil/backend/mil/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,9 @@ def translate_generic_op(op, parameters, blob_writer, literal_params=[]):

attr_dict["name"] = create_scalar_value(op.name)
attr_dict["class_name"] = create_scalar_value(class_name)
attr_dict["input_order"] = create_list_scalarvalue(input_order, np.str)
attr_dict["parameters"] = create_list_scalarvalue(parameters, np.str)
attr_dict["weights"] = create_list_scalarvalue(weights, np.str)
attr_dict["input_order"] = create_list_scalarvalue(input_order, str)
attr_dict["parameters"] = create_list_scalarvalue(parameters, str)
attr_dict["weights"] = create_list_scalarvalue(weights, str)
attr_dict["description"] = create_scalar_value(description)

return pm.Operation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,7 @@ def prog(x):
assert np.all(add_op.y.val == np.array([1.0, 2.0, 3.0]).reshape([1, 3, 1, 1]))

@pytest.mark.parametrize(
"scale_type, bias_type", itertools.product([np.float, np.int32], [np.float, np.int32])
"scale_type, bias_type", itertools.product([np.float32, np.int32], [np.float32, np.int32])
)
def test_scale_bias_types(self, scale_type, bias_type):
"""
Expand Down
8 changes: 4 additions & 4 deletions coremltools/converters/mil/backend/nn/op_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,10 +238,10 @@ def add_const(const_context, builder, name, val):
return
if not isinstance(val, (_np.ndarray, _np.generic)):
val = _np.array([val])
if val.dtype != _np.float:
if val.dtype != float:
# nn proto only supports float32 activation. (e.g., pred in cond op
# needs to be converted to float)
val = val.astype(_np.float)
val = val.astype(float)
rank = len(val.shape)
if rank == 0:
builder.add_load_constant_nd(
Expand Down Expand Up @@ -755,9 +755,9 @@ def _add_elementwise_binary(
builder.add_less_than(**params)
return

if op.x.val is not None:
if op.x.can_be_folded_to_const():
add_const(const_context, builder, op.x.name, op.x.val)
if op.y.val is not None:
if op.y.can_be_folded_to_const():
if mode == "pow":
_add_elementwise_unary(
const_context,
Expand Down
35 changes: 17 additions & 18 deletions coremltools/converters/mil/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
# Use of this source code is governed by a BSD-3-clause license that can be
# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause

import os as _os
import stat as _stat
import tempfile as _tempfile
import warnings as _warnings

Expand Down Expand Up @@ -52,10 +50,9 @@ def __call__(self, model, *args, **kwargs):
msg = (
"Please update the minimum_deployment_target to {!s},"
" since op {} is only available in opset {!s} or newer."

).format(max_opset_version, op.op_type, max_opset_version)
raise ValueError(msg)

if "inputs" in kwargs and kwargs["inputs"] is not None:
inputs = kwargs["inputs"]
if not isinstance(inputs, (list, tuple)):
Expand Down Expand Up @@ -208,11 +205,8 @@ def _mil_convert(

if convert_to == 'mlprogram':
# mil_convert_to_proto places weight files inside the weights_dir
weights_dir = _tempfile.mkdtemp()
kwargs['weights_dir'] = weights_dir

# To make sure everyone can read and write to this directory (on par with os.mkdir())
_os.chmod(weights_dir, _stat.S_IRWXU | _stat.S_IRWXG | _stat.S_IRWXO)
weights_dir = _tempfile.TemporaryDirectory()
kwargs["weights_dir"] = weights_dir.name

proto, mil_program = mil_convert_to_proto(
model,
Expand All @@ -225,23 +219,28 @@ def _mil_convert(
_reset_conversion_state()

if convert_to == 'milinternal':
return mil_program # mil program
return mil_program # mil program
elif convert_to == 'milpython':
return proto # internal mil data structure
return proto # internal mil data structure

elif convert_to == 'mlprogram':
package_path = _create_mlpackage(proto, weights_dir, kwargs.get("package_dir"))
return modelClass(package_path,
is_temp_package=not kwargs.get('package_dir'),
mil_program=mil_program,
skip_model_load=kwargs.get('skip_model_load', False),
compute_units=compute_units)
elif convert_to == "mlprogram":
package_path = _create_mlpackage(
proto, kwargs.get("weights_dir"), kwargs.get("package_dir")
)
return modelClass(
package_path,
is_temp_package=not kwargs.get("package_dir"),
mil_program=mil_program,
skip_model_load=kwargs.get("skip_model_load", False),
compute_units=compute_units,
)

return modelClass(proto,
mil_program=mil_program,
skip_model_load=kwargs.get('skip_model_load', False),
compute_units=compute_units)


def mil_convert_to_proto(
model,
convert_from,
Expand Down
6 changes: 4 additions & 2 deletions coremltools/converters/mil/frontend/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# Use of this source code is governed by a BSD-3-clause license that can be
# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause

from typing import List, Optional

from coremltools.converters.mil.input_types import InputType
from coremltools.converters.mil.mil import Builder as mb
from coremltools.converters.mil.mil import Var, types
Expand Down Expand Up @@ -208,10 +210,10 @@ def _does_block_contain_symbolic_shape(block):
return False


def get_output_names(outputs):
def get_output_names(outputs) -> Optional[List[str]]:
"""
:param: list[ct.TensorType/ct.ImageType]
:return: list[str]
:return: list[str] or None
"""
output_names = None
if outputs is not None:
Expand Down
4 changes: 4 additions & 0 deletions coremltools/converters/mil/frontend/tensorflow/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,10 @@ def __init__(self, tfssa, inputs=None, outputs=None, opset_version=None):
continue
if any([isinstance(s, RangeDim) for s in inputtype.shape.shape]):
continue
if inputtype.name not in graph:
raise ValueError(
f"The input {inputtype.name} provided is not in graph."
)
node = graph[inputtype.name]
shape = [-1 if is_symbolic(s) else s for s in inputtype.shape.shape]
node.attr["_output_shapes"] = [shape] # list of length 1
Expand Down
Loading

0 comments on commit 22cd170

Please sign in to comment.