diff --git a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc index f267dc755135..25d7890faeba 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc @@ -13,6 +13,8 @@ using namespace CoreML::Specification; namespace onnxruntime { namespace coreml { +// Once all ops are supportted FP16, we can remove it. Before that, we keep a set of ops to +// filter suppported ones. static std::set Float16Ops = { "Add", "Mul", "Sub", "Div", "Pow", "Sqrt", "Reciprocal", "Sigmoid", "Tanh", "Relu", "LeakyRelu", "Concat", "GridSample", "GlobalAveragePool", @@ -110,11 +112,13 @@ bool BaseOpBuilder::IsInputDtypeSupport(const Node& node, size_t idx, const OpBu return true; } +// only support MLProgram for FP16 #if defined(COREML_ENABLE_MLPROGRAM) if (input_params.create_mlprogram && input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 && Float16Ops.count(node.OpType())) { return true; } #endif + LOGS(logger, VERBOSE) << "[" << node.OpType() << "] Input type: [" << input_type << "] is not currently supported"; return false; } diff --git a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h index a2cbef6dd57d..153ae841b238 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h +++ b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h @@ -32,7 +32,7 @@ class BaseOpBuilder : public IOpBuilder { : allow_empty_tensor_as_input_(allow_empty_tensor_as_input) { } - // currently we only support float + // currently we support float/float16 static bool IsInputDtypeSupport(const Node& node, size_t idx, const OpBuilderInputParams& input_params, const logging::Logger& logger); diff --git a/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc index aa3060d62686..e8a138aa4979 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc @@ -27,7 +27,6 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const if (model_builder.CreateMLProgram()) { using namespace CoreML::Specification::MILSpec; - // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#module-coremltools.converters.mil.mil.ops.defs.iOS15.elementwise_binary std::string_view coreml_op_type; if (op_type == "Sqrt") { coreml_op_type = "sqrt";