Skip to content

Commit

Permalink
5.1 Release (#1334)
Browse files Browse the repository at this point in the history
* 5.1 Release

* Minor unit test fix
  • Loading branch information
TobyRoseman committed Nov 9, 2021
1 parent 22a8877 commit acadd11
Show file tree
Hide file tree
Showing 182 changed files with 3,823 additions and 1,824 deletions.
38 changes: 22 additions & 16 deletions coremltools/_deps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,28 @@
"""
from distutils.version import StrictVersion as _StrictVersion
import logging as _logging
from packaging import version
import platform as _platform
import re as _re
import sys as _sys
from packaging import version


def __get_version(version):
def _get_version(version):
# matching 1.6.1, and 1.6.1rc, 1.6.1.dev
version_regex = r"^\d+\.\d+\.\d+"
version = _re.search(version_regex, str(version)).group(0)
return _StrictVersion(version)


def _warn_if_above_max_supported_version(package_name, package_version, max_supported_version):
if _get_version(package_version) > _StrictVersion(max_supported_version):
_logging.warning(
"%s version %s has not been tested with coremltools. You may run into unexpected errors. "
"%s %s is the most recent version that has been tested."
% (package_name, package_version, package_name, max_supported_version)
)


# ---------------------------------------------------------------------------------------

_IS_MACOS = _sys.platform == "darwin"
Expand Down Expand Up @@ -77,8 +86,10 @@ def __get_sklearn_version(version):

# ---------------------------------------------------------------------------------------
_HAS_XGBOOST = True
_XGBOOST_MAX_VERSION = "1.4.2"
try:
import xgboost
_warn_if_above_max_supported_version("XGBoost", xgboost.__version__, _XGBOOST_MAX_VERSION)
except:
_HAS_XGBOOST = False

Expand All @@ -89,12 +100,12 @@ def __get_sklearn_version(version):
_TF_1_MIN_VERSION = "1.12.0"
_TF_1_MAX_VERSION = "1.15.0"
_TF_2_MIN_VERSION = "2.1.0"
_TF_2_MAX_VERSION = "2.3.1"
_TF_2_MAX_VERSION = "2.5.0"

try:
import tensorflow

tf_ver = __get_version(tensorflow.__version__)
tf_ver = _get_version(tensorflow.__version__)

# TensorFlow
if tf_ver < _StrictVersion("2.0.0"):
Expand All @@ -112,11 +123,7 @@ def __get_sklearn_version(version):
)
% (tensorflow.__version__, _TF_1_MIN_VERSION)
)
elif tf_ver > _StrictVersion(_TF_1_MAX_VERSION):
_logging.warning(
"TensorFlow version %s detected. Last version known to be fully compatible is %s ."
% (tensorflow.__version__, _TF_1_MAX_VERSION)
)
_warn_if_above_max_supported_version("TensorFlow", tensorflow.__version__, _TF_1_MAX_VERSION)
elif _HAS_TF_2:
if tf_ver < _StrictVersion(_TF_2_MIN_VERSION):
_logging.warn(
Expand All @@ -126,11 +133,7 @@ def __get_sklearn_version(version):
)
% (tensorflow.__version__, _TF_2_MIN_VERSION)
)
elif tf_ver > _StrictVersion(_TF_2_MAX_VERSION):
_logging.warning(
"TensorFlow version %s detected. Last version known to be fully compatible is %s ."
% (tensorflow.__version__, _TF_2_MAX_VERSION)
)
_warn_if_above_max_supported_version("TensorFlow", tensorflow.__version__, _TF_2_MAX_VERSION)

except:
_HAS_TF = False
Expand Down Expand Up @@ -168,7 +171,7 @@ def __get_sklearn_version(version):
sys.stderr = stderr
import tensorflow

k_ver = __get_version(keras.__version__)
k_ver = _get_version(keras.__version__)

# keras 1 version too old
if k_ver < _StrictVersion(_KERAS_MIN_VERSION):
Expand All @@ -186,7 +189,8 @@ def __get_sklearn_version(version):
_HAS_KERAS_TF = False
_logging.warning(
(
"Keras version %s detected. Last version known to be fully compatible of Keras is %s ."
"Keras version %s has not been tested with coremltools. You may run into unexpected errors. "
"Keras %s is the most recent version that has been tested."
)
% (keras.__version__, _KERAS_MAX_VERSION)
)
Expand Down Expand Up @@ -214,8 +218,10 @@ def __get_sklearn_version(version):

# ---------------------------------------------------------------------------------------
_HAS_TORCH = True
_TORCH_MAX_VERSION = "1.9.1"
try:
import torch
_warn_if_above_max_supported_version("Torch", torch.__version__, _TORCH_MAX_VERSION)
except:
_HAS_TORCH = False
MSG_TORCH_NOT_FOUND = "PyTorch not found."
Expand Down
36 changes: 26 additions & 10 deletions coremltools/converters/_converters_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def convert(
compute_precision=None,
skip_model_load=False,
compute_units=_ComputeUnit.ALL,
**kwargs
useCPUOnly=False,
package_dir=None,
debug=False,
):
"""
Convert a TensorFlow or PyTorch model to the Core ML model format as either
Expand Down Expand Up @@ -230,6 +232,24 @@ def convert(
- ``coremltools.ComputeUnit.CPU_AND_GPU``: Use both the CPU and GPU, but not the
neural engine.
useCPUOnly: bool
Deprecated, to be removed in coremltools 6.0. Please use `compute_units` instead.
- if True, identical to setting compute_units to `coremltools.ComputeUnit.CPU_ONLY``
- if False, identical to setting compute_units to `coremltools.ComputeUnit.ALL``
package_dir : str
Post conversion, the model is compiled to form the MLModel object ready for prediction.
This requires a temporary directory to hold the mlmodelc archive.
- if not None, must be a path to a directory that is used for
temporarily storing the compiled model assets. If None, a temporary directory is created.
debug : bool
This flag should generally be False except for debugging purposes
Setting this flag to True:
- For Torch conversion, it will print the list of supported and unsupported ops
found in the model if conversion fails due to an unsupported op.
- For Tensorflow conversion, it will cause to display extra logging and visualizations
Returns
-------
model : ``coremltools.models.MLModel`` or ``coremltools.converters.mil.Program``
Expand Down Expand Up @@ -284,9 +304,9 @@ def convert(
exact_source = _determine_source(model, source, outputs)
exact_target = _determine_target(convert_to, minimum_deployment_target)
_validate_inputs(model, exact_source, inputs, outputs, classifier_config, compute_precision,
exact_target, **kwargs)
exact_target)

if "useCPUOnly" in kwargs and kwargs["useCPUOnly"]:
if useCPUOnly:
warnings.warn('The "useCPUOnly" parameter is deprecated and will be removed in 6.0. '
'Use the compute_units parameter: "compute_units=coremotools.ComputeUnits.CPU_ONLY".')
compute_units = _ComputeUnit.CPU_ONLY
Expand All @@ -313,7 +333,8 @@ def convert(
transforms=tuple(transforms),
skip_model_load=skip_model_load,
compute_units=compute_units,
**kwargs
package_dir=package_dir,
debug=debug,
)

if exact_target == 'milinternal':
Expand Down Expand Up @@ -344,8 +365,7 @@ def _check_deployment_target(minimum_deployment_target):
)
raise TypeError(msg.format(minimum_deployment_target))

def _validate_inputs(model, exact_source, inputs, outputs, classifier_config, compute_precision, convert_to,
**kwargs):
def _validate_inputs(model, exact_source, inputs, outputs, classifier_config, compute_precision, convert_to):
"""
Validate and process model, inputs, outputs, classifier_config based on
`exact_source` (which cannot be `auto`)
Expand Down Expand Up @@ -399,10 +419,6 @@ def raise_if_duplicated(input_list):
raise ValueError("Input should be a list of TensorType or ImageType")

elif exact_source == "pytorch":
if "example_inputs" in kwargs:
msg = 'Unexpected argument "example_inputs" found'
raise ValueError(msg)

if inputs is None:
msg = 'Expected argument for pytorch "inputs" not provided'
raise ValueError(msg)
Expand Down
5 changes: 5 additions & 0 deletions coremltools/converters/_profile_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# Copyright (c) 2021, Apple Inc. All rights reserved.
#
# Use of this source code is governed by a BSD-3-clause license that can be
# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause

import os
import time

Expand Down
53 changes: 52 additions & 1 deletion coremltools/converters/mil/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,58 @@
# Use of this source code is governed by a BSD-3-clause license that can be
# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause

from .mil import *
# This import should be pruned rdar://84519338
from .mil import (
block,
Block,
BoolInputType,
BoolTensorInputType,
builder,
Builder,
curr_block,
DefaultInputs,
FloatInputType,
FloatTensorInputType,
Function,
get_existing_symbol,
get_new_symbol,
get_new_variadic_symbol,
input_type,
InputSpec,
IntInputType,
IntOrFloatInputType,
IntOrFloatOrBoolInputType,
IntTensorInputType,
InternalInputType,
InternalScalarOrTensorInputType,
InternalStringInputType,
InternalVar,
ListInputType,
ListOrScalarOrTensorInputType,
ListVar,
mil_list,
operation,
Operation,
ops,
Placeholder,
precondition,
program,
Program,
PyFunctionInputType,
register_op,
SPACES,
SUPPORT_FLOAT_TYPES,
SUPPORT_INT_TYPES,
ScalarOrTensorInputType,
StringInputType,
Symbol,
TensorInputType,
TupleInputType,
types,
var,
Var,
visitors
)

from .frontend.torch import register_torch_op

Expand Down
5 changes: 4 additions & 1 deletion coremltools/converters/mil/backend/mil/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (c) 2020, Apple Inc. All rights reserved.
# Copyright (c) 2020, Apple Inc. All rights reserved.
#
# Use of this source code is governed by a BSD-3-clause license that can be
# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause

from .load import load
5 changes: 5 additions & 0 deletions coremltools/converters/mil/backend/mil/helper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# Copyright (c) 2021, Apple Inc. All rights reserved.
#
# Use of this source code is governed by a BSD-3-clause license that can be
# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause

import numpy as np
import os
import re
Expand Down
59 changes: 35 additions & 24 deletions coremltools/converters/mil/backend/mil/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,47 +4,56 @@
# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause

import logging
import numpy as _np
import numpy as np
import os
import tempfile
import shutil

from coremltools.converters.mil.backend.mil.helper import *
from coremltools.converters.mil.backend.backend_helper import _get_probability_var_for_classifier
from .passes import mil_passes
import coremltools.proto.MIL_pb2 as pm
from coremltools.converters.mil.mil import types
from coremltools.converters.mil.mil import Function
from coremltools import _SPECIFICATION_VERSION_IOS_15
from coremltools.converters.mil.backend.mil.helper import (
cast_to_framework_io_dtype,
create_file_value,
create_immediate_value,
create_list_scalarvalue,
create_scalar_value,
types_to_proto
)
from coremltools.converters.mil.backend.backend_helper import _get_probability_var_for_classifier
from coremltools.converters.mil.mil import (
Builder as mb,
Function,
mil_list,
types
)
from coremltools.converters.mil.backend.nn.load import _set_optional_inputs
from coremltools.converters.mil.input_types import ImageType, TensorType, EnumeratedShapes, RangeDim
from coremltools.converters.mil.mil.ops.registry import SSAOpRegistry
from coremltools.converters.mil.mil.types.symbolic import (
any_symbolic,
any_variadic,
is_symbolic,
)
from coremltools.converters.mil.mil.types.type_mapping import types_int64
from coremltools.libmilstoragepython import _BlobStorageWriter as BlobWriter
from coremltools.models.model import _WEIGHTS_FILE_NAME
from coremltools.models.neural_network.flexible_shape_utils import (
NeuralNetworkImageSize,
NeuralNetworkImageSizeRange,
add_enumerated_image_sizes,
add_multiarray_ndshape_enumeration,
NeuralNetworkImageSize,
NeuralNetworkImageSizeRange,
set_multiarray_ndshape_range,
update_image_size_range,
update_image_size_range
)
from coremltools.proto import (
FeatureTypes_pb2 as ft,
MIL_pb2 as pm,
Model_pb2 as ml
)

from coremltools.libmilstoragepython import _BlobStorageWriter as BlobWriter

import coremltools.proto.Model_pb2 as ml
import coremltools.proto.FeatureTypes_pb2 as ft
from coremltools.converters.mil.input_types import ImageType, TensorType, EnumeratedShapes, RangeDim
from coremltools.models.model import _WEIGHTS_FILE_NAME
from coremltools.converters.mil.mil import Builder as mb
from coremltools.converters.mil.mil import mil_list
from coremltools import _SPECIFICATION_VERSION_IOS_15

def should_use_weight_file(val):
return (
val is not None
and isinstance(val, (_np.ndarray, _np.generic))
and isinstance(val, (np.ndarray, np.generic))
and val.size >= 10
and val.dtype in ['float16', 'float32']
)
Expand Down Expand Up @@ -97,7 +106,7 @@ def translate_generic_op(op, parameters, blob_writer, literal_params=[]):
blocks = None
if len(op.blocks) > 0:
blocks = [create_block(b, parameters, blob_writer) \
for b in op.blocks]
for b in op.blocks]

op_type = op.op_type
attr_dict = {}
Expand Down Expand Up @@ -206,6 +215,9 @@ def _add_classify_op(prog, classifier_config):

# add the classify op now
with block:
# cast the int label to np.int64
if isinstance(classes[0], int):
classes = [np.int64(x) for x in classes]
classes_var = mb.const(val=mil_list(classes))
out = mb.classify(probabilities=probability_var, classes=classes_var)

Expand Down Expand Up @@ -344,7 +356,7 @@ def load(prog, weights_dir, resume_on_errors=False, **kwargs):
keytype, valtype = var.sym_type.T
if types.is_str(keytype):
output_feature_type.dictionaryType.stringKeyType.MergeFromString(b"")
elif (keytype == types_int64):
elif (keytype == types.int64):
output_feature_type.dictionaryType.int64KeyType.MergeFromString(b"")
else:
raise ValueError("Dictionary key type not supported.")
Expand Down Expand Up @@ -445,7 +457,6 @@ def load(prog, weights_dir, resume_on_errors=False, **kwargs):
model, input_name, lower_bounds=lb, upper_bounds=ub
)


# Set optional inputs
_set_optional_inputs(model, input_types)

Expand Down
Loading

0 comments on commit acadd11

Please sign in to comment.