From 1d6f13fb92204df344301ccc12a5f292a0cc44ed Mon Sep 17 00:00:00 2001 From: Yueqing Zhang Date: Thu, 1 Feb 2024 13:08:26 +0800 Subject: [PATCH] [VitisAI] Refactor the VAIEP to use MSFT's standalone API (#19058) ### Description Refactor the VAIEP to use MSFT's standalone API ### Motivation and Context Vitis ONNX RT VAI should switch to using the standalone API for ONNX EPs in order to decouple the EP from onnxruntime.dll and the providers.dll. This will help to simplify customer deployment of applications and use cases that need to share their onnxruntime.dll with other applications. --------- Co-authored-by: Zhenze Wang Co-authored-by: zz002 --- cmake/onnxruntime.cmake | 1 - cmake/onnxruntime_providers_vitisai.cmake | 32 +- cmake/onnxruntime_python.cmake | 11 +- cmake/onnxruntime_unittests.cmake | 1 - .../core/session/onnxruntime_c_api.h | 17 + .../core/session/onnxruntime_cxx_api.h | 3 + .../core/session/onnxruntime_cxx_inline.h | 19 + .../providers/provider_factory_creators.h | 4 - .../providers/shared_library/provider_api.h | 9 +- .../provider_bridge_provider.cc | 4 + .../shared_library/provider_interfaces.h | 87 +++++ .../shared_library/provider_wrappedtypes.h | 117 +++++- .../core/providers/vitisai/imp/attr_proto.cc | 120 +++--- .../core/providers/vitisai/imp/attr_proto.h | 46 +-- .../core/providers/vitisai/imp/capability.cc | 73 ++-- .../core/providers/vitisai/imp/global_api.cc | 367 +++++++----------- .../core/providers/vitisai/imp/graph.cc | 127 +++--- .../core/providers/vitisai/imp/node.cc | 11 +- .../core/providers/vitisai/imp/node_arg.cc | 155 ++------ .../core/providers/vitisai/imp/node_attrs.cc | 114 ------ .../providers/vitisai/imp/register_xir_ops.cc | 117 +----- .../providers/vitisai/imp/tensor_proto.cc | 100 ++--- .../core/providers/vitisai/imp/tensor_proto.h | 41 +- .../vitisai/include/vaip/capability.h | 3 +- .../vitisai/include/vaip/global_api.h | 13 +- .../providers/vitisai/include/vaip/graph.h | 22 +- .../providers/vitisai/include/vaip/my_ort.h | 26 +- .../providers/vitisai/include/vaip/node.h | 8 - .../providers/vitisai/include/vaip/node_arg.h | 7 +- .../vitisai/include/vaip/node_attrs.h | 46 --- .../vitisai/include/vaip/vaip_ort_api.h | 8 +- .../core/providers/vitisai/symbols.def | 2 + .../core/providers/vitisai/version_script.lds | 9 + .../vitisai/vitisai_execution_provider.cc | 71 +--- .../vitisai/vitisai_execution_provider.h | 4 +- .../vitisai/vitisai_provider_factory.cc | 38 +- onnxruntime/core/session/onnxruntime_c_api.cc | 1 + onnxruntime/core/session/ort_apis.h | 4 + .../core/session/provider_bridge_ort.cc | 242 +++++++++++- .../core/session/provider_registration.cc | 16 +- .../python/onnxruntime_pybind_state.cc | 2 +- setup.py | 3 + 42 files changed, 1000 insertions(+), 1101 deletions(-) delete mode 100644 onnxruntime/core/providers/vitisai/imp/node_attrs.cc delete mode 100644 onnxruntime/core/providers/vitisai/include/vaip/node_attrs.h create mode 100644 onnxruntime/core/providers/vitisai/symbols.def create mode 100644 onnxruntime/core/providers/vitisai/version_script.lds diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index c900f4d4b09a..2ead13e55419 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -189,7 +189,6 @@ set(onnxruntime_INTERNAL_LIBRARIES ${PROVIDERS_SNPE} ${PROVIDERS_TVM} ${PROVIDERS_RKNPU} - ${PROVIDERS_VITISAI} ${PROVIDERS_XNNPACK} ${PROVIDERS_WEBNN} ${PROVIDERS_AZURE} diff --git a/cmake/onnxruntime_providers_vitisai.cmake b/cmake/onnxruntime_providers_vitisai.cmake index 0951c2d02664..183a3e196af4 100644 --- a/cmake/onnxruntime_providers_vitisai.cmake +++ b/cmake/onnxruntime_providers_vitisai.cmake @@ -14,14 +14,19 @@ "${ONNXRUNTIME_ROOT}/core/providers/vitisai/*.h" "${ONNXRUNTIME_ROOT}/core/providers/vitisai/imp/*.cc" "${ONNXRUNTIME_ROOT}/core/providers/vitisai/imp/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc" ) source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_vitisai_cc_srcs}) - onnxruntime_add_static_library(onnxruntime_providers_vitisai ${onnxruntime_providers_vitisai_cc_srcs}) - onnxruntime_add_include_to_target(onnxruntime_providers_vitisai onnxruntime_common onnxruntime_framework onnx onnx_proto) - target_link_libraries(onnxruntime_providers_vitisai PRIVATE onnx protobuf::libprotobuf nlohmann_json::nlohmann_json) - if(NOT MSVC) - target_compile_options(onnxruntime_providers_vitisai PUBLIC $<$:-U_FORTIFY_SOURCE -D_FORTIFY_SOURCE=0>) - endif(NOT MSVC) + onnxruntime_add_shared_library(onnxruntime_providers_vitisai ${onnxruntime_providers_vitisai_cc_srcs}) + onnxruntime_add_include_to_target(onnxruntime_providers_vitisai ${ONNXRUNTIME_PROVIDERS_SHARED} nlohmann_json::nlohmann_json safeint_interface flatbuffers::flatbuffers) + target_link_libraries(onnxruntime_providers_vitisai PRIVATE ${ONNXRUNTIME_PROVIDERS_SHARED}) + if(MSVC) + onnxruntime_add_include_to_target(onnxruntime_providers_vitisai dbghelp) + set_property(TARGET onnxruntime_providers_vitisai APPEND_STRING PROPERTY LINK_FLAGS "-DEF:${ONNXRUNTIME_ROOT}/core/providers/vitisai/symbols.def") + else(MSVC) + set_property(TARGET onnxruntime_providers_vitisai APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/vitisai/version_script.lds -Xlinker --gc-sections") + endif(MSVC) target_include_directories(onnxruntime_providers_vitisai PRIVATE "${ONNXRUNTIME_ROOT}/core/providers/vitisai/include" ${XRT_INCLUDE_DIRS} ${CMAKE_CURRENT_BINARY_DIR}/VitisAI) if(MSVC) @@ -30,17 +35,18 @@ target_compile_options(onnxruntime_providers_vitisai PRIVATE "/wd4251") # for unused formal parameter target_compile_options(onnxruntime_providers_vitisai PRIVATE "/wd4100") + # for type name first seen using 'class' now seen using 'struct' + target_compile_options(onnxruntime_providers_vitisai PRIVATE "/wd4099") else(MSVC) + target_compile_options(onnxruntime_providers_vitisai PUBLIC $<$:-U_FORTIFY_SOURCE -D_FORTIFY_SOURCE=0>) target_compile_options(onnxruntime_providers_vitisai PRIVATE -Wno-unused-parameter) endif(MSVC) set_target_properties(onnxruntime_providers_vitisai PROPERTIES FOLDER "ONNXRuntime") set_target_properties(onnxruntime_providers_vitisai PROPERTIES LINKER_LANGUAGE CXX) - if (NOT onnxruntime_BUILD_SHARED_LIB) - install(TARGETS onnxruntime_providers_vitisai - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) - endif() + install(TARGETS onnxruntime_providers_vitisai + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index 2e3594f256f6..456344aa34d9 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -170,7 +170,6 @@ target_link_libraries(onnxruntime_pybind11_state PRIVATE onnxruntime_session ${onnxruntime_libs} ${PROVIDERS_TVM} - ${PROVIDERS_VITISAI} ${PROVIDERS_NNAPI} ${PROVIDERS_XNNPACK} ${PROVIDERS_COREML} @@ -852,6 +851,16 @@ if (onnxruntime_USE_DNNL) ) endif() +if (onnxruntime_USE_VITISAI) + add_custom_command( + TARGET onnxruntime_pybind11_state POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy + ${DNNL_DLL_PATH} $ + $ + $/onnxruntime/capi/ + ) +endif() + if (onnxruntime_USE_TENSORRT) add_custom_command( TARGET onnxruntime_pybind11_state POST_BUILD diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 714f35380ca0..6a4551ad94d9 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -591,7 +591,6 @@ set(ONNXRUNTIME_TEST_LIBS # CUDA, ROCM, TENSORRT, MIGRAPHX, DNNL, and OpenVINO are dynamically loaded at runtime ${PROVIDERS_NNAPI} ${PROVIDERS_JS} - ${PROVIDERS_VITISAI} ${PROVIDERS_QNN} ${PROVIDERS_SNPE} ${PROVIDERS_RKNPU} diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 2ce9d361e8e5..5577c840c537 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -4569,6 +4569,23 @@ struct OrtApi { _In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); + + /** \brief Append VitisAI provider to session options + * + * If VitisAI is not available (due to a non VitisAI enabled build, or if VitisAI is not installed on the system), this function will return failure. + * + * \param[in] options + * \param[in] provider_options_keys + * \param[in] provider_options_values + * \param[in] num_keys + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_VitisAI, + _In_ OrtSessionOptions* options, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, + _In_ size_t num_keys); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 7a553f9f9400..ae4c4bef90c6 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -901,6 +901,9 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl { SessionOptionsImpl& RegisterCustomOpsLibrary(const ORTCHAR_T* library_name, const CustomOpConfigs& custom_op_configs = {}); SessionOptionsImpl& RegisterCustomOpsUsingFunction(const char* function_name); ///< Wraps OrtApi::RegisterCustomOpsUsingFunction + + ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_VitisAI + SessionOptionsImpl& AppendExecutionProvider_VitisAI(const std::unordered_map& provider_options = {}); }; } // namespace detail diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 957e849cf5d4..23246adff254 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -885,6 +885,25 @@ inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_Ope return *this; } +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_VitisAI(const std::unordered_map& provider_options) { + auto num_entries = provider_options.size(); + std::vector keys, values; + if (num_entries > 0) { + keys.reserve(num_entries); + values.reserve(num_entries); + + for (const auto& entry : provider_options) { + keys.push_back(entry.first.c_str()); + values.push_back(entry.second.c_str()); + } + } + + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_VitisAI(this->p_, keys.data(), values.data(), num_entries)); + + return *this; +} + template inline SessionOptionsImpl& SessionOptionsImpl::RegisterCustomOpsLibrary(const ORTCHAR_T* library_name, const CustomOpConfigs& custom_op_configs) { diff --git a/onnxruntime/core/providers/provider_factory_creators.h b/onnxruntime/core/providers/provider_factory_creators.h index 42a58097e163..6a4ab6a3d211 100644 --- a/onnxruntime/core/providers/provider_factory_creators.h +++ b/onnxruntime/core/providers/provider_factory_creators.h @@ -78,10 +78,6 @@ #include "core/providers/tvm/tvm_provider_factory_creator.h" #endif -#if defined(USE_VITISAI) -#include "core/providers/vitisai/vitisai_provider_factory_creator.h" -#endif - #if defined(USE_XNNPACK) #include "core/providers/xnnpack/xnnpack_provider_factory_creator.h" #endif diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 1e3a528d8772..b78279040acb 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -95,12 +95,15 @@ enum OperatorStatus : int { }; // onnx Protobuf types (All of these are direct mappings to the onnx types except for the Repeated*Field ones which map to a Repeated*Field type) -struct int64s; // RepeatedField +struct int64s; // RepeatedField +struct float32s; // RepeatedField struct AttributeProto; struct GraphProto; struct ModelProto; struct NodeProto; struct SparseTensorProto; +struct StringStringEntryProto; +struct StringStringEntryProtos; // RepeatedPtrField struct TensorProto; struct TensorProtos; // RepeatedPtrField struct TensorShapeProto_Dimension; @@ -113,6 +116,9 @@ struct TypeProto_Sequence; struct TypeProto; struct ValueInfoProto; struct ValueInfoProtos; // RepeatedPtrField +struct InferenceContext; +class GraphInferencer; +using InferenceFunction = std::function; } // namespace ONNX_NAMESPACE namespace onnxruntime { @@ -249,6 +255,7 @@ constexpr const char* kCudaExecutionProvider = "CUDAExecutionProvider"; constexpr const char* kCannExecutionProvider = "CANNExecutionProvider"; constexpr const char* kDnnlExecutionProvider = "DnnlExecutionProvider"; constexpr const char* kOpenVINOExecutionProvider = "OpenVINOExecutionProvider"; +constexpr const char* kVitisAIExecutionProvider = "VitisAIExecutionProvider"; constexpr const char* kRocmExecutionProvider = "ROCMExecutionProvider"; constexpr const char* kTensorrtExecutionProvider = "TensorrtExecutionProvider"; constexpr const char* kMIGraphXExecutionProvider = "MIGraphXExecutionProvider"; diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index 6dbe103791e4..da17135878fe 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -492,6 +492,10 @@ template <> Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ int64_t* p_data, size_t expected_size) { return g_host->UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); } template <> Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ uint64_t* p_data, size_t expected_size) { return g_host->UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); } +Status UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& tensor, const Path& model_path, + /*out*/ std::vector& unpacked_tensor) { + return g_host->UnpackInitializerData(tensor, model_path, unpacked_tensor); +} } // namespace utils diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index a216b2bfc6d0..f5a832744386 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -91,6 +91,7 @@ using HashValue = uint64_t; using NodeIndex = size_t; // We can't just reinterpret_cast this one, since it's an unordered_map of object BY VALUE (can't do anything by value on the real types) // using NodeAttributes = std::unordered_map; +using ModelMetaData = std::unordered_map; using InitializedTensorSet = std::unordered_map; @@ -201,6 +202,8 @@ struct ProviderHost { virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ uint32_t* p_data, size_t expected_size) = 0; virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ int64_t* p_data, size_t expected_size) = 0; virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ uint64_t* p_data, size_t expected_size) = 0; + virtual Status UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& tensor, const Path& model_path, + /*out*/ std::vector& unpacked_tensor) = 0; virtual uint16_t math__floatToHalf(float f) = 0; virtual float math__halfToFloat(uint16_t h) = 0; @@ -261,12 +264,32 @@ struct ProviderHost { virtual void logging__Capture__operator_delete(logging::Capture* p) noexcept = 0; virtual std::ostream& logging__Capture__Stream(logging::Capture* p) noexcept = 0; + // Env + virtual Env& Env__Default() = 0; + // Utils::DataTypeUtils virtual const std::string* Utils__DataTypeUtils__ToType(const ONNX_NAMESPACE::TypeProto& type_proto) = 0; // int64s virtual int int64s__size(const ONNX_NAMESPACE::int64s* p) = 0; virtual const int64_t& int64s__Get(const ONNX_NAMESPACE::int64s* p, int index) = 0; + virtual void int64s__Reserve(ONNX_NAMESPACE::int64s* p, int size) = 0; + virtual const int64_t* int64s__data(const ONNX_NAMESPACE::int64s* p) = 0; + + // float32s + virtual void float32s__Reserve(ONNX_NAMESPACE::float32s* p, int size) = 0; + virtual const float* float32s__data(const ONNX_NAMESPACE::float32s* p) = 0; + virtual int float32s__size(const ONNX_NAMESPACE::float32s* p) = 0; + + // StringStringEntryProto + virtual std::string* StringStringEntryProto__mutable_key(ONNX_NAMESPACE::StringStringEntryProto* p) = 0; + virtual std::string* StringStringEntryProto__mutable_value(ONNX_NAMESPACE::StringStringEntryProto* p) = 0; + + // StringStringEntryProtos + virtual void StringStringEntryProtos__Clear(ONNX_NAMESPACE::StringStringEntryProtos* p) = 0; + virtual ONNX_NAMESPACE::StringStringEntryProto* StringStringEntryProtos__Add(ONNX_NAMESPACE::StringStringEntryProtos* p) = 0; + virtual int StringStringEntryProtos__size(ONNX_NAMESPACE::StringStringEntryProtos* p) = 0; + virtual ONNX_NAMESPACE::StringStringEntryProto& StringStringEntryProtos__at(ONNX_NAMESPACE::StringStringEntryProtos* p, int index) = 0; #if !defined(DISABLE_OPTIONAL_TYPE) // TypeProto_Optional @@ -283,6 +306,7 @@ struct ProviderHost { virtual const ONNX_NAMESPACE::TensorShapeProto& TypeProto_Tensor__shape(const ONNX_NAMESPACE::TypeProto_Tensor* p) = 0; virtual ONNX_NAMESPACE::TensorShapeProto* TypeProto_Tensor__mutable_shape(ONNX_NAMESPACE::TypeProto_Tensor* p) = 0; virtual int32_t TypeProto_Tensor__elem_type(const ONNX_NAMESPACE::TypeProto_Tensor* p) = 0; + virtual void TypeProto_Tensor__set_elem_type(ONNX_NAMESPACE::TypeProto_Tensor* p, int32_t value) = 0; #if !defined(DISABLE_SPARSE_TENSORS) // TypeProto_SparseTensor @@ -327,9 +351,17 @@ struct ProviderHost { virtual float AttributeProto__floats(const ONNX_NAMESPACE::AttributeProto* p, int i) = 0; virtual const ::std::string& AttributeProto__strings(const ONNX_NAMESPACE::AttributeProto* p, int i) = 0; virtual const ONNX_NAMESPACE::int64s& AttributeProto__ints(const ONNX_NAMESPACE::AttributeProto* p) = 0; + virtual const ONNX_NAMESPACE::float32s& AttributeProto__floats(const ONNX_NAMESPACE::AttributeProto* p) = 0; + virtual ONNX_NAMESPACE::int64s* AttributeProto__mutable_ints(ONNX_NAMESPACE::AttributeProto* p) = 0; + virtual ONNX_NAMESPACE::float32s* AttributeProto__mutable_floats(ONNX_NAMESPACE::AttributeProto* p) = 0; + virtual void AttributeProto__add_ints(ONNX_NAMESPACE::AttributeProto* p, int64_t size) = 0; + virtual void AttributeProto__add_floats(ONNX_NAMESPACE::AttributeProto* p, float size) = 0; + virtual void AttributeProto__add_strings(ONNX_NAMESPACE::AttributeProto* p, const ::std::string& size) = 0; virtual int64_t AttributeProto__i(const ONNX_NAMESPACE::AttributeProto* p) = 0; virtual float AttributeProto__f(const ONNX_NAMESPACE::AttributeProto* p) = 0; + virtual const ONNX_NAMESPACE::TensorProto& AttributeProto__t(const ONNX_NAMESPACE::AttributeProto* p) = 0; virtual void AttributeProto__set_s(ONNX_NAMESPACE::AttributeProto* p, const ::std::string& value) = 0; + virtual void AttributeProto__set_f(ONNX_NAMESPACE::AttributeProto* p, const float& value) = 0; virtual void AttributeProto__set_i(ONNX_NAMESPACE::AttributeProto* p, int64_t value) = 0; virtual const ::std::string& AttributeProto__s(const ONNX_NAMESPACE::AttributeProto* p) = 0; virtual void AttributeProto__set_name(ONNX_NAMESPACE::AttributeProto* p, const ::std::string& value) = 0; @@ -352,6 +384,7 @@ struct ProviderHost { virtual ONNX_NAMESPACE::ValueInfoProtos* GraphProto__mutable_value_info(ONNX_NAMESPACE::GraphProto* p) = 0; virtual ONNX_NAMESPACE::TensorProtos* GraphProto__mutable_initializer(ONNX_NAMESPACE::GraphProto* p) = 0; virtual ONNX_NAMESPACE::NodeProto* GraphProto__add_node(ONNX_NAMESPACE::GraphProto* p) = 0; + virtual std::string* GraphProto__mutable_name(ONNX_NAMESPACE::GraphProto* p) = 0; virtual ONNX_NAMESPACE::NodeProto* GraphProto__mutable_node(ONNX_NAMESPACE::GraphProto* p, int index) = 0; // ModelProto @@ -367,6 +400,7 @@ struct ProviderHost { virtual ONNX_NAMESPACE::GraphProto* ModelProto__mutable_graph(ONNX_NAMESPACE::ModelProto* p) = 0; virtual void ModelProto__set_ir_version(ONNX_NAMESPACE::ModelProto* p, int64_t value) = 0; + virtual ONNX_NAMESPACE::StringStringEntryProtos* ModelProto__mutable_metadata_props(ONNX_NAMESPACE::ModelProto* p) = 0; // NodeProto virtual std::unique_ptr NodeProto__construct() = 0; @@ -381,19 +415,33 @@ struct ProviderHost { virtual void TensorProto__operator_delete(ONNX_NAMESPACE::TensorProto* p) = 0; virtual void TensorProto__operator_assign(ONNX_NAMESPACE::TensorProto* p, const ONNX_NAMESPACE::TensorProto& v) = 0; virtual bool TensorProto__has_name(const ONNX_NAMESPACE::TensorProto* p) = 0; + virtual void TensorProto__set_name(ONNX_NAMESPACE::TensorProto* p, const ::std::string& name) = 0; + virtual const ::std::string& TensorProto__name(const ONNX_NAMESPACE::TensorProto* p) = 0; virtual int TensorProto__dims_size(const ONNX_NAMESPACE::TensorProto* p) = 0; virtual const ONNX_NAMESPACE::int64s& TensorProto__dims(const ONNX_NAMESPACE::TensorProto* p) = 0; + virtual void TensorProto__add_dims(ONNX_NAMESPACE::TensorProto* p, int64_t value) = 0; virtual bool TensorProto__has_data_location(const ONNX_NAMESPACE::TensorProto* p) = 0; virtual int TensorProto__data_location(const ONNX_NAMESPACE::TensorProto* p) = 0; virtual bool TensorProto__has_raw_data(const ONNX_NAMESPACE::TensorProto* p) = 0; virtual const std::string& TensorProto__raw_data(const ONNX_NAMESPACE::TensorProto* p) = 0; + virtual std::string* TensorProto__mutable_raw_data(ONNX_NAMESPACE::TensorProto* p) = 0; virtual int32_t TensorProto__data_type(const ONNX_NAMESPACE::TensorProto* p) = 0; + virtual void TensorProto__set_data_type(ONNX_NAMESPACE::TensorProto* p, int32_t type) = 0; virtual void TensorProto__CopyFrom(ONNX_NAMESPACE::TensorProto* p, const ONNX_NAMESPACE::TensorProto* other) = 0; + virtual ONNX_NAMESPACE::StringStringEntryProtos* TensorProto__mutable_external_data(ONNX_NAMESPACE::TensorProto* p) = 0; + virtual void TensorProto__clear_float_data(ONNX_NAMESPACE::TensorProto* p) = 0; + virtual void TensorProto__clear_int32_data(ONNX_NAMESPACE::TensorProto* p) = 0; + virtual void TensorProto__clear_string_data(ONNX_NAMESPACE::TensorProto* p) = 0; + virtual void TensorProto__clear_int64_data(ONNX_NAMESPACE::TensorProto* p) = 0; + virtual void TensorProto__clear_double_data(ONNX_NAMESPACE::TensorProto* p) = 0; + virtual void TensorProto__clear_uint64_data(ONNX_NAMESPACE::TensorProto* p) = 0; virtual bool TensorProto_DataType_IsValid(int value) = 0; // TensorProtos virtual ONNX_NAMESPACE::TensorProto* TensorProtos__Add(ONNX_NAMESPACE::TensorProtos* p) = 0; + virtual int TensorProtos__size(ONNX_NAMESPACE::TensorProtos* p) = 0; + virtual ONNX_NAMESPACE::TensorProto& TensorProtos__at(ONNX_NAMESPACE::TensorProtos* p, int index) = 0; // TensorShapeProto_Dimension virtual int TensorShapeProto_Dimension__value_case(const ONNX_NAMESPACE::TensorShapeProto_Dimension* p) = 0; @@ -403,6 +451,8 @@ struct ProviderHost { virtual bool TensorShapeProto_Dimension__has_dim_value(const ONNX_NAMESPACE::TensorShapeProto_Dimension* p) = 0; virtual bool TensorShapeProto_Dimension__has_dim_param(const ONNX_NAMESPACE::TensorShapeProto_Dimension* p) = 0; virtual void TensorShapeProto_Dimension__clear_dim_value(ONNX_NAMESPACE::TensorShapeProto_Dimension* p) = 0; + virtual const std::string& TensorShapeProto_Dimension__denotation(const ONNX_NAMESPACE::TensorShapeProto_Dimension* p) const = 0; + virtual void TensorShapeProto_Dimension__set_denotation(ONNX_NAMESPACE::TensorShapeProto_Dimension* p, const std::string& value) = 0; // TensorShapeProto_Dimensions virtual std::unique_ptr TensorShapeProto_Dimensions__begin(const ONNX_NAMESPACE::TensorShapeProto_Dimensions* p) = 0; @@ -426,6 +476,8 @@ struct ProviderHost { virtual const ONNX_NAMESPACE::ValueInfoProto& ValueInfoProtos__operator_array(const ONNX_NAMESPACE::ValueInfoProtos* p, int index) = 0; + virtual void RegisterSchema(const std::string& domain, const OrtCustomOp* op, int type) = 0; + // ConfigOptions virtual std::optional ConfigOptions__GetConfigEntry(const ConfigOptions* p, const std::string& config_key) = 0; @@ -651,6 +703,7 @@ struct ProviderHost { virtual void Node__ToProto(const Node* p, ONNX_NAMESPACE::NodeProto& proto, bool update_subgraphs = false) = 0; virtual const NodeAttributes& Node__GetAttributes(const Node* p) noexcept = 0; + virtual void Node__AddAttribute(Node* p, const ::std::string& attr_name, const ONNX_NAMESPACE::GraphProto& value) = 0; virtual size_t Node__GetInputEdgesCount(const Node* p) noexcept = 0; virtual size_t Node__GetOutputEdgesCount(const Node* p) noexcept = 0; @@ -660,10 +713,13 @@ struct ProviderHost { virtual std::unique_ptr Node__OutputNodesBegin(const Node* p) noexcept = 0; virtual std::unique_ptr Node__OutputNodesEnd(const Node* p) noexcept = 0; + virtual std::unique_ptr Node__InputEdgesBegin(const Node* p) noexcept = 0; + virtual std::unique_ptr Node__InputEdgesEnd(const Node* p) noexcept = 0; virtual std::unique_ptr Node__OutputEdgesBegin(const Node* p) noexcept = 0; virtual std::unique_ptr Node__OutputEdgesEnd(const Node* p) noexcept = 0; virtual void Node__ForEachDef(const Node* p, std::function func, bool include_missing_optional_defs) = 0; + virtual int Node__NodeType(const Node* p) const noexcept = 0; virtual const std::unordered_map>& Node__GetAttributeNameToMutableSubgraphMap(Node* p) = 0; virtual std::unordered_map> Node__GetAttributeNameToSubgraphMap(const Node* p) const = 0; @@ -674,6 +730,7 @@ struct ProviderHost { virtual const ONNX_NAMESPACE::NodeArgInfo& NodeArg__ToProto(const NodeArg* p) noexcept = 0; virtual bool NodeArg__Exists(const NodeArg* p) const noexcept = 0; virtual const ONNX_NAMESPACE::TypeProto* NodeArg__TypeAsProto(const NodeArg* p) noexcept = 0; + virtual Status NodeArg__OverrideTypesHelper(NodeArg* p, const ONNX_NAMESPACE::TypeProto& input_type, int32_t input_tensor_elem_type, int32_t current_tensor_elem_type, bool override_types) = 0; // NodeAttributes virtual std::unique_ptr NodeAttributes__construct() = 0; @@ -691,12 +748,18 @@ struct ProviderHost { virtual std::unique_ptr NodeAttributes__find(const NodeAttributes* p, const std::string& key) = 0; virtual void NodeAttributes__insert(NodeAttributes* p, const NodeAttributes& v) = 0; virtual void NodeAttributes__emplace(NodeAttributes* p, const std::string& k, const ONNX_NAMESPACE::AttributeProto& v) = 0; + virtual void NodeAttributes__insert_or_assign(NodeAttributes* p, const std::string& k, const ONNX_NAMESPACE::AttributeProto& v) = 0; virtual void NodeAttributes__reserve(NodeAttributes* p, size_t size) = 0; // Model + virtual std::unique_ptr Model__construct(ONNX_NAMESPACE::ModelProto&& model_proto, + const PathString& model_path, const logging::Logger& logger) = 0; virtual void Model__operator_delete(Model* p) = 0; virtual Graph& Model__MainGraph(Model* p) = 0; virtual std::unique_ptr Model__ToProto(Model* p) = 0; + virtual std::unique_ptr Model__ToGraphProtoWithExternalInitializers(Model* p, const std::string& external_file_name, const PathString& file_path, size_t initializer_size_threshold) = 0; + virtual const ModelMetaData& Model__MetaData(const Model* p) const noexcept = 0; + virtual Status Model__Load(const PathString& file_path, /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) = 0; // Graph virtual std::unique_ptr Graph__CreateGraphViewer(const Graph* p) = 0; @@ -714,6 +777,7 @@ struct ProviderHost { virtual void Graph__SetOutputs(Graph* p, gsl::span outputs) = 0; virtual const std::vector& Graph__GetInputs(const Graph* p) noexcept = 0; + virtual std::vector Graph__Nodes(const Graph* p) = 0; virtual bool Graph__GetInitializedTensor(const Graph* p, const std::string& tensor_name, const ONNX_NAMESPACE::TensorProto*& value) = 0; virtual const Node* Graph__ParentNode(const Graph* p) const = 0; @@ -723,6 +787,26 @@ struct ProviderHost { virtual const Path& Graph__ModelPath(const Graph* p) const = 0; virtual const std::vector& Graph__GetInputsIncludingInitializers(const Graph* p) const noexcept = 0; virtual bool Graph__IsSubgraph(const Graph* p) = 0; + virtual const Node* Graph__GetProducerNode(const Graph* p, const std::string& node_arg_name) const = 0; + virtual const Model& Graph__GetModel(const Graph* p) = 0; + virtual void Graph__ReverseDFSFrom(const Graph* p, gsl::span from, + const std::function& enter, + const std::function& leave, + const std::function& comp, + const std::function& stop) const = 0; + virtual Graph& Graph__SetGraphResolveNeeded(Graph* p) = 0; + virtual void Graph__RemoveInitializedTensor(Graph* p, const std::string& tensor_name) = 0; + + virtual std::vector Graph__GetConsumerNodes(const Graph* p, const std::string& node_arg_name) const = 0; + virtual void Graph__AddEdge(Graph* p, NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, + int dst_arg_index) = 0; + virtual void Graph__RemoveEdge(Graph* p, NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, + int dst_arg_index) = 0; + virtual void Graph__RemoveNode(Graph* p, NodeIndex index) = 0; + virtual Node& Graph__FuseSubGraph(Graph* p, const IndexedSubGraph& sub_graph, const std::string& fused_node_name) = 0; + virtual void Graph__UpdateProducerNode(Graph* p, const std::string& node_arg_name, NodeIndex node_index) = 0; + virtual const ONNX_NAMESPACE::TensorProto* Graph__GetConstantInitializer(const Graph* p, const std::string& name, bool check_outer_scope) const = 0; + virtual const InitializedTensorSet& Graph__GetAllInitializedTensors(const Graph* p) = 0; virtual int Graph__MaxNodeIndex(const Graph* p) const noexcept = 0; virtual Node* Graph__GetNode(Graph* p, NodeIndex node_index) noexcept = 0; virtual const Node* Graph__GetNode(const Graph* p, NodeIndex node_index) const = 0; @@ -757,11 +841,14 @@ struct ProviderHost { virtual const std::vector& GraphViewer__GetInputsIncludingInitializers(const GraphViewer* p) noexcept = 0; virtual void GraphViewer__ToProto(const GraphViewer* p, ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializers, bool include_outer_scope_args) noexcept = 0; + virtual const Node* GraphViewer__GetProducerNode(const GraphViewer* p, const std::string& node_arg_name) const = 0; // Path virtual PathString Path__ToPathString(const Path* p) noexcept = 0; virtual const std::vector& Path__GetComponents(const Path* p) noexcept = 0; virtual bool Path__IsEmpty(const Path* p) noexcept = 0; + virtual std::unique_ptr Path__construct() = 0; + virtual void Path__operator_delete(ONNX_NAMESPACE::Path* p) = 0; // OpKernel virtual const Node& OpKernel__Node(const OpKernel* p) = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index f46c76fd3421..dde4005c80b9 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -52,11 +52,34 @@ namespace ONNX_NAMESPACE { struct int64s final { int size() const { return g_host->int64s__size(this); } const int64_t& Get(int index) const { return g_host->int64s__Get(this, index); } + const int64_t* data() const { return g_host->int64s__data(this); } const int64_t& operator[](int index) const { return Get(index); } - + void Reserve(int size) { g_host->int64s__Reserve(this, size); } PROVIDER_DISALLOW_ALL(int64s) }; +struct float32s final { + void Reserve(int size) { g_host->float32s__Reserve(this, size); } + const float* data() const { return g_host->float32s__data(this); } + int size() const { return g_host->float32s__size(this); } + PROVIDER_DISALLOW_ALL(float32s) +}; + +struct StringStringEntryProto final { + std::string* mutable_key() { return g_host->StringStringEntryProto__mutable_key(this); } + std::string* mutable_value() { return g_host->StringStringEntryProto__mutable_value(this); } + + PROVIDER_DISALLOW_ALL(StringStringEntryProto) +}; + +struct StringStringEntryProtos final { + void Clear() { g_host->StringStringEntryProtos__Clear(this); } + StringStringEntryProto* Add() { return g_host->StringStringEntryProtos__Add(this); } + int size() { return g_host->StringStringEntryProtos__size(this); } + StringStringEntryProto& at(int index) { return g_host->StringStringEntryProtos__at(this, index); } + + PROVIDER_DISALLOW_ALL(StringStringEntryProtos) +}; struct AttributeProto final { static std::unique_ptr Create() { return g_host->AttributeProto__construct(); } void operator=(const AttributeProto& v) { g_host->AttributeProto__operator_assign(this, v); } @@ -71,9 +94,18 @@ struct AttributeProto final { float floats(int i) const { return g_host->AttributeProto__floats(this, i); } const std::string& strings(int i) const { return g_host->AttributeProto__strings(this, i); } const int64s& ints() const { return g_host->AttributeProto__ints(this); } + const float32s& floats() const { return g_host->AttributeProto__floats(this); } + int64s* mutable_ints() { return g_host->AttributeProto__mutable_ints(this); } + float32s* mutable_floats() { return g_host->AttributeProto__mutable_floats(this); } + void add_ints(int64_t value) { g_host->AttributeProto__add_ints(this, value); } + void add_floats(float value) { g_host->AttributeProto__add_floats(this, value); } + void add_strings(const ::std::string& value) { g_host->AttributeProto__add_strings(this, value); } + int64_t i() const { return g_host->AttributeProto__i(this); } float f() const { return g_host->AttributeProto__f(this); } + const ONNX_NAMESPACE::TensorProto& t() const { return g_host->AttributeProto__t(this); } void set_s(const ::std::string& value) { return g_host->AttributeProto__set_s(this, value); } + void set_f(const float& value) { return g_host->AttributeProto__set_f(this, value); } void set_i(int64_t value) { return g_host->AttributeProto__set_i(this, value); } const ::std::string& s() const { return g_host->AttributeProto__s(this); } void set_name(const ::std::string& value) { return g_host->AttributeProto__set_name(this, value); } @@ -121,6 +153,8 @@ struct GraphProto final { NodeProto* add_node() { return g_host->GraphProto__add_node(this); } NodeProto* mutable_node(int index) { return g_host->GraphProto__mutable_node(this, index); } + std::string* mutable_name() { return g_host->GraphProto__mutable_name(this); } + GraphProto() = delete; GraphProto(const GraphProto&) = delete; }; @@ -133,7 +167,7 @@ struct ModelProto final { bool SerializeToOstream(std::ostream& output) const { return g_host->ModelProto__SerializeToOstream(this, output); } bool ParseFromString(const std::string& data) { return g_host->ModelProto__ParseFromString(this, data); } std::string SerializeAsString() const { return g_host->ModelProto__SerializeAsString(this); } - + StringStringEntryProtos* mutable_metadata_props() { return g_host->ModelProto__mutable_metadata_props(this); }; const GraphProto& graph() const { return g_host->ModelProto__graph(this); } GraphProto* mutable_graph() { return g_host->ModelProto__mutable_graph(this); } @@ -162,17 +196,22 @@ struct TensorProto final { void operator=(const TensorProto& v) { g_host->TensorProto__operator_assign(this, v); } bool has_name() const { return g_host->TensorProto__has_name(this); } + void set_name(const ::std::string& name) { return g_host->TensorProto__set_name(this, name); } + const ::std::string& name() const { return g_host->TensorProto__name(this); } int dims_size() const { return g_host->TensorProto__dims_size(this); } const int64s& dims() const { return g_host->TensorProto__dims(this); } + void add_dims(int64_t value) { g_host->TensorProto__add_dims(this, value); } bool has_data_location() const { return g_host->TensorProto__has_data_location(this); } TensorProto_DataLocation data_location() const { return TensorProto_DataLocation(g_host->TensorProto__data_location(this)); } bool has_raw_data() const { return g_host->TensorProto__has_raw_data(this); } const std::string& raw_data() const { return g_host->TensorProto__raw_data(this); } + std::string* mutable_raw_data() { return g_host->TensorProto__mutable_raw_data(this); } int32_t data_type() const { return g_host->TensorProto__data_type(this); } + void set_data_type(int32_t type) { return g_host->TensorProto__set_data_type(this, type); } typedef TensorProto_DataType DataType; static constexpr DataType UNDEFINED = TensorProto_DataType_UNDEFINED; @@ -180,6 +219,13 @@ struct TensorProto final { static bool DataType_IsValid(int value) { return g_host->TensorProto_DataType_IsValid(value); } void copy_from(const TensorProto* other) { return g_host->TensorProto__CopyFrom(this, other); } + StringStringEntryProtos* mutable_external_data() { return g_host->TensorProto__mutable_external_data(this); }; + void clear_float_data() { return g_host->TensorProto__clear_float_data(this); } + void clear_int32_data() { return g_host->TensorProto__clear_int32_data(this); } + void clear_string_data() { return g_host->TensorProto__clear_string_data(this); } + void clear_int64_data() { return g_host->TensorProto__clear_int64_data(this); } + void clear_double_data() { return g_host->TensorProto__clear_double_data(this); } + void clear_uint64_data() { return g_host->TensorProto__clear_uint64_data(this); } TensorProto() = delete; TensorProto(const TensorProto&) = delete; @@ -187,6 +233,8 @@ struct TensorProto final { struct TensorProtos final { TensorProto* Add() { return g_host->TensorProtos__Add(this); } + int size() { return g_host->TensorProtos__size(this); } + TensorProto& at(int index) { return g_host->TensorProtos__at(this, index); } PROVIDER_DISALLOW_ALL(TensorProtos) }; @@ -205,6 +253,8 @@ struct TensorShapeProto_Dimension final { bool has_dim_value() const { return g_host->TensorShapeProto_Dimension__has_dim_value(this); } bool has_dim_param() const { return g_host->TensorShapeProto_Dimension__has_dim_param(this); } void clear_dim_value() { return g_host->TensorShapeProto_Dimension__clear_dim_value(this); } + const std::string& denotation() const { return g_host->TensorShapeProto_Dimension__denotation(this); } + void set_denotation(const std::string& value) { g_host->TensorShapeProto_Dimension__set_denotation(this, value); } PROVIDER_DISALLOW_ALL(TensorShapeProto_Dimension) }; @@ -232,6 +282,7 @@ struct TypeProto_Tensor final { const TensorShapeProto& shape() const { return g_host->TypeProto_Tensor__shape(this); } TensorShapeProto* mutable_shape() { return g_host->TypeProto_Tensor__mutable_shape(this); } int32_t elem_type() const { return g_host->TypeProto_Tensor__elem_type(this); } + void set_elem_type(int32_t value) { g_host->TypeProto_Tensor__set_elem_type(this, value); } PROVIDER_DISALLOW_ALL(TypeProto_Tensor) }; @@ -315,7 +366,6 @@ struct ValueInfoProtos final { PROVIDER_DISALLOW_ALL(ValueInfoProtos) }; - } // namespace ONNX_NAMESPACE namespace onnxruntime { @@ -603,6 +653,10 @@ struct Function final { }; struct Node final { + enum class Type { + Primitive = 0, + Fused = 1, + }; const std::string& Name() const noexcept { return g_host->Node__Name(this); } const std::string& Description() const noexcept { return g_host->Node__Description(this); } const std::string& Domain() const noexcept { return g_host->Node__Domain(this); } @@ -626,6 +680,10 @@ struct Node final { void ToProto(ONNX_NAMESPACE::NodeProto& proto, bool update_subgraphs = false) const { return g_host->Node__ToProto(this, proto, update_subgraphs); } const NodeAttributes& GetAttributes() const noexcept { return g_host->Node__GetAttributes(this); } + void AddAttribute(const ::std::string& attr_name, const ONNX_NAMESPACE::GraphProto& value) { + g_host->Node__AddAttribute(this, attr_name, value); + } + size_t GetInputEdgesCount() const noexcept { return g_host->Node__GetInputEdgesCount(this); } size_t GetOutputEdgesCount() const noexcept { return g_host->Node__GetOutputEdgesCount(this); } @@ -661,12 +719,15 @@ struct Node final { std::unique_ptr impl_; }; + EdgeConstIterator InputEdgesBegin() const noexcept { return g_host->Node__InputEdgesBegin(this); } + EdgeConstIterator InputEdgesEnd() const noexcept { return g_host->Node__InputEdgesEnd(this); } EdgeConstIterator OutputEdgesBegin() const noexcept { return g_host->Node__OutputEdgesBegin(this); } EdgeConstIterator OutputEdgesEnd() const noexcept { return g_host->Node__OutputEdgesEnd(this); } void ForEachDef(std::function func, bool include_missing_optional_defs = false) const { g_host->Node__ForEachDef(this, func, std::move(include_missing_optional_defs)); } const std::unordered_map>& GetAttributeNameToMutableSubgraphMap() { return g_host->Node__GetAttributeNameToMutableSubgraphMap(this); } std::unordered_map> GetAttributeNameToSubgraphMap() const { return g_host->Node__GetAttributeNameToSubgraphMap(this); } + Type NodeType() const noexcept { return Type(g_host->Node__NodeType(this)); } PROVIDER_DISALLOW_ALL(Node) }; @@ -678,6 +739,7 @@ struct NodeArg final { const NodeArgInfo& ToProto() const noexcept { return g_host->NodeArg__ToProto(this); } bool Exists() const noexcept { return g_host->NodeArg__Exists(this); } const ONNX_NAMESPACE::TypeProto* TypeAsProto() const noexcept { return g_host->NodeArg__TypeAsProto(this); } + Status OverrideTypesHelper(const ONNX_NAMESPACE::TypeProto& input_type, int32_t input_tensor_elem_type, int32_t current_tensor_elem_type, bool override_types) { return g_host->NodeArg__OverrideTypesHelper(this, input_type, input_tensor_elem_type, current_tensor_elem_type, override_types); } PROVIDER_DISALLOW_ALL(NodeArg) }; @@ -698,6 +760,8 @@ struct NodeAttributes final { IteratorHolder> find(const std::string& key) const { return g_host->NodeAttributes__find(this, key); } void insert(const NodeAttributes& v) { return g_host->NodeAttributes__insert(this, v); } void emplace(const std::string& k, const ONNX_NAMESPACE::AttributeProto& v) { g_host->NodeAttributes__emplace(this, k, v); } + void insert_or_assign(const std::string& k, const ONNX_NAMESPACE::AttributeProto& v) { g_host->NodeAttributes__insert_or_assign(this, k, v); } + void reserve(size_t size) { g_host->NodeAttributes__reserve(this, size); } NodeAttributes() = delete; @@ -705,11 +769,18 @@ struct NodeAttributes final { }; struct Model final { + static std::unique_ptr Create(ONNX_NAMESPACE::ModelProto&& model_proto, const PathString& model_path, + const logging::Logger& logger) { + return g_host->Model__construct(std::move(model_proto), model_path, logger); + } static void operator delete(void* p) { g_host->Model__operator_delete(reinterpret_cast(p)); } + static Status Load(const PathString& file_path, /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) { return g_host->Model__Load(file_path, model_proto); } Graph& MainGraph() { return g_host->Model__MainGraph(this); } std::unique_ptr ToProto() { return g_host->Model__ToProto(this); } + std::unique_ptr ToGraphProtoWithExternalInitializers(const std::string& external_file_name, const PathString& file_path, size_t initializer_size_threshold) { return g_host->Model__ToGraphProtoWithExternalInitializers(this, external_file_name, file_path, initializer_size_threshold); } + const ModelMetaData& MetaData() const noexcept { return g_host->Model__MetaData(this); } Model() = delete; Model(const Model&) = delete; @@ -732,6 +803,7 @@ struct Graph final { void SetOutputs(gsl::span outputs) { return g_host->Graph__SetOutputs(this, outputs); } const std::vector& GetInputs() const noexcept { return g_host->Graph__GetInputs(this); } + std::vector Nodes() const noexcept { return g_host->Graph__Nodes(this); } bool GetInitializedTensor(const std::string& tensor_name, const ONNX_NAMESPACE::TensorProto*& value) const { return g_host->Graph__GetInitializedTensor(this, tensor_name, value); } @@ -742,6 +814,37 @@ struct Graph final { const Path& ModelPath() const { return g_host->Graph__ModelPath(this); } const std::vector& GetInputsIncludingInitializers() const noexcept { return g_host->Graph__GetInputsIncludingInitializers(this); } bool IsSubgraph() const { return g_host->Graph__IsSubgraph(this); } + const Node* GetProducerNode(const std::string& node_arg_name) const { return g_host->Graph__GetProducerNode(this, node_arg_name); } + const Model& GetModel() const { return g_host->Graph__GetModel(this); } + void ReverseDFSFrom(gsl::span from, const std::function& enter, + const std::function& leave, + const std::function& comp, + const std::function& stop) const { + g_host->Graph__ReverseDFSFrom(this, from, enter, leave, comp, stop); + } + Graph& SetGraphResolveNeeded() { return g_host->Graph__SetGraphResolveNeeded(this); } + void RemoveInitializedTensor(const std::string& tensor_name) { g_host->Graph__RemoveInitializedTensor(this, tensor_name); } + + std::vector GetConsumerNodes(const std::string& node_arg_name) const { + return g_host->Graph__GetConsumerNodes(this, node_arg_name); + } + void AddEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, int dst_arg_index) { + g_host->Graph__AddEdge(this, src_node_index, dst_node_index, src_arg_index, dst_arg_index); + } + void RemoveEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, int dst_arg_index) { + g_host->Graph__RemoveEdge(this, src_node_index, dst_node_index, src_arg_index, dst_arg_index); + } + void RemoveNode(NodeIndex index) { g_host->Graph__RemoveNode(this, index); } + Node& FuseSubGraph(const IndexedSubGraph& sub_graph, const std::string& fused_node_name) { + return g_host->Graph__FuseSubGraph(this, sub_graph, fused_node_name); + } + void UpdateProducerNode(const std::string& node_arg_name, NodeIndex node_index) { + g_host->Graph__UpdateProducerNode(this, node_arg_name, node_index); + } + const ONNX_NAMESPACE::TensorProto* GetConstantInitializer(const std::string& name, bool check_outer_scope) const { + return g_host->Graph__GetConstantInitializer(this, name, check_outer_scope); + } + const InitializedTensorSet& GetAllInitializedTensors() const noexcept { return g_host->Graph__GetAllInitializedTensors(this); } int MaxNodeIndex() const noexcept { return g_host->Graph__MaxNodeIndex(this); } const Node* GetNode(NodeIndex node_index) const noexcept { return g_host->Graph__GetNode(this, node_index); } Node* GetNode(NodeIndex node_index) noexcept { return g_host->Graph__GetNode(this, node_index); } @@ -783,6 +886,7 @@ class GraphViewer final { const std::vector& GetInputsIncludingInitializers() const noexcept { return g_host->GraphViewer__GetInputsIncludingInitializers(this); } void ToProto(ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializers, bool include_outer_scope_args) const { g_host->GraphViewer__ToProto(this, graph_proto, include_initializers, include_outer_scope_args); } + const Node* GetProducerNode(const std::string& node_arg_name) const { return g_host->GraphViewer__GetProducerNode(this, node_arg_name); } GraphViewer() = delete; GraphViewer(const GraphViewer&) = delete; @@ -790,11 +894,16 @@ class GraphViewer final { }; struct Path final { + static std::unique_ptr Create() { return g_host->Path__construct(); } + static void operator delete(void* p) { g_host->Path__operator_delete(reinterpret_cast(p)); } + PathString ToPathString() const noexcept { return g_host->Path__ToPathString(this); } const std::vector& GetComponents() const noexcept { return g_host->Path__GetComponents(this); } bool IsEmpty() const noexcept { return g_host->Path__IsEmpty(this); } - PROVIDER_DISALLOW_ALL(Path) + Path() = delete; + Path(const Path&) = delete; + void operator=(const Path&) = delete; }; struct OpKernelContext final { diff --git a/onnxruntime/core/providers/vitisai/imp/attr_proto.cc b/onnxruntime/core/providers/vitisai/imp/attr_proto.cc index 29bc886fb5ed..1392ecef1b72 100644 --- a/onnxruntime/core/providers/vitisai/imp/attr_proto.cc +++ b/onnxruntime/core/providers/vitisai/imp/attr_proto.cc @@ -2,126 +2,106 @@ // Licensed under the MIT License. #include "./attr_proto.h" -#include "./vai_assert.h" - #include #include #include #include -namespace vaip { +#include "core/providers/shared_library/provider_api.h" -ONNX_NAMESPACE::AttributeProto* attr_proto_new_int(const std::string& name, - int64_t value) { - auto ret = new onnx::AttributeProto(); +#include "./vai_assert.h" + +namespace vaip { +ONNX_NAMESPACE::AttributeProto* attr_proto_new_int(const std::string& name, int64_t value) { + auto ret = ONNX_NAMESPACE::AttributeProto::Create(); ret->set_name(name); - ret->set_type(onnx::AttributeProto_AttributeType_INT); + ret->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT); ret->set_i(value); - return ret; + return ret.release(); } -ONNX_NAMESPACE::AttributeProto* attr_proto_new_float(const std::string& name, - float value) { - auto ret = new onnx::AttributeProto(); +ONNX_NAMESPACE::AttributeProto* attr_proto_new_float(const std::string& name, float value) { + auto ret = ONNX_NAMESPACE::AttributeProto::Create(); ret->set_name(name); - ret->set_type(onnx::AttributeProto_AttributeType_FLOAT); + ret->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT); ret->set_f(value); - return ret; + return ret.release(); } -ONNX_NAMESPACE::AttributeProto* attr_proto_new_string( - const std::string& name, const std::string& value) { - auto ret = new onnx::AttributeProto(); +ONNX_NAMESPACE::AttributeProto* attr_proto_new_string(const std::string& name, const std::string& value) { + auto ret = ONNX_NAMESPACE::AttributeProto::Create(); ret->set_name(name); - ret->set_type(onnx::AttributeProto_AttributeType_STRING); + ret->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_STRING); ret->set_s(value); - return ret; + return ret.release(); } ONNX_NAMESPACE::AttributeProto* attr_proto_new_tensor( const std::string& name, const ONNX_NAMESPACE::TensorProto& value) { - auto ret = new onnx::AttributeProto(); + auto ret = ONNX_NAMESPACE::AttributeProto::Create(); ret->set_name(name); - ret->set_type(onnx::AttributeProto_AttributeType_TENSOR); - *ret->mutable_t() = value; - return ret; + ret->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_TENSOR); + *ret->add_tensors() = value; + return ret.release(); } -ONNX_NAMESPACE::AttributeProto* attr_proto_new_ints( - const std::string& name, const std::vector& value) { - auto ret = new onnx::AttributeProto(); +ONNX_NAMESPACE::AttributeProto* attr_proto_new_ints(const std::string& name, const std::vector& value) { + auto ret = ONNX_NAMESPACE::AttributeProto::Create(); ret->set_name(name); - ret->set_type(onnx::AttributeProto_AttributeType_INTS); + ret->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INTS); ret->mutable_ints()->Reserve((int)value.size()); for (auto v : value) { ret->add_ints(v); } - return ret; + return ret.release(); } - ONNX_NAMESPACE::AttributeProto* attr_proto_new_floats( const std::string& name, const std::vector& value) { - auto ret = new onnx::AttributeProto(); + auto ret = ONNX_NAMESPACE::AttributeProto::Create(); ret->set_name(name); - ret->set_type(onnx::AttributeProto_AttributeType_FLOATS); + ret->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_FLOATS); ret->mutable_floats()->Reserve((int)value.size()); for (auto v : value) { ret->add_floats(v); } - return ret; + return ret.release(); } - -ONNX_NAMESPACE::AttributeProto* attr_proto_new_strings( - const std::string& name, const std::vector& value) { - auto ret = new onnx::AttributeProto(); +ONNX_NAMESPACE::AttributeProto* attr_proto_new_strings(const std::string& name, const std::vector& value) { + auto ret = ONNX_NAMESPACE::AttributeProto::Create(); ret->set_name(name); - ret->set_type(onnx::AttributeProto_AttributeType_STRINGS); - ret->mutable_strings()->Reserve((int)value.size()); + ret->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_STRINGS); for (auto& v : value) { ret->add_strings(v); } - return ret; + return ret.release(); } - -int64_t attr_proto_get_int(const onnx::AttributeProto& attr) { - vai_assert(attr.type() == onnx::AttributeProto_AttributeType_INT, attr.DebugString()); +int64_t attr_proto_get_int(const ONNX_NAMESPACE::AttributeProto& attr) { + vai_assert(attr.type() == ONNX_NAMESPACE::AttributeProto_AttributeType_INT, attr.name()); return attr.i(); } - -float attr_proto_get_float(const onnx::AttributeProto& attr) { - vai_assert(attr.type() == onnx::AttributeProto_AttributeType_FLOAT, attr.DebugString()); +float attr_proto_get_float(const ONNX_NAMESPACE::AttributeProto& attr) { + vai_assert(attr.type() == ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT, attr.name()); return attr.f(); } - -const std::string& attr_proto_get_string(const onnx::AttributeProto& attr) { - vai_assert(attr.type() == onnx::AttributeProto_AttributeType_STRING, attr.DebugString()); +const std::string& attr_proto_get_string(const ONNX_NAMESPACE::AttributeProto& attr) { + vai_assert(attr.type() == ONNX_NAMESPACE::AttributeProto_AttributeType_STRING, attr.name()); return attr.s(); } - -const ONNX_NAMESPACE::TensorProto& attr_proto_get_tensor( - const onnx::AttributeProto& attr) { - vai_assert(attr.type() == onnx::AttributeProto_AttributeType_TENSOR, attr.DebugString()); +const ONNX_NAMESPACE::TensorProto& attr_proto_get_tensor(const ONNX_NAMESPACE::AttributeProto& attr) { + vai_assert(attr.type() == ONNX_NAMESPACE::AttributeProto_AttributeType_TENSOR, attr.name()); return attr.t(); } - -gsl::span attr_proto_get_ints(const onnx::AttributeProto& attr) { - vai_assert(attr.type() == onnx::AttributeProto_AttributeType_INTS, attr.DebugString()); +gsl::span attr_proto_get_ints(const ONNX_NAMESPACE::AttributeProto& attr) { + vai_assert(attr.type() == ONNX_NAMESPACE::AttributeProto_AttributeType_INTS, attr.name()); return gsl::span(attr.ints()); } - -gsl::span attr_proto_get_floats(const onnx::AttributeProto& attr) { - vai_assert(attr.type() == onnx::AttributeProto_AttributeType_FLOATS, attr.DebugString()); +gsl::span attr_proto_get_floats(const ONNX_NAMESPACE::AttributeProto& attr) { + vai_assert(attr.type() == ONNX_NAMESPACE::AttributeProto_AttributeType_FLOATS, attr.name()); return gsl::span(attr.floats()); } - -std::vector attr_proto_get_strings( - const ONNX_NAMESPACE::AttributeProto& attr) { - vai_assert(attr.type() == onnx::AttributeProto_AttributeType_STRINGS, attr.DebugString()); - return std::vector(attr.strings().begin(), attr.strings().end()); -} - -ONNX_NAMESPACE::AttributeProto attr_proto_from_i64(const std::string& name, - int64_t value) { - ONNX_NAMESPACE::AttributeProto ret; - ret.set_name(name); - ret.set_i(value); +std::vector attr_proto_get_strings(const ONNX_NAMESPACE::AttributeProto& attr) { + vai_assert(attr.type() == ONNX_NAMESPACE::AttributeProto_AttributeType_STRINGS, attr.name()); + std::vector ret; + ret.reserve(attr.strings_size()); + for (int i = 0; i < attr.strings_size(); i++) { + ret.push_back(attr.strings(i)); + } return ret; } - } // namespace vaip diff --git a/onnxruntime/core/providers/vitisai/imp/attr_proto.h b/onnxruntime/core/providers/vitisai/imp/attr_proto.h index 32ba8fa672d7..f4d56dd618a8 100644 --- a/onnxruntime/core/providers/vitisai/imp/attr_proto.h +++ b/onnxruntime/core/providers/vitisai/imp/attr_proto.h @@ -2,46 +2,26 @@ // Licensed under the MIT License. #pragma once #include - +#include "vaip/my_ort.h" #include "core/common/gsl.h" -#include "onnx/onnx_pb.h" namespace vaip { -ONNX_NAMESPACE::AttributeProto* attr_proto_new_int(const std::string& name, - int64_t value); -ONNX_NAMESPACE::AttributeProto* attr_proto_new_float(const std::string& name, - float value); -ONNX_NAMESPACE::AttributeProto* attr_proto_new_string(const std::string& name, - const std::string& value); -ONNX_NAMESPACE::AttributeProto* attr_proto_new_tensor( - const std::string& name, const ONNX_NAMESPACE::TensorProto& value); -ONNX_NAMESPACE::AttributeProto* attr_proto_new_ints( - const std::string& name, const std::vector& value); -ONNX_NAMESPACE::AttributeProto* attr_proto_new_floats( - const std::string& name, const std::vector& value); -ONNX_NAMESPACE::AttributeProto* attr_proto_new_strings( - const std::string& name, const std::vector& value); +ONNX_NAMESPACE::AttributeProto* attr_proto_new_int(const std::string& name, int64_t value); +ONNX_NAMESPACE::AttributeProto* attr_proto_new_float(const std::string& name, float value); +ONNX_NAMESPACE::AttributeProto* attr_proto_new_string(const std::string& name, const std::string& value); +ONNX_NAMESPACE::AttributeProto* attr_proto_new_tensor(const std::string& name, const ONNX_NAMESPACE::TensorProto& value); +ONNX_NAMESPACE::AttributeProto* attr_proto_new_ints(const std::string& name, const std::vector& value); +ONNX_NAMESPACE::AttributeProto* attr_proto_new_floats(const std::string& name, const std::vector& value); +ONNX_NAMESPACE::AttributeProto* attr_proto_new_strings(const std::string& name, const std::vector& value); /// attr_proto getters int64_t attr_proto_get_int(const ONNX_NAMESPACE::AttributeProto& attr); float attr_proto_get_float(const ONNX_NAMESPACE::AttributeProto& attr); -const std::string& attr_proto_get_string( - const ONNX_NAMESPACE::AttributeProto& attr); - -const ONNX_NAMESPACE::TensorProto& attr_proto_get_tensor( - const onnx::AttributeProto& attr); -gsl::span attr_proto_get_ints(const onnx::AttributeProto& attr); -gsl::span attr_proto_get_floats(const onnx::AttributeProto& attr); -std::vector attr_proto_get_strings( - const ONNX_NAMESPACE::AttributeProto& attr); - -/// attr_proto makers -ONNX_NAMESPACE::AttributeProto attr_proto_from_i64(const std::string& name, - int64_t); - -/// -using attr_proto_func_t = std::function; +const std::string& attr_proto_get_string(const ONNX_NAMESPACE::AttributeProto& attr); +const ONNX_NAMESPACE::TensorProto& attr_proto_get_tensor(const ONNX_NAMESPACE::AttributeProto& attr); +gsl::span attr_proto_get_ints(const ONNX_NAMESPACE::AttributeProto& attr); +gsl::span attr_proto_get_floats(const ONNX_NAMESPACE::AttributeProto& attr); +std::vector attr_proto_get_strings(const ONNX_NAMESPACE::AttributeProto& attr); } // namespace vaip diff --git a/onnxruntime/core/providers/vitisai/imp/capability.cc b/onnxruntime/core/providers/vitisai/imp/capability.cc index a55180bd2ee5..58522a45a151 100644 --- a/onnxruntime/core/providers/vitisai/imp/capability.cc +++ b/onnxruntime/core/providers/vitisai/imp/capability.cc @@ -3,15 +3,10 @@ #include "vaip/capability.h" #include "./vai_assert.h" -#include "core/graph/basic_types.h" - -#include "./attr_proto.h" - namespace vaip { using namespace ::onnxruntime; -static std::vector node_names_to_nodes(const GraphViewer& graph, - const std::vector& node_names) { +static std::vector node_names_to_nodes(const GraphViewer& graph, const std::vector& node_names) { auto ret = std::vector(); ret.reserve(node_names.size()); for (auto& onnx_node_name : node_names) { @@ -24,53 +19,45 @@ static std::vector node_names_to_nodes(const GraphViewer& graph, } std::unique_ptr XirSubgraphToComputeCapability1(const onnxruntime::GraphViewer& graph, vaip_core::ExecutionProvider* ep, size_t index) { - auto meta_def = std::make_unique(); - meta_def->constant_initializers = *ep->get_meta_def_constant_initializer(); - meta_def->inputs = *ep->get_meta_def_inputs(); - meta_def->outputs = *ep->get_meta_def_outputs(); - auto indexed_subgraph = std::make_unique(); - auto indexed_subgraph_ptr = indexed_subgraph.get(); - indexed_subgraph_ptr->nodes = node_names_to_nodes(graph, *ep->get_meta_def_nodes()); + auto meta_def = IndexedSubGraph_MetaDef::Create(); + meta_def->constant_initializers() = *ep->get_meta_def_constant_initializer(); + meta_def->inputs() = *ep->get_meta_def_inputs(); + meta_def->outputs() = *ep->get_meta_def_outputs(); + auto indexed_subgraph = IndexedSubGraph::Create(); + indexed_subgraph->Nodes() = node_names_to_nodes(graph, *ep->get_meta_def_nodes()); static auto g_counter = 1; - meta_def->name = std::string("vitis_ai_ep_") + std::to_string(g_counter++); - meta_def->domain = "com.xilinx"; - meta_def->since_version = 1; - meta_def->status = ONNX_NAMESPACE::EXPERIMENTAL; - auto index_proto = std::unique_ptr(vaip::attr_proto_new_int("index", (int64_t)index)); - meta_def->attributes["index"] = *index_proto; + meta_def->name() = std::string("vitis_ai_ep_") + std::to_string(g_counter++); + meta_def->domain() = "com.xilinx"; + meta_def->since_version() = 1; + meta_def->status() = ONNX_NAMESPACE::EXPERIMENTAL; + auto index_proto = ONNX_NAMESPACE::AttributeProto::Create(); + index_proto->set_name("index"); + index_proto->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT); + index_proto->set_i(index); + meta_def->attributes()["index"] = *index_proto; indexed_subgraph->SetMetaDef(std::move(meta_def)); - return std::make_unique(std::move(indexed_subgraph)); + return ComputeCapability::Create(std::move(indexed_subgraph)); } std::vector> GetComputeCapabilityOps(const onnxruntime::GraphViewer& graph, vaip_core::DllSafe>>* eps, - const std::set& all_not_support_optypes) { - std::set all_compute_capability_nodes; + const std::set& all_support_optypes_by_eps) { + std::set all_nodes_included_eps; for (auto& ep : **eps) { - auto nodes = *ep->get_meta_def_nodes(); - for (auto n : nodes) - all_compute_capability_nodes.insert(n); + auto nodes = node_names_to_nodes(graph, *ep->get_meta_def_nodes()); + all_nodes_included_eps.insert(nodes.begin(), nodes.end()); } + + std::vector node_indexs = graph.GetNodesInTopologicalOrder(); + node_indexs.erase(std::remove_if(node_indexs.begin(), node_indexs.end(), [&](NodeIndex index) { return all_nodes_included_eps.count(index) > 0; }), node_indexs.end()); + node_indexs.erase(std::remove_if(node_indexs.begin(), node_indexs.end(), [&](NodeIndex index) { return all_support_optypes_by_eps.count(graph.GetNode(index)->OpType()) == 0; }), node_indexs.end()); + std::vector> result; - for (auto& n : graph.Nodes()) { - if ((!all_compute_capability_nodes.count(n.Name())) && all_not_support_optypes.count(n.OpType())) { - auto meta_def = std::make_unique(); - meta_def->name = n.OpType(); - meta_def->domain = n.Domain(); - meta_def->since_version = 1; - meta_def->status = ONNX_NAMESPACE::EXPERIMENTAL; - auto indexed_subgraph = std::make_unique(); - indexed_subgraph->nodes.push_back(n.Index()); - for (auto i : n.InputDefs()) { - meta_def->inputs.push_back(i->Name()); - } - for (auto i : n.OutputDefs()) { - meta_def->outputs.push_back(i->Name()); - } - indexed_subgraph->SetMetaDef(std::move(meta_def)); - result.emplace_back(std::make_unique(std::move(indexed_subgraph))); - } + for (auto& n : node_indexs) { + auto indexed_subgraph = IndexedSubGraph::Create(); + indexed_subgraph->Nodes() = {n}; + result.emplace_back(ComputeCapability::Create(std::move(indexed_subgraph))); } return result; } diff --git a/onnxruntime/core/providers/vitisai/imp/global_api.cc b/onnxruntime/core/providers/vitisai/imp/global_api.cc index b629c8eff909..f609d40f459b 100644 --- a/onnxruntime/core/providers/vitisai/imp/global_api.cc +++ b/onnxruntime/core/providers/vitisai/imp/global_api.cc @@ -1,20 +1,18 @@ - // Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. // Licensed under the MIT License. + #include "vaip/global_api.h" #include +#include +#include #include #include "./vai_assert.h" -#include "core/common/exceptions.h" -#include "core/common/logging/logging.h" +#include "core/common/exceptions.h" #include "core/framework/error_code_helper.h" - -#include "core/graph/model.h" -#include "core/session/ort_env.h" -#include "core/session/onnxruntime_cxx_api.h" +#include "core/providers/shared/common.h" #include @@ -55,16 +53,14 @@ struct OrtVitisAIEpAPI { std::vector>* (*compile_onnx_model_with_options)( const std::string& model_path, const onnxruntime::Graph& graph, const onnxruntime::ProviderOptions& options); void Ensure() { - if (handle_) return; - auto full_path = Env::Default().GetRuntimePath() + - PathString(LIBRARY_PREFIX ORT_TSTR("onnxruntime_vitisai_ep") LIBRARY_EXTENSION); - ORT_THROW_IF_ERROR(Env::Default().LoadDynamicLibrary(full_path, true, &handle_)); - ORT_THROW_IF_ERROR(Env::Default().GetSymbolFromLibrary( - handle_, "initialize_onnxruntime_vitisai_ep", reinterpret_cast(&initialize_onnxruntime_vitisai_ep))); - auto status1 = Env::Default().GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep_with_options", - reinterpret_cast(&compile_onnx_model_with_options)); - auto status2 = Env::Default().GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep", - reinterpret_cast(&compile_onnx_model_3)); + if (handle_) + return; + auto& env = Provider_GetHost()->Env__Default(); + auto full_path = env.GetRuntimePath() + PathString(LIBRARY_PREFIX ORT_TSTR("onnxruntime_vitisai_ep") LIBRARY_EXTENSION); + ORT_THROW_IF_ERROR(env.LoadDynamicLibrary(full_path, true, &handle_)); + ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "initialize_onnxruntime_vitisai_ep", (void**)&initialize_onnxruntime_vitisai_ep)); + auto status1 = env.GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep_with_options", (void**)&compile_onnx_model_with_options); + auto status2 = env.GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep", (void**)&compile_onnx_model_3); if (!status1.IsOK() && !status2.IsOK()) { ::onnxruntime::LogRuntimeError(0, status1, __FILE__, static_cast(__FUNCTION__), __LINE__); ORT_THROW(status1); @@ -76,6 +72,12 @@ struct OrtVitisAIEpAPI { }; static OrtVitisAIEpAPI s_library_vitisaiep; +static std::shared_ptr s_kernel_registry_vitisaiep; +static std::vector s_domains_vitisaiep; +static vaip_core::OrtApiForVaip the_global_api; +std::shared_ptr get_kernel_registry_vitisaiep() { return s_kernel_registry_vitisaiep; } +const std::vector& get_domains_vitisaiep() { return s_domains_vitisaiep; } + static std::string config_to_json_str(const onnxruntime::ProviderOptions& config) { auto iter = config.find("config_file"); if (iter == config.end()) { @@ -105,121 +107,142 @@ static std::string config_to_json_str(const onnxruntime::ProviderOptions& config return ""; } } -vaip_core::DllSafe>> compile_onnx_model_with_options( - const std::string& model_path, const onnxruntime::Graph& graph, const onnxruntime::ProviderOptions& options) { + +vaip_core::DllSafe>> compile_onnx_model( + const onnxruntime::GraphViewer& graph_viewer, const logging::Logger& logger, const ProviderOptions& options) { +#ifndef _WIN32 + auto model_path = graph_viewer.ModelPath().ToPathString(); +#else + using convert_t = std::codecvt_utf8; + std::wstring_convert strconverter; + auto model_path = strconverter.to_bytes(graph_viewer.ModelPath().ToPathString()); +#endif if (s_library_vitisaiep.compile_onnx_model_with_options) { - return vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_with_options(model_path, graph, options)); + return vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_with_options(model_path, graph_viewer.GetGraph(), options)); } else { auto json_str = config_to_json_str(options); - return vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_3(model_path, graph, json_str.c_str())); + return vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_3(model_path, graph_viewer.GetGraph(), json_str.c_str())); } } -std::vector initialize_vitisai_ep() { - s_library_vitisaiep.Ensure(); - Status status = Status::OK(); - try { - OrtEnv::LoggingManagerConstructionInfo lm_info{nullptr, nullptr, ORT_LOGGING_LEVEL_WARNING, - "onnxruntime-vitisai-ep"}; - std::ignore = OrtEnv::GetInstance(lm_info, status); - } catch (onnxruntime::OnnxRuntimeException& /*e*/) { +struct MyCustomOpKernel : OpKernel { + MyCustomOpKernel(const OpKernelInfo& info, const OrtCustomOp& op) : OpKernel(info), op_(op) { + op_kernel_ = + op_.CreateKernel(&op_, Ort::Global::api_, reinterpret_cast(&info)); } - auto domains = std::vector(); - domains.reserve(100); - s_library_vitisaiep.initialize_onnxruntime_vitisai_ep(create_org_api_hook(), domains); - auto& domainToVersionRangeInstance = ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance(); - if (domainToVersionRangeInstance.Map().find("com.xilinx") == domainToVersionRangeInstance.Map().end()) { - vaip::register_xir_ops(domains); + + ~MyCustomOpKernel() override { op_.KernelDestroy(op_kernel_); } + + Status Compute(OpKernelContext* ctx) const override { + op_.KernelCompute(op_kernel_, reinterpret_cast(ctx)); + return Status::OK(); } - return domains; + private: + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(MyCustomOpKernel); + + const OrtCustomOp& op_; + void* op_kernel_; +}; + +void create_kernel_registry(std::vector domains) { + s_kernel_registry_vitisaiep = KernelRegistry::Create(); + for (const auto& domain : domains) { + for (const auto* op : domain->custom_ops_) { + auto def_builder = KernelDefBuilder::Create(); + def_builder->SetName(op->GetName(op)); + def_builder->SetDomain(domain->domain_.c_str()); + def_builder->SinceVersion(1); + if (op->version > 12) { + auto input_count = op->GetInputTypeCount(op); + for (auto i = 0u; i < input_count; i++) { + def_builder->InputMemoryType(op->GetInputMemoryType(op, i), i); + } + } + def_builder->Provider(onnxruntime::kVitisAIExecutionProvider); + KernelCreateFn kernel_create_fn = + [op](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { + // out = std::make_unique(info, *op); + return Status::OK(); + }; + std::ignore = s_kernel_registry_vitisaiep->Register(KernelCreateInfo(def_builder->Build(), kernel_create_fn)); + } + } +} +void initialize_vitisai_ep() { + s_library_vitisaiep.Ensure(); + s_domains_vitisaiep.reserve(100); + s_library_vitisaiep.initialize_onnxruntime_vitisai_ep(create_org_api_hook(), s_domains_vitisaiep); + vaip::register_xir_ops(s_domains_vitisaiep); + create_kernel_registry(s_domains_vitisaiep); } -static vaip_core::OrtApiForVaip the_global_api; vaip_core::OrtApiForVaip* create_org_api_hook() { + InitProviderOrtApi(); + the_global_api.host_ = Provider_GetHost(); assert(Ort::Global::api_ != nullptr); the_global_api.ort_api_ = Ort::Global::api_; the_global_api.model_load = [](const std::string& filename) -> Model* { - ONNX_NAMESPACE::ModelProto model_proto; + auto model_proto = ONNX_NAMESPACE::ModelProto::Create(); auto& logger = logging::LoggingManager::DefaultLogger(); auto file_path = ToPathString(filename); - auto status = Model::Load(file_path, model_proto); + auto status = Model::Load(file_path, *model_proto); vai_assert(status.IsOK(), "load model proto error"); - auto model = std::make_unique(std::move(model_proto), file_path, nullptr, logger); + auto model = Model::Create(std::move(*model_proto), file_path, logger); return model.release(); }; the_global_api.model_delete = [](Model* model) { delete model; }; - the_global_api.model_clone = [](const Model& model) -> Model* { + + the_global_api.model_clone = [](const Model& const_model) -> Model* { auto& logger = logging::LoggingManager::DefaultLogger(); - auto model_proto = const_cast(model).ToProto(); - auto file_path = model.ModelPath().ToPathString(); - auto ret = std::make_unique(std::move(model_proto), file_path, nullptr, logger); + auto& model = const_cast(const_model); + auto model_proto = model.ToProto(); + auto file_path = model.MainGraph().ModelPath().ToPathString(); + auto ret = Model::Create(std::move(*model_proto), file_path, logger); auto status = ret->MainGraph().Resolve(); vai_assert(status.IsOK(), status.ErrorMessage()); return ret.release(); }; - the_global_api.model_set_meta_data = [](Model& model, const std::string& key, const std::string& value) -> void { + the_global_api.model_set_meta_data = [](Model& model, const std::string& key, const std::string& value) { const_cast(model.MetaData())[key] = value; }; - the_global_api.model_get_meta_data = [](const Model& model, - const std::string& key) -> vaip_core::DllSafe { - auto& m = model.MetaData(); - auto it = m.find(key); - auto ret = std::string(); - if (it != m.end()) { - ret = it->second; + the_global_api.model_get_meta_data = + [](const Model& model, const std::string& key) -> vaip_core::DllSafe { + if (model.MetaData().count(key)) { + return vaip_core::DllSafe(model.MetaData().at(key)); } - return vaip_core::DllSafe(ret); + return vaip_core::DllSafe(std::string()); }; - the_global_api.model_has_meta_data = [](const Model& model, const std::string& key) -> int { - auto& m = model.MetaData(); - return m.find(key) != m.end() ? 1 : 0; + return int(model.MetaData().count(key)); }; - the_global_api.model_main_graph = [](Model& model) -> Graph& { return model.MainGraph(); }; the_global_api.graph_get_model = [](const Graph& graph) -> const Model& { return graph.GetModel(); }; - the_global_api.graph_get_inputs_unsafe = [](const Graph& graph) -> vaip_core::DllSafe> { - auto ret = std::vector(); - auto inputs = graph.GetInputs(); - for (auto input : inputs) { - vai_assert(input->Exists(), input->Name()); - ret.push_back(input); - } - return vaip_core::DllSafe(std::move(ret)); + the_global_api.graph_get_inputs_unsafe = [](const Graph& graph) -> auto { + return vaip_core::DllSafe(graph.GetInputs()); }; - the_global_api.graph_get_outputs_unsafe = [](const Graph& graph) -> vaip_core::DllSafe> { + the_global_api.graph_get_outputs_unsafe = [](const Graph& graph) -> auto { return vaip_core::DllSafe(graph.GetOutputs()); }; - - the_global_api.graph_set_outputs = [](Graph& graph, gsl::span outputs) -> void { - return graph.SetOutputs(outputs); + the_global_api.graph_set_outputs = [](Graph& graph, gsl::span outputs) { + graph.SetOutputs(outputs); }; - the_global_api.graph_get_node_arg = [](const Graph& graph, const std::string& name) -> const NodeArg* { return graph.GetNodeArg(name); }; the_global_api.graph_producer_node = [](const Graph& graph, const std::string& name) -> const Node* { return graph.GetProducerNode(name); }; - - the_global_api.graph_get_node = [](const Graph& graph, size_t index) -> const Node* { return graph.GetNode(index); }; - + the_global_api.graph_get_node = [](const Graph& graph, size_t index) -> const Node* { + return graph.GetNode(index); + }; the_global_api.graph_save = vaip::graph_save; the_global_api.graph_fuse = vaip::graph_fuse; the_global_api.graph_remove_node = vaip::graph_remove_node; - the_global_api.graph_add_node = [](Graph& graph, const std::string& name, const std::string& op_type, - const std::string& description, const std::vector& input_args, - const std::vector& output_args, - vaip_core::NodeAttributes& attributes, const std::string& domain) -> Node& { - return vaip::graph_add_node(graph, name, op_type, description, input_args, output_args, - std::move(reinterpret_cast(attributes)), domain); - }; - + the_global_api.graph_add_node = vaip::graph_add_node; the_global_api.graph_get_all_initialized_tensors = [](const Graph& graph) -> const InitializedTensorSet& { return graph.GetAllInitializedTensors(); }; - the_global_api.graph_resolve = [](Graph& graph, bool force) { if (force) { graph.SetGraphResolveNeeded(); @@ -227,129 +250,57 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { auto status = graph.Resolve(); return status.Code(); }; - - the_global_api.graph_get_consumer_nodes_unsafe = - [](const Graph& graph, const std::string& node_arg_name) -> vaip_core::DllSafe> { + the_global_api.graph_get_consumer_nodes_unsafe = [](const Graph& graph, const std::string& node_arg_name) -> auto { return vaip_core::DllSafe(graph.GetConsumerNodes(node_arg_name)); }; - the_global_api.graph_nodes_unsafe = [](const Graph& graph) -> vaip_core::DllSafe> { - auto& node_refererence = graph.Nodes(); - std::vector nodes(static_cast(graph.NumberOfNodes()), nullptr); - std::transform(node_refererence.begin(), node_refererence.end(), nodes.begin(), [](const Node& n) { return &n; }); - return vaip_core::DllSafe(std::move(nodes)); - }; + the_global_api.graph_nodes_unsafe = [](const Graph& graph) -> auto { return vaip_core::DllSafe(graph.Nodes()); }; the_global_api.graph_get_name = [](const Graph& graph) -> const std::string& { return graph.Name(); }; the_global_api.graph_reverse_dfs_from = [](const Graph& graph, gsl::span from, - const std::function& enter, - const std::function& leave, - const std::function& stop) { + const auto& enter, const auto& leave, const auto& stop) { graph.ReverseDFSFrom(from, enter, leave, nullptr, stop); }; // node the_global_api.node_get_inputs_unsafe = vaip::node_get_inputs; the_global_api.node_get_output_node_args_unsafe = vaip::node_get_output_node_args; - the_global_api.node_op_type = [](const Node& node) -> const std::string& { return node.OpType(); }; the_global_api.node_op_domain = [](const Node& node) -> const std::string& { return node.Domain(); }; - the_global_api.node_get_index = [](const Node& node) -> size_t { return static_cast(node.Index()); }; + the_global_api.node_get_index = [](const Node& node) -> size_t { return node.Index(); }; the_global_api.node_get_name = [](const Node& node) -> const std::string& { return node.Name(); }; the_global_api.node_description = [](const Node& node) -> const std::string& { return node.Description(); }; - - the_global_api.node_get_attributes = [](Node& node) -> vaip_core::NodeAttributes& { - return reinterpret_cast(node.GetMutableAttributes()); - }; - - the_global_api.node_type_is_fused = [](const Node& node) { - return node.NodeType() == onnxruntime::Node::Type::Fused; + the_global_api.node_get_attributes = [](Node& node) -> NodeAttributes& { + return const_cast(node.GetAttributes()); }; - the_global_api.node_get_function_body = [](const Node& node) -> const onnxruntime::Graph& { + the_global_api.node_type_is_fused = [](const Node& node) { return node.NodeType() == Node::Type::Fused; }; + the_global_api.node_get_function_body = [](const Node& node) -> const auto& { assert(node.GetFunctionBody() != nullptr); return node.GetFunctionBody()->Body(); }; // node_arg - the_global_api.node_arg_get_name_unsafe = [](const NodeArg& node_arg) -> const std::string& { - return node_arg.Name(); - }; + the_global_api.node_arg_get_name_unsafe = + [](const NodeArg& node_arg) -> const std::string& { return node_arg.Name(); }; the_global_api.node_arg_clone = vaip::node_arg_clone; the_global_api.node_arg_new = vaip::node_arg_new; - the_global_api.node_arg_is_exists = vaip::node_arg_is_exists; + the_global_api.node_arg_is_exists = [](const NodeArg& node_arg) { return node_arg.Exists(); }; the_global_api.node_arg_is_constant = vaip::node_arg_is_constant; the_global_api.node_arg_get_shape_i64_unsafe = vaip::node_arg_get_shape_i64; the_global_api.node_arg_set_shape_i64 = vaip::node_arg_set_shape_i64; the_global_api.node_arg_get_denotation_unsafe = vaip::node_arg_get_denotation; + the_global_api.node_arg_set_denotation = vaip::node_arg_set_denotation; the_global_api.node_arg_get_const_data_as_tensor = vaip::node_arg_get_const_data_as_tensor; the_global_api.node_arg_get_element_type = vaip::node_arg_get_element_type; - the_global_api.node_arg_set_element_type = [](NodeArg& node_arg, int type) { - auto data_type = ONNX_NAMESPACE::TensorProto::UNDEFINED; - switch (type) { - case 1: - data_type = ONNX_NAMESPACE::TensorProto::FLOAT; - break; - case 2: - data_type = ONNX_NAMESPACE::TensorProto::UINT8; - break; - case 3: - data_type = ONNX_NAMESPACE::TensorProto::INT8; - break; - - case 4: - data_type = ONNX_NAMESPACE::TensorProto::UINT16; - break; - case 5: - data_type = ONNX_NAMESPACE::TensorProto::INT16; - break; - case 6: - data_type = ONNX_NAMESPACE::TensorProto::INT32; - break; - case 7: - data_type = ONNX_NAMESPACE::TensorProto::INT64; - break; - case 8: - data_type = ONNX_NAMESPACE::TensorProto::STRING; - break; - case 9: - data_type = ONNX_NAMESPACE::TensorProto::BOOL; - break; - case 10: - data_type = ONNX_NAMESPACE::TensorProto::FLOAT16; - break; - case 11: - data_type = ONNX_NAMESPACE::TensorProto::DOUBLE; - break; - case 12: - data_type = ONNX_NAMESPACE::TensorProto::UINT32; - break; - case 13: - data_type = ONNX_NAMESPACE::TensorProto::UINT64; - break; - case 14: - data_type = ONNX_NAMESPACE::TensorProto::COMPLEX64; - break; - case 15: - data_type = ONNX_NAMESPACE::TensorProto::COMPLEX128; - break; - case 16: - data_type = ONNX_NAMESPACE::TensorProto::BFLOAT16; - break; - default: - vai_assert(false, "TensorProto::DataType not supoort"); - } - return vaip::node_arg_set_element_type(node_arg, data_type); - }; + the_global_api.node_arg_set_element_type = vaip::node_arg_set_element_type; /// attr proto - the_global_api.attr_proto_delete = [](onnx::AttributeProto* v) { delete v; }; - the_global_api.attr_proto_clone = [](const onnx::AttributeProto& v) -> onnx::AttributeProto* { - return new onnx::AttributeProto(v); - }; - the_global_api.attr_proto_get_name = [](const onnx::AttributeProto& attr_proto) -> const std::string& { - return attr_proto.name(); - }; - the_global_api.attr_proto_set_name = [](onnx::AttributeProto* attr_proto, const std::string& name) { - attr_proto->set_name(name); + the_global_api.attr_proto_delete = [](ONNX_NAMESPACE::AttributeProto* v) { delete v; }; + the_global_api.attr_proto_clone = [](const ONNX_NAMESPACE::AttributeProto& v) -> ONNX_NAMESPACE::AttributeProto* { + auto ret = ONNX_NAMESPACE::AttributeProto::Create(); + *ret = v; + return ret.release(); }; + the_global_api.attr_proto_get_name = [](const auto& attr_proto) -> const std::string& { return attr_proto.name(); }; + the_global_api.attr_proto_set_name = [](auto* attr_proto, const auto& name) { attr_proto->set_name(name); }; the_global_api.attr_proto_new_int = vaip::attr_proto_new_int; the_global_api.attr_proto_new_float = vaip::attr_proto_new_float; the_global_api.attr_proto_new_string = vaip::attr_proto_new_string; @@ -364,31 +315,24 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { the_global_api.attr_proto_get_ints = vaip::attr_proto_get_ints; the_global_api.attr_proto_get_floats = vaip::attr_proto_get_floats; the_global_api.attr_proto_get_strings = vaip::attr_proto_get_strings; - the_global_api.attr_proto_get_type = [](const onnx::AttributeProto& attr) -> int { return attr.type(); }; + the_global_api.attr_proto_get_type = [](const ONNX_NAMESPACE::AttributeProto& attr) -> int { return attr.type(); }; /// node attributes - the_global_api.node_attributes_new = []() { - return reinterpret_cast(new NodeAttributes()); - }; - the_global_api.node_attributes_add = [](vaip_core::NodeAttributes& p, onnx::AttributeProto&& attr) { - reinterpret_cast(p).insert_or_assign(attr.name(), std::move(attr)); + the_global_api.node_attributes_new = []() { return NodeAttributes::Create().release(); }; + the_global_api.node_attributes_add = [](NodeAttributes& p, ONNX_NAMESPACE::AttributeProto&& attr) { + p.insert_or_assign(attr.name(), std::move(attr)); }; - the_global_api.node_attributes_delete = [](vaip_core::NodeAttributes* p) { - delete reinterpret_cast(p); - }; - the_global_api.node_attributes_get = [](vaip_core::NodeAttributes& p, - const std::string& name) -> ONNX_NAMESPACE::AttributeProto* { - auto& attr = reinterpret_cast(p); - auto it = attr.find(name); - if (it == attr.end()) { - return nullptr; + + the_global_api.node_attributes_delete = [](NodeAttributes* p) { delete p; }; + the_global_api.node_attributes_get = + [](const NodeAttributes& attr, const std::string& name) -> const ONNX_NAMESPACE::AttributeProto* { + if (attr.count(name)) { + return &attr.at(name); } - return &it->second; + return nullptr; }; - the_global_api.node_attributes_get_keys = - [](vaip_core::NodeAttributes& p) -> vaip_core::DllSafe> { + the_global_api.node_attributes_get_keys = [](NodeAttributes& attr) -> vaip_core::DllSafe> { auto ret = std::vector(); - auto& attr = reinterpret_cast(p); ret.reserve(attr.size()); for (auto& it : attr) { ret.push_back(it.first); @@ -396,35 +340,16 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { return vaip_core::DllSafe(std::move(ret)); }; /// tensor proto - the_global_api.tensor_proto_get_shape_unsafe = - [](const onnx::TensorProto& t) -> vaip_core::DllSafe> { - return vaip_core::DllSafe>(vaip::tensor_proto_get_shape(t)); - }; - - the_global_api.tensor_proto_data_type = [](const onnx::TensorProto& t) -> int { return t.data_type(); }; - - the_global_api.tensor_proto_delete = [](onnx::TensorProto* tp) { delete tp; }; - - the_global_api.tensor_proto_new_floats = [](const std::string& name, const std::vector& shape, - const std::vector& data) -> onnx::TensorProto* { - return new onnx::TensorProto{vaip::tensor_proto_new_floats(name, shape, data)}; - }; - the_global_api.tensor_proto_new_i32 = [](const std::string& name, const std::vector& shape, - const std::vector& data) -> onnx::TensorProto* { - return new onnx::TensorProto{vaip::tensor_proto_new_i32(name, shape, data)}; - }; - the_global_api.tensor_proto_new_i64 = [](const std::string& name, const std::vector& shape, - const std::vector& data) -> onnx::TensorProto* { - return new onnx::TensorProto{vaip::tensor_proto_new_i64(name, shape, data)}; - }; - the_global_api.tensor_proto_new_i8 = [](const std::string& name, const std::vector& shape, - const std::vector& data) -> onnx::TensorProto* { - return new onnx::TensorProto{vaip::tensor_proto_new_i8(name, shape, data)}; - }; - the_global_api.tensor_proto_raw_data_size = vaip::tensor_proto_raw_data_size; - + the_global_api.tensor_proto_get_shape_unsafe = vaip::tensor_proto_get_shape; + the_global_api.tensor_proto_data_type = [](const ONNX_NAMESPACE::TensorProto& t) -> int { return t.data_type(); }; + the_global_api.tensor_proto_delete = [](ONNX_NAMESPACE::TensorProto* tp) { delete tp; }; + the_global_api.tensor_proto_new_floats = vaip::tensor_proto_new_floats; + the_global_api.tensor_proto_new_i32 = vaip::tensor_proto_new_i32; + the_global_api.tensor_proto_new_i64 = vaip::tensor_proto_new_i64; + the_global_api.tensor_proto_new_i8 = vaip::tensor_proto_new_i8; + the_global_api.tensor_proto_raw_data_size = [](const auto& tensor) { return tensor.raw_data().size(); }; the_global_api.tensor_proto_as_raw = vaip::tensor_proto_as_raw; - the_global_api.tensor_proto_get_name = vaip::tensor_proto_get_name; + the_global_api.tensor_proto_get_name = [](const auto& tensor) -> const std::string& { return tensor.name(); }; the_global_api.get_lib_name = []() -> vaip_core::DllSafe { return vaip_core::DllSafe(std::string("onnxruntime.") + std::string(ORT_VERSION)); diff --git a/onnxruntime/core/providers/vitisai/imp/graph.cc b/onnxruntime/core/providers/vitisai/imp/graph.cc index cca680baf7dc..061bc414fcec 100644 --- a/onnxruntime/core/providers/vitisai/imp/graph.cc +++ b/onnxruntime/core/providers/vitisai/imp/graph.cc @@ -2,27 +2,15 @@ // Licensed under the MIT License. #include "vaip/graph.h" -#include - -#include "./vai_assert.h" #include #include #include #include #include #include -#include "onnx/onnx-ml.pb.h" -#ifdef _MSC_VER -#pragma warning(push) -// 'type' : forcing value to bool 'true' or 'false' (performance warning) -#pragma warning(disable : 4800) -#endif -#include -#ifdef _MSC_VER -#pragma warning(pop) -#endif -using convert_t = std::codecvt_utf8; -std::wstring_convert strconverter; + +#include "core/providers/shared_library/provider_api.h" +#include "./vai_assert.h" #include "vaip/node.h" #include "vaip/node_arg.h" @@ -38,23 +26,14 @@ struct NodeEdgeT { static void graph_remove_node(Graph& graph, const Node& node) { auto remove_edges = std::vector(); - auto begin = node.InputEdgesBegin(); - auto end = node.InputEdgesEnd(); - for (auto it = begin; it != end; ++it) { - remove_edges.push_back(NodeEdgeT{it->GetNode().Index(), node.Index(), - it->GetSrcArgIndex(), - it->GetDstArgIndex()}); + for (auto it = node.InputEdgesBegin(); it != node.InputEdgesEnd(); ++it) { + remove_edges.push_back(NodeEdgeT{it->GetNode().Index(), node.Index(), it->GetSrcArgIndex(), it->GetDstArgIndex()}); } - begin = node.OutputEdgesBegin(); - end = node.OutputEdgesEnd(); - for (auto it = begin; it != end; ++it) { - remove_edges.push_back(NodeEdgeT{node.Index(), it->GetNode().Index(), - it->GetSrcArgIndex(), - it->GetDstArgIndex()}); + for (auto it = node.OutputEdgesBegin(); it != node.OutputEdgesEnd(); ++it) { + remove_edges.push_back(NodeEdgeT{node.Index(), it->GetNode().Index(), it->GetSrcArgIndex(), it->GetDstArgIndex()}); } for (auto it : remove_edges) { - graph.RemoveEdge(it.src_node_index, it.dst_node_index, it.src_arg_index, - it.dst_arg_index); + graph.RemoveEdge(it.src_node_index, it.dst_node_index, it.src_arg_index, it.dst_arg_index); } graph.RemoveNode(node.Index()); } @@ -68,13 +47,9 @@ static std::vector node_get_implicit_input_node_args(const Node& } return ret; } - -Node& graph_add_node(Graph& graph, const std::string& name, - const std::string& op_type, const std::string& description, - const std::vector& input_args, - const std::vector& output_args, - const NodeAttributes& attributes, - const std::string& domain) { +Node& graph_add_node(Graph& graph, const std::string& name, const std::string& op_type, const std::string& description, + const std::vector& input_args, const std::vector& output_args, + const NodeAttributes& attributes, const std::string& domain) { std::vector inputs; inputs.reserve(input_args.size()); for (auto i : input_args) { @@ -85,8 +60,7 @@ Node& graph_add_node(Graph& graph, const std::string& name, for (auto i : output_args) { outputs.push_back(const_cast(i)); } - auto& ret = graph.AddNode(name, op_type, description, inputs, outputs, - &attributes, domain); + auto& ret = graph.AddNode(name, op_type, description, inputs, outputs, &attributes, domain); auto src_arg_index = 0; for (auto& o : outputs) { auto consumers = graph.GetConsumerNodes(o->Name()); @@ -96,8 +70,7 @@ Node& graph_add_node(Graph& graph, const std::string& name, for (auto ni : *tmp_inputs) { auto name1 = ni.node_arg->Name(); if (name1 == o->Name()) { - graph.AddEdge(ret.Index(), consumer->Index(), src_arg_index, - dst_arg_index); + graph.AddEdge(ret.Index(), consumer->Index(), src_arg_index, dst_arg_index); } dst_arg_index = dst_arg_index + 1; } @@ -105,8 +78,7 @@ Node& graph_add_node(Graph& graph, const std::string& name, for (auto implicit_node_arg : node_get_implicit_input_node_args(*consumer)) { auto name1 = implicit_node_arg->Name(); if (name1 == o->Name()) { - graph.AddEdge(ret.Index(), consumer->Index(), src_arg_index, - dst_arg_index); + graph.AddEdge(ret.Index(), consumer->Index(), src_arg_index, dst_arg_index); } dst_arg_index = dst_arg_index + 1; } @@ -132,44 +104,39 @@ void graph_remove_node(Graph& graph, const NodeInput& node_input) { void graph_save(const Graph& graph, const std::string& filename, const std::string& filename_dat, size_t initializer_size_threshold) { auto& model = const_cast(graph.GetModel()); - auto model_proto = ONNX_NAMESPACE::ModelProto(); + std::unique_ptr model_proto; if (initializer_size_threshold == std::numeric_limits::max()) { model_proto = model.ToProto(); } else { - model_proto = model.ToGraphProtoWithExternalInitializers(filename_dat, - ToPathString(filename), - initializer_size_threshold); + model_proto = model.ToGraphProtoWithExternalInitializers(filename_dat, graph.ModelPath().ToPathString(), initializer_size_threshold); } auto& metadata = model.MetaData(); if (!metadata.empty()) { - model_proto.mutable_metadata_props()->Clear(); + auto metadata_props = model_proto->mutable_metadata_props(); + metadata_props->Clear(); for (auto& m : metadata) { - auto prop = model_proto.mutable_metadata_props()->Add(); + auto prop = metadata_props->Add(); *prop->mutable_key() = m.first; *prop->mutable_value() = m.second; } } // use relative path as data storage. - auto graph_proto = model_proto.mutable_graph(); - *graph_proto = graph.ToGraphProto(); - for (auto i = 0; i < graph_proto->initializer_size(); ++i) { - auto initializer = graph_proto->mutable_initializer(i); - for (auto j = 0; j < initializer->external_data_size(); ++j) { - auto external_data = initializer->mutable_external_data(j); - if (external_data->key() == "location") { - *external_data->mutable_value() = std::filesystem::path(external_data->value()).filename().u8string(); - } + auto graph_proto = model_proto->mutable_graph(); + *graph_proto = *graph.ToGraphProto(); + for (int i = 0; i < graph_proto->mutable_initializer()->size(); i++) { + auto mutable_external_data = graph_proto->mutable_initializer()->at(i).mutable_external_data(); + for (int j = 0; j < mutable_external_data->size(); j++) { + auto& external_data = mutable_external_data->at(j); + if (*external_data.mutable_key() == "location") + *external_data.mutable_value() = std::filesystem::path(*external_data.mutable_value()).filename().u8string(); } } - int fd = -1; - Status status = Env::Default().FileOpenWr(filename, fd); - vai_assert(status.IsOK(), status.ErrorMessage()); - google::protobuf::io::FileOutputStream output(fd); - const bool result = model_proto.SerializeToZeroCopyStream(&output) && output.Flush(); - vai_assert(result, "model serialize to zero cipy stream error"); - status = Env::Default().FileClose(fd); - vai_assert(status.IsOK(), status.ErrorMessage()); + + std::fstream output(filename, std::ios::out | std::ios::trunc | std::ios::binary); + bool result = model_proto->SerializeToOstream(output); + output << std::flush; + vai_assert(result, "model serialize to ostream error"); } Node& graph_fuse(Graph& graph, const std::string& name, @@ -178,25 +145,25 @@ Node& graph_fuse(Graph& graph, const std::string& name, const std::vector& inputs, const std::vector& outputs, const std::vector& constant_initializers) { - auto meta_def = std::make_unique(); - auto indexed_subgraph = std::make_unique(); - indexed_subgraph->nodes = nodes; - meta_def->inputs = inputs; - meta_def->outputs = outputs; - meta_def->constant_initializers = constant_initializers; - meta_def->name = "super_layer"; - meta_def->domain = "com.xilinx"; - meta_def->since_version = 1; - meta_def->status = ONNX_NAMESPACE::EXPERIMENTAL; + auto meta_def = IndexedSubGraph_MetaDef::Create(); + meta_def->inputs() = inputs; + meta_def->outputs() = outputs; + meta_def->constant_initializers() = constant_initializers; + meta_def->name() = "super_layer"; + meta_def->domain() = "com.xilinx"; + meta_def->since_version() = 1; + meta_def->status() = ONNX_NAMESPACE::EXPERIMENTAL; + + auto indexed_subgraph = IndexedSubGraph::Create(); + indexed_subgraph->Nodes() = nodes; indexed_subgraph->SetMetaDef(std::move(meta_def)); + auto& fused_node = graph.FuseSubGraph(*indexed_subgraph, name); auto function_body = fused_node.GetFunctionBody(); if (function_body) { - auto& mygraph = function_body->Body(); - // auto proto = graph.ToGraphProtoWithExternal("exteranl.dat", 128); - auto proto = mygraph.ToGraphProto(); - *proto.mutable_name() = name; - fused_node.AddAttribute("body", proto); + auto proto = function_body->Body().ToGraphProto(); + *proto->mutable_name() = name; + fused_node.AddAttribute("body", *proto); } for (auto&& o : fused_node.OutputDefs()) { graph.UpdateProducerNode(o->Name(), fused_node.Index()); diff --git a/onnxruntime/core/providers/vitisai/imp/node.cc b/onnxruntime/core/providers/vitisai/imp/node.cc index 6d65ad4e8c40..0565171fb7f4 100644 --- a/onnxruntime/core/providers/vitisai/imp/node.cc +++ b/onnxruntime/core/providers/vitisai/imp/node.cc @@ -4,9 +4,8 @@ #include "./vai_assert.h" #include "attr_proto.h" -#include "core/graph/graph_utils.h" -#include "core/graph/node_arg.h" #include "vaip/node_arg.h" +#include "core/providers/shared_library/provider_api.h" namespace vaip { @@ -29,7 +28,6 @@ vaip_core::DllSafe> node_get_inputs(const Node& node) { } return vaip_core::DllSafe(ret); } - vaip_core::DllSafe> node_get_output_node_args(const Node& node) { auto outputs = node.OutputDefs(); auto size = outputs.size(); @@ -42,11 +40,4 @@ vaip_core::DllSafe> node_get_output_node_args(const } return vaip_core::DllSafe(ret); } - -vaip_core::DllSafe> node_get_output_shape(const Node& node, int index) { - auto outputs = node.OutputDefs(); - assert((size_t)index < outputs.size()); - return node_arg_get_shape_i64(*outputs[index]); -} - } // namespace vaip diff --git a/onnxruntime/core/providers/vitisai/imp/node_arg.cc b/onnxruntime/core/providers/vitisai/imp/node_arg.cc index 3bdeb09698d4..a54cbef91c39 100644 --- a/onnxruntime/core/providers/vitisai/imp/node_arg.cc +++ b/onnxruntime/core/providers/vitisai/imp/node_arg.cc @@ -2,25 +2,16 @@ // Licensed under the MIT License. #include "vaip/node_arg.h" #include "./vai_assert.h" - -#include +#include "core/providers/shared_library/provider_api.h" #include "./tensor_proto.h" -#include "core/graph/node_arg.h" namespace vaip { - -bool node_arg_is_exists(const NodeArg& node_arg) { - return node_arg.Exists(); -} bool node_arg_is_constant(const Graph& graph, const NodeArg& node_arg) { assert(node_arg.Exists()); assert(!node_arg.Name().empty()); - auto constant_tensor_proto = - graph.GetConstantInitializer(node_arg.Name(), true); - return constant_tensor_proto != nullptr; + return graph.GetConstantInitializer(node_arg.Name(), true) != nullptr; } - vaip_core::DllSafe> node_arg_get_shape_i64(const NodeArg& node_arg) { auto shape = node_arg.Shape(); if (nullptr == shape) return vaip_core::DllSafe>(); @@ -32,104 +23,42 @@ vaip_core::DllSafe> node_arg_get_shape_i64(const NodeArg& n } return vaip_core::DllSafe(shape_vector); } - -static void LayoutTransformRule_set_shape(onnx::TensorShapeProto& shape_proto, - const std::vector& shape) { - assert(shape.size() == static_cast(shape_proto.dim_size())); - auto rank = shape_proto.dim_size(); +void node_arg_set_shape_i64(const NodeArg& node_arg, const std::vector& shape) { + auto shape_proto = const_cast(node_arg.Shape()); + assert(shape_proto != nullptr); + assert(shape.size() == static_cast(shape_proto->dim_size())); + auto rank = shape_proto->dim_size(); for (auto i = 0; i < rank; ++i) { - shape_proto.mutable_dim(i)->set_dim_value(shape[i]); + shape_proto->mutable_dim(i)->set_dim_value(shape[i]); } } - -static void LayoutTransformRule_set_shape(onnx::TypeProto& type_proto, - const std::vector& shape) { - assert(type_proto.value_case() == onnx::TypeProto::kTensorType); - //<< type_proto.DebugString(); - auto& tensor_type = *type_proto.mutable_tensor_type(); - auto& shape_prot = *tensor_type.mutable_shape(); - return LayoutTransformRule_set_shape(shape_prot, shape); -} - -static void LayoutTransformRule_set_shape(NodeArg* node_arg, - const std::vector& shape) { - assert(node_arg != nullptr); - auto* type_proto = node_arg->TypeAsProto(); - assert(type_proto != nullptr); - return LayoutTransformRule_set_shape( - *const_cast(type_proto), shape); -} - -void node_arg_set_shape_i64(const NodeArg& node_arg, - const std::vector& shape) { - LayoutTransformRule_set_shape(const_cast(&node_arg), shape); -} - -static std::vector LayoutTransformRule_get_denotation( - const onnx::TensorShapeProto& shape) { +vaip_core::DllSafe> node_arg_get_denotation(const NodeArg& node_arg) { + auto shape = node_arg.Shape(); + if (shape == nullptr) { + return vaip_core::DllSafe>(); + } auto ret = std::vector(); - auto rank = shape.dim_size(); - ret.reserve(rank); + auto rank = shape->dim_size(); for (auto i = 0; i < rank; ++i) { - auto& d = shape.dim(i).denotation(); - ret.push_back(d); + ret.push_back(shape->dim(i).denotation()); } - return ret; + return vaip_core::DllSafe>(ret); } - -static vaip_core::DllSafe> LayoutTransformRule_get_denotation( - const onnx::TypeProto& type_proto) { - vai_assert(type_proto.value_case() == onnx::TypeProto::kTensorType, type_proto.DebugString()); - auto& tensor_type = type_proto.tensor_type(); - if (!tensor_type.has_shape()) { - return vaip_core::DllSafe>(); - } - auto& shape = tensor_type.shape(); - auto denotation = LayoutTransformRule_get_denotation(shape); - return vaip_core::DllSafe>(denotation); -} - -static vaip_core::DllSafe> LayoutTransformRule_get_denotation( - const NodeArg* node_arg) { - assert(node_arg != nullptr); - auto* type_proto = node_arg->TypeAsProto(); - assert(type_proto != nullptr); - return LayoutTransformRule_get_denotation(*type_proto); -} - -vaip_core::DllSafe> node_arg_get_denotation(const NodeArg& node_arg) { - return LayoutTransformRule_get_denotation(&node_arg); -} - -static onnx::TensorShapeProto* node_arg_get_tensor_mutable_shape( - NodeArg* node_arg) { - assert(node_arg != nullptr); - auto type_proto = const_cast(node_arg->TypeAsProto()); - assert(type_proto != nullptr); - vai_assert(type_proto->value_case() == onnx::TypeProto::kTensorType, - type_proto->DebugString()); - return type_proto->mutable_tensor_type()->mutable_shape(); -} - -static void LayoutTransformRule_set_denotation( - onnx::TensorShapeProto& shape, const std::vector& denotation) { - assert(denotation.size() == static_cast(shape.dim_size())); - auto rank = shape.dim_size(); +void node_arg_set_denotation(const NodeArg& node_arg, const std::vector& denotation) { + auto shape_proto = const_cast(node_arg.Shape()); + assert(shape_proto != nullptr); + assert(denotation.size() == static_cast(shape_proto->dim_size())); + auto rank = shape_proto->dim_size(); for (auto i = 0; i < rank; ++i) { - shape.mutable_dim(i)->set_denotation(denotation[i]); + shape_proto->mutable_dim(i)->set_denotation(denotation[i]); } } -void node_arg_set_denotation(const NodeArg& node_arg, - const std::vector& denotation) { - auto mutable_shape = - node_arg_get_tensor_mutable_shape(const_cast(&node_arg)); - - return LayoutTransformRule_set_denotation(*mutable_shape, denotation); -} - -void node_arg_set_element_type(NodeArg& node_arg, - onnx::TensorProto::DataType data_type) { - auto type_proto = const_cast(node_arg.TypeAsProto()); +void node_arg_set_element_type(NodeArg& node_arg, int type) { + if (type < 0 || type > 16) { + vai_assert(false, "TensorProto::DataType not supoort"); + } + auto data_type = static_cast(type); + auto type_proto = const_cast(node_arg.TypeAsProto()); assert(type_proto != nullptr); auto current_elem_type = type_proto->mutable_tensor_type()->elem_type(); auto input_elem_type = data_type; @@ -138,24 +67,12 @@ void node_arg_set_element_type(NodeArg& node_arg, current_elem_type, true); vai_assert(status.IsOK(), status.ErrorMessage()); } -void node_arg_set_shape(NodeArg& node_arg, std::vector shape) { - auto type_proto = const_cast(node_arg.TypeAsProto()); - assert(type_proto != nullptr); - for (auto i = 0u; i < shape.size(); i++) { - type_proto->mutable_tensor_type() - ->mutable_shape() - ->mutable_dim(i) - ->set_dim_value(shape[i]); - } -} - const ONNX_NAMESPACE::TensorProto& node_arg_get_const_data_as_tensor( const Graph& graph, const NodeArg& node_arg) { auto tensor_proto = graph.GetConstantInitializer(node_arg.Name(), true); assert(tensor_proto != nullptr); return *tensor_proto; } - int node_arg_get_element_type(const NodeArg& node_arg) { auto type_proto = node_arg.TypeAsProto(); assert(type_proto != nullptr); @@ -164,9 +81,7 @@ int node_arg_get_element_type(const NodeArg& node_arg) { } return type_proto->tensor_type().elem_type(); } - -NodeArg& node_arg_clone(Graph& graph, const NodeArg& node_arg, - const std::string& name) { +NodeArg& node_arg_clone(Graph& graph, const NodeArg& node_arg, const std::string& name) { vai_assert(name != node_arg.Name(), "node arg must have a new unique name"); vai_assert(graph.GetNodeArg(name) == nullptr, std::string("node arg " + name + " already exists. ")); auto type_proto = node_arg.TypeAsProto(); @@ -174,12 +89,10 @@ NodeArg& node_arg_clone(Graph& graph, const NodeArg& node_arg, auto& ret = graph.GetOrCreateNodeArg(name, type_proto); return ret; } - -NodeArg& node_arg_new(Graph& graph, - const std::string& name, const std::vector* shape, int element_type) { +NodeArg& node_arg_new(Graph& graph, const std::string& name, const std::vector* shape, int element_type) { vai_assert(graph.GetNodeArg(name) == nullptr, std::string("node arg " + name + " already exists. ")); - auto type_proto = onnx::TypeProto(); - auto tensor_type = type_proto.mutable_tensor_type(); + auto type_proto = ONNX_NAMESPACE::TypeProto::Create(); + auto tensor_type = type_proto->mutable_tensor_type(); tensor_type->set_elem_type(element_type); if (shape != nullptr) { auto shape_proto = tensor_type->mutable_shape(); @@ -189,8 +102,6 @@ NodeArg& node_arg_new(Graph& graph, } else { assert(tensor_type->has_shape() == false); } - auto& ret = graph.GetOrCreateNodeArg(name, &type_proto); - return ret; + return graph.GetOrCreateNodeArg(name, type_proto.release()); } - } // namespace vaip diff --git a/onnxruntime/core/providers/vitisai/imp/node_attrs.cc b/onnxruntime/core/providers/vitisai/imp/node_attrs.cc deleted file mode 100644 index e438266e2a4c..000000000000 --- a/onnxruntime/core/providers/vitisai/imp/node_attrs.cc +++ /dev/null @@ -1,114 +0,0 @@ -// Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. -// Licensed under the MIT License. -#include "vaip/node_attrs.h" -#include "./vai_assert.h" - -namespace vaip { -static onnx::AttributeProto make_attribute(const std::string& name, - int64_t value) { - auto ret = onnx::AttributeProto(); - ret.set_name(name); - ret.set_type(onnx::AttributeProto::INT); - ret.set_i(value); - return ret; -} - -static onnx::AttributeProto make_attribute(const std::string& name, - const std::vector value) { - auto ret = onnx::AttributeProto(); - ret.set_name(name); - ret.set_type(onnx::AttributeProto::INTS); - for (auto v : value) { - ret.add_ints(v); - } - return ret; -} - -static onnx::AttributeProto make_attribute(const std::string& name, - const std::string& value) { - auto ret = onnx::AttributeProto(); - ret.set_name(name); - ret.set_type(onnx::AttributeProto::STRING); - ret.set_s(value); - return ret; -} -static onnx::AttributeProto make_attribute( - const std::string& name, const std::vector& value) { - auto ret = onnx::AttributeProto(); - ret.set_name(name); - ret.set_type(onnx::AttributeProto::STRINGS); - for (auto v : value) { - ret.add_strings(v); - } - return ret; -} - -static onnx::AttributeProto make_attribute(const std::string& name, - const std::vector& value) { - auto ret = onnx::AttributeProto(); - ret.set_name(name); - ret.set_type(onnx::AttributeProto::FLOATS); - for (auto v : value) { - ret.add_floats(v); - } - return ret; -} - -static onnx::AttributeProto make_attribute(const std::string& name, - const onnx::TensorProto& value) { - auto ret = onnx::AttributeProto(); - ret.set_name(name); - ret.set_type(onnx::AttributeProto::TENSOR); - *(ret.mutable_t()) = std::move(value); - return ret; -} // namespace vaip - -NodeAttr::NodeAttr(const std::string& name, int64_t value) - : attribute_proto_{make_attribute(name, value)} {} - -NodeAttr::NodeAttr(const std::string& name, const std::vector& value) - : attribute_proto_{make_attribute(name, value)} {} - -NodeAttr::NodeAttr(const std::string& name, const std::string& value) - : attribute_proto_{make_attribute(name, value)} {} - -NodeAttr::NodeAttr(const std::string& name, - const std::vector& value) - : attribute_proto_{make_attribute(name, value)} {} - -NodeAttr::NodeAttr(const std::string& name, const std::vector& value) - : attribute_proto_{make_attribute(name, value)} {} - -NodeAttr::NodeAttr(const std::string& name, const onnx::TensorProto& value) - : attribute_proto_{make_attribute(name, value)} {} - -onnx::AttributeProto& NodeAttr::get() { return attribute_proto_; } - -NodeAttributesBuiler::NodeAttributesBuiler(size_t capacity) : attrs_{} { - attrs_.reserve(capacity); -} - -NodeAttributes NodeAttributesBuiler::build() { - auto ret = NodeAttributes(); - ret.reserve(attrs_.size()); - for (auto& node_attr : attrs_) { - onnx::AttributeProto& attr_proto = node_attr.get(); - auto name = attr_proto.name(); - ret.insert(std::make_pair(name, std::move(attr_proto))); - } - attrs_.clear(); - return ret; -} - -void NodeAttributesBuiler::merge_into(Node& node) { - merge_into(node.GetMutableAttributes()); -} - -void NodeAttributesBuiler::merge_into(NodeAttributes& attrs) { - for (auto& attr : attrs_) { - vai_assert(attr.get().has_name(), std::string("attr must has name " + attr.get().DebugString())); - auto name = attr.get().name(); - attrs.insert_or_assign(std::move(name), std::move(attr.get())); - } -} -} // namespace vaip diff --git a/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc b/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc index ee8dfc6d03d1..97ed2d3b4b8a 100644 --- a/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc +++ b/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc @@ -1,130 +1,25 @@ - - // Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. // Licensed under the MIT License. + #include "./register_xir_ops.h" #include "./vai_assert.h" - -#include "core/common/logging/logging.h" -#include "core/common/status.h" - -#include "core/framework/customregistry.h" - +#include "core/providers/shared_library/provider_api.h" #include "core/session/onnxruntime_c_api.h" -#include "core/session/custom_ops.h" -#include "core/session/inference_session.h" -#include "onnx/defs/schema.h" -#include "onnx/defs/shape_inference.h" using namespace onnxruntime; -namespace vaip { - -static void xir_shape_infer(ONNX_NAMESPACE::InferenceContext& ctx) { - auto* shape = ctx.getAttribute("shape"); - auto* data_type = ctx.getAttribute("data_type"); - if (data_type->s() == "float32") { - updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::FLOAT); - } else if (data_type->s() == "int8") { - updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::INT8); - } else if (data_type->s() == "uint8") { - updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::UINT8); - } else if (data_type->s() == "int32") { - updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::INT32); - } else if (data_type->s() == "int64") { - updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::INT64); - } else if (data_type->s() == "int1") { - updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::BOOL); - } else if (data_type->s() == "bfloat16") { - updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::BFLOAT16); - } else if (data_type->s() == "float16") { - updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::FLOAT16); - } else { - vai_assert(false, ", not supported data_type: " + data_type->s()); - } - if (shape != nullptr) { - for (auto i = 0; i < shape->ints_size(); ++i) { - ONNX_NAMESPACE::appendDim(ONNX_NAMESPACE::getOutputShape(ctx, 0), shape->ints(i)); - } - } else { - // set scalar type. - auto* output_shape = ONNX_NAMESPACE::getOutputShape(ctx, 0); - output_shape->clear_dim(); - } - return; -} - -static void xir_fixneuron_shape_inference(ONNX_NAMESPACE::InferenceContext& ctx) { - ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); - ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, 0, 0); -} - -static void xir_subgraph_shape_inference(ONNX_NAMESPACE::InferenceContext& ctx) { - auto num_inputs = ctx.getNumInputs(); - - // Run inferencing on the subgraph - ONNX_NAMESPACE::GraphInferencer* graphInferencer = ctx.getGraphAttributeInferencer("body"); - if (!graphInferencer) { - fail_type_inference("body is missing."); - } - - std::vector input_data; - std::vector subgraph_input_types; - for (size_t i = 0; i < num_inputs; ++i) { - input_data.push_back(ctx.getInputData(i)); - subgraph_input_types.push_back(ctx.getInputType(i)); - } - std::vector output_types; - output_types = - graphInferencer->doInferencing(subgraph_input_types, input_data); - - auto num_outputs = ctx.getNumOutputs(); - auto num_of_the_subgraph_outputs = output_types.size(); - if (num_outputs != num_of_the_subgraph_outputs) { - fail_type_inference("super layer has ", num_outputs, - " but subgraphs produce ", num_of_the_subgraph_outputs); - } - for (size_t i = 0, end = output_types.size(); i < end; ++i) { - auto subgraph_output = output_types[i]; - auto* super_layer_output = ctx.getOutputType(i); - *super_layer_output = *subgraph_output; - } -} +namespace vaip { void register_xir_ops(const std::vector& domains) { - std::shared_ptr custom_registry; - auto status = CreateCustomRegistry(gsl::span(domains), custom_registry); - vai_assert(status.IsOK(), status.ErrorMessage()); for (auto domain : domains) { for (auto op : domain->custom_ops_) { auto name = op->GetName(op); - auto schema1 = custom_registry->GetOpschemaRegistry()->GetSchema(name, ORT_API_VERSION, domain->domain_); - auto schema2 = ::ONNX_NAMESPACE::OpSchema(); - schema2.SetName(schema1->Name()); - schema2.SetDomain(schema1->domain()); - auto n = 0; - for (auto input : schema1->inputs()) { - schema2.Input(n, input.GetName(), input.GetDescription(), std::string("T") + std::to_string(n), input.GetOption(), false, input.GetMinArity(), input.GetDifferentiationCategory()); - schema2.TypeConstraint(std::string("T") + std::to_string(n), DataTypeImpl::ToString(DataTypeImpl::AllTensorTypes()), "all types"); - n = n + 1; - } - auto m = n; - n = 0; - for (auto output : schema1->outputs()) { - auto type_str = std::string("T") + std::to_string(n + m); - schema2.Output(n, output.GetName(), output.GetDescription(), type_str, output.GetOption(), false, output.GetMinArity(), output.GetDifferentiationCategory()); - schema2.TypeConstraint(type_str, DataTypeImpl::ToString(DataTypeImpl::AllTensorTypes()), "all types"); - n = n + 1; - } - schema2.SinceVersion(1); - schema2.AllowUncheckedAttributes(); if ((std::string)name == "super_layer") { - schema2.TypeAndShapeInferenceFunction(xir_subgraph_shape_inference); + Provider_GetHost()->RegisterSchema(domain->domain_, op, 1); } else if ((std::string)name == "FixNeuron") { - schema2.TypeAndShapeInferenceFunction(xir_fixneuron_shape_inference); + Provider_GetHost()->RegisterSchema(domain->domain_, op, 2); } else { - schema2.TypeAndShapeInferenceFunction(xir_shape_infer); + Provider_GetHost()->RegisterSchema(domain->domain_, op, 3); } - ONNX_NAMESPACE::RegisterSchema(schema2, ORT_API_VERSION); } } } diff --git a/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc b/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc index db03354bf4c4..48dcd220a150 100644 --- a/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc +++ b/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc @@ -1,20 +1,19 @@ // Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. // Licensed under the MIT License. #include "./tensor_proto.h" -#include "./vai_assert.h" -#include "core/framework/tensorprotoutils.h" #include #include +#include "./vai_assert.h" +#include "core/providers/shared_library/provider_api.h" namespace vaip { - -gsl::span tensor_proto_as_raw( - const ONNX_NAMESPACE::TensorProto& tensor) { +gsl::span tensor_proto_as_raw(const ONNX_NAMESPACE::TensorProto& tensor) { auto& mut_tensor = const_cast(tensor); if (!tensor.has_raw_data()) { std::vector unpacked_tensor; - auto s = onnxruntime::utils::UnpackInitializerData(tensor, onnxruntime::Path(), unpacked_tensor); + auto path = onnxruntime::Path::Create(); + auto s = onnxruntime::utils::UnpackInitializerData(tensor, *path, unpacked_tensor); mut_tensor.mutable_raw_data()->resize(unpacked_tensor.size()); mut_tensor.clear_float_data(); mut_tensor.clear_int32_data(); @@ -27,78 +26,51 @@ gsl::span tensor_proto_as_raw( return gsl::span(tensor.raw_data().data(), tensor.raw_data().size()); } -size_t tensor_proto_raw_data_size(const ONNX_NAMESPACE::TensorProto& tensor) { - return tensor.raw_data().size(); -} - -std::vector tensor_proto_get_shape( - const onnx::TensorProto& tensor_proto) { +vaip_core::DllSafe> tensor_proto_get_shape(const ONNX_NAMESPACE::TensorProto& tensor_proto) { auto ret = std::vector(); int rank = tensor_proto.dims_size(); if (rank > 0) { - ret.reserve((size_t)rank); - for (auto i = 0; i < rank; ++i) { - ret.push_back(tensor_proto.dims(i)); + auto& dims = tensor_proto.dims(); + for (auto i = 0; i < dims.size(); ++i) { + ret.push_back(dims[i]); } } - return ret; + return vaip_core::DllSafe(ret); } - -const std::string& tensor_proto_get_name( - const ONNX_NAMESPACE::TensorProto& tensor) { - return tensor.name(); +static ONNX_NAMESPACE::TensorProto* tensor_proto_new(const std::string& name, const std::vector& shape, + int data_type, const char* data, size_t data_size) { + auto tensor_proto = ONNX_NAMESPACE::TensorProto::Create(); + tensor_proto->set_name(name); + for (auto s : shape) { + tensor_proto->add_dims(s); + } + tensor_proto->set_data_type(data_type); + tensor_proto->mutable_raw_data()->assign(data, data_size); + return tensor_proto.release(); } -ONNX_NAMESPACE::TensorProto tensor_proto_new_i32( - const std::string& name, const std::vector& shape, - const std::vector& data) { - auto tensor_proto = ONNX_NAMESPACE::TensorProto(); - tensor_proto.set_name(name); - tensor_proto.mutable_dims()->Clear(); - tensor_proto.mutable_dims()->Add(shape.begin(), shape.end()); - tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto::INT32); - tensor_proto.mutable_raw_data()->assign( - reinterpret_cast(&data[0]), data.size() * sizeof(int32_t)); - return tensor_proto; +ONNX_NAMESPACE::TensorProto* tensor_proto_new_i32(const std::string& name, const std::vector& shape, + const std::vector& data) { + return tensor_proto_new(name, shape, ONNX_NAMESPACE::TensorProto_DataType_INT32, + reinterpret_cast(&data[0]), data.size() * sizeof(int32_t)); } -ONNX_NAMESPACE::TensorProto tensor_proto_new_i64( - const std::string& name, const std::vector& shape, - const std::vector& data) { - auto tensor_proto = ONNX_NAMESPACE::TensorProto(); - tensor_proto.set_name(name); - tensor_proto.mutable_dims()->Clear(); - tensor_proto.mutable_dims()->Add(shape.begin(), shape.end()); - tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto::INT64); - tensor_proto.mutable_raw_data()->assign( - reinterpret_cast(&data[0]), data.size() * sizeof(int64_t)); - return tensor_proto; +ONNX_NAMESPACE::TensorProto* tensor_proto_new_i64(const std::string& name, const std::vector& shape, + const std::vector& data) { + return tensor_proto_new(name, shape, ONNX_NAMESPACE::TensorProto_DataType_INT64, + reinterpret_cast(&data[0]), data.size() * sizeof(int64_t)); } -ONNX_NAMESPACE::TensorProto tensor_proto_new_i8( - const std::string& name, const std::vector& shape, - const std::vector& data) { - auto tensor_proto = ONNX_NAMESPACE::TensorProto(); - tensor_proto.set_name(name); - tensor_proto.mutable_dims()->Clear(); - tensor_proto.mutable_dims()->Add(shape.begin(), shape.end()); - tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto::INT8); - tensor_proto.mutable_raw_data()->assign( - reinterpret_cast(&data[0]), data.size() * sizeof(int8_t)); - return tensor_proto; +ONNX_NAMESPACE::TensorProto* tensor_proto_new_i8(const std::string& name, const std::vector& shape, + const std::vector& data) { + return tensor_proto_new(name, shape, ONNX_NAMESPACE::TensorProto_DataType_INT8, + reinterpret_cast(&data[0]), data.size() * sizeof(int8_t)); } -ONNX_NAMESPACE::TensorProto tensor_proto_new_floats( - const std::string& name, const std::vector& shape, - const std::vector& data) { - auto tensor_proto = ONNX_NAMESPACE::TensorProto(); - tensor_proto.set_name(name); - tensor_proto.mutable_dims()->Clear(); - tensor_proto.mutable_dims()->Add(shape.begin(), shape.end()); - tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto::FLOAT); - tensor_proto.mutable_raw_data()->assign( - reinterpret_cast(&data[0]), data.size() * sizeof(float)); - return tensor_proto; +ONNX_NAMESPACE::TensorProto* tensor_proto_new_floats(const std::string& name, const std::vector& shape, + const std::vector& data) { + return tensor_proto_new(name, shape, ONNX_NAMESPACE::TensorProto_DataType_FLOAT, + reinterpret_cast(&data[0]), data.size() * sizeof(float)); } } // namespace vaip diff --git a/onnxruntime/core/providers/vitisai/imp/tensor_proto.h b/onnxruntime/core/providers/vitisai/imp/tensor_proto.h index 00aa388c809c..292905ca734f 100644 --- a/onnxruntime/core/providers/vitisai/imp/tensor_proto.h +++ b/onnxruntime/core/providers/vitisai/imp/tensor_proto.h @@ -1,31 +1,20 @@ // Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. // Licensed under the MIT License. #pragma once -// -#include "core/common/gsl.h" -#include "onnx/onnx_pb.h" -namespace vaip { - -gsl::span tensor_proto_as_raw( - const ONNX_NAMESPACE::TensorProto& tensor); -size_t tensor_proto_raw_data_size(const ONNX_NAMESPACE::TensorProto& tensor); - -std::vector tensor_proto_get_shape( - const ONNX_NAMESPACE::TensorProto& tensor); -const std::string& tensor_proto_get_name( - const ONNX_NAMESPACE::TensorProto& tensor); -ONNX_NAMESPACE::TensorProto tensor_proto_new_i8( - const std::string& name, const std::vector& shape, - const std::vector& data); -ONNX_NAMESPACE::TensorProto tensor_proto_new_i32( - const std::string& name, const std::vector& shape, - const std::vector& data); -ONNX_NAMESPACE::TensorProto tensor_proto_new_i64( - const std::string& name, const std::vector& shape, - const std::vector& data); - -ONNX_NAMESPACE::TensorProto tensor_proto_new_floats( - const std::string& name, const std::vector& shape, - const std::vector& data); +#include "vaip/my_ort.h" +#include "vaip/vaip_gsl.h" +#include "vaip/dll_safe.h" +namespace vaip { +gsl::span tensor_proto_as_raw(const ONNX_NAMESPACE::TensorProto& tensor); +vaip_core::DllSafe> tensor_proto_get_shape(const ONNX_NAMESPACE::TensorProto& tensor); +const std::string& tensor_proto_get_name(const ONNX_NAMESPACE::TensorProto& tensor); +ONNX_NAMESPACE::TensorProto* tensor_proto_new_i8(const std::string& name, const std::vector& shape, + const std::vector& data); +ONNX_NAMESPACE::TensorProto* tensor_proto_new_i32(const std::string& name, const std::vector& shape, + const std::vector& data); +ONNX_NAMESPACE::TensorProto* tensor_proto_new_i64(const std::string& name, const std::vector& shape, + const std::vector& data); +ONNX_NAMESPACE::TensorProto* tensor_proto_new_floats(const std::string& name, const std::vector& shape, + const std::vector& data); } // namespace vaip diff --git a/onnxruntime/core/providers/vitisai/include/vaip/capability.h b/onnxruntime/core/providers/vitisai/include/vaip/capability.h index d6b5ae34decc..e7644dbe8635 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/capability.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/capability.h @@ -2,8 +2,7 @@ // Licensed under the MIT License. #pragma once -#include "core/framework/compute_capability.h" -#include "core/graph/graph_viewer.h" +#include "core/providers/shared_library/provider_api.h" #include "vaip/custom_op.h" namespace vaip { using namespace ::onnxruntime; diff --git a/onnxruntime/core/providers/vitisai/include/vaip/global_api.h b/onnxruntime/core/providers/vitisai/include/vaip/global_api.h index c446ab3aefcc..1f8b8802e86b 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/global_api.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/global_api.h @@ -2,16 +2,15 @@ // Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. // Licensed under the MIT License. #pragma once -#include -#include -#include - +#include "core/providers/shared_library/provider_api.h" +#define ORT_API_MANUAL_INIT #include "core/session/onnxruntime_cxx_api.h" #include "core/framework/provider_options.h" #include "vaip/my_ort.h" #include "vaip/dll_safe.h" #include "vaip/custom_op.h" -std::vector initialize_vitisai_ep(); -vaip_core::DllSafe>> compile_onnx_model_with_options( - const std::string& model_path, const onnxruntime::Graph& graph, const onnxruntime::ProviderOptions& options); +void initialize_vitisai_ep(); +vaip_core::DllSafe>> compile_onnx_model(const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::logging::Logger& logger, const onnxruntime::ProviderOptions& options); +std::shared_ptr get_kernel_registry_vitisaiep(); +const std::vector& get_domains_vitisaiep(); diff --git a/onnxruntime/core/providers/vitisai/include/vaip/graph.h b/onnxruntime/core/providers/vitisai/include/vaip/graph.h index 9def8645709f..292fb2bb38b2 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/graph.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/graph.h @@ -1,25 +1,19 @@ // Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. // Licensed under the MIT License. #pragma once -#include #include "./node.h" +#include "vaip/my_ort.h" namespace vaip { using namespace onnxruntime; void graph_remove_node(Graph& graph, const NodeInput& node_input); -Node& graph_add_node(Graph& graph, const std::string& name, - const std::string& op_type, const std::string& description, - const std::vector& input_args, - const std::vector& output_args, - const NodeAttributes& attributes, - const std::string& domain); - -void graph_save(const Graph& graph, const std::string& filename, const std::string& dat_filename, size_t initializer_size_threshold); -Node& graph_fuse(Graph& graph, const std::string& name, - const std::string& op_type, - const std::vector& nodes, - const std::vector& inputs, - const std::vector& outputs, +Node& graph_add_node(Graph& graph, const std::string& name, const std::string& op_type, const std::string& description, + const std::vector& input_args, const std::vector& output_args, + const NodeAttributes& attributes, const std::string& domain); +void graph_save(const Graph& graph, const std::string& filename, const std::string& dat_filename, + size_t initializer_size_threshold); +Node& graph_fuse(Graph& graph, const std::string& name, const std::string& op_type, const std::vector& nodes, + const std::vector& inputs, const std::vector& outputs, const std::vector& constant_initializers); } // namespace vaip diff --git a/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h b/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h index d43ef1253715..46fc4ac9b2a5 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h @@ -9,15 +9,17 @@ #include namespace onnxruntime { -class Model; -class Graph; -class GraphViewer; -class Node; -class NodeArg; +struct Model; +struct Graph; +struct GraphViewer; +struct Node; +struct NodeArg; +struct ProviderHost; +struct NodeAttributes; } // namespace onnxruntime namespace ONNX_NAMESPACE { -class AttributeProto; -class TensorProto; +struct AttributeProto; +struct TensorProto; #ifndef USE_VITISAI enum TensorProto_DataType : int { TensorProto_DataType_UNDEFINED = 0, @@ -68,6 +70,7 @@ using onnxruntime::GraphViewer; using onnxruntime::Model; using onnxruntime::Node; using onnxruntime::NodeArg; +using onnxruntime::NodeAttributes; struct ModelDeleter { VAIP_DLL_SPEC void operator()(Model* tp) const; }; @@ -75,22 +78,17 @@ using ModelPtr = std::unique_ptr; struct AttributeProtoDeleter { VAIP_DLL_SPEC void operator()(AttributeProto* p) const; }; -using AttributeProtoPtr = - std::unique_ptr; +using AttributeProtoPtr = std::unique_ptr; struct TensorProtoDeleter { VAIP_DLL_SPEC void operator()(TensorProto* tp) const; }; using TensorProtoPtr = std::unique_ptr; -/// I cannot forward declare a using directive, because -/// std::unorderd_map required AttributeProto must be defiend. -class NodeAttributes; struct NodeAttributesDeleter { VAIP_DLL_SPEC void operator()(NodeAttributes* p) const; }; -using NodeAttributesPtr = - std::unique_ptr; +using NodeAttributesPtr = std::unique_ptr; /// get node's input /// when Node* is nullptr, it is a tensor in the initializer. /// node_arg is always non-null. diff --git a/onnxruntime/core/providers/vitisai/include/vaip/node.h b/onnxruntime/core/providers/vitisai/include/vaip/node.h index bad7660f6674..31d9d4bd73b8 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/node.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/node.h @@ -2,10 +2,6 @@ // Licensed under the MIT License. #pragma once - -#include - -#include "core/graph/node_arg.h" #include "vaip/dll_safe.h" #include "vaip/my_ort.h" namespace vaip { @@ -17,8 +13,4 @@ vaip_core::DllSafe> node_get_inputs(const Node& node); /// to support multiple outputs vaip_core::DllSafe> node_get_output_node_args(const Node& node); -/// get output shape -/// index is usually zero, because most operators only have a single output. -vaip_core::DllSafe> node_get_output_shape(const Node& node, int index = 0); - } // namespace vaip diff --git a/onnxruntime/core/providers/vitisai/include/vaip/node_arg.h b/onnxruntime/core/providers/vitisai/include/vaip/node_arg.h index 76432fc5b3a6..fca641c5e11c 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/node_arg.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/node_arg.h @@ -2,9 +2,8 @@ // Licensed under the MIT License. #pragma once -#include #include "vaip/dll_safe.h" -#include +#include "vaip/my_ort.h" namespace vaip { using namespace onnxruntime; @@ -26,9 +25,7 @@ void node_arg_set_shape_i64(const NodeArg& node_arg, void node_arg_set_denotation(const NodeArg& node_arg, const std::vector& denotation); void node_arg_set_element_type(NodeArg& node_arg, - ONNX_NAMESPACE::TensorProto::DataType data_type); -void node_arg_set_shape(NodeArg& node_arg, std::vector shape); - + int data_type); const ONNX_NAMESPACE::TensorProto& node_arg_get_const_data_as_tensor(const Graph& graph, const NodeArg& node_arg); diff --git a/onnxruntime/core/providers/vitisai/include/vaip/node_attrs.h b/onnxruntime/core/providers/vitisai/include/vaip/node_attrs.h deleted file mode 100644 index 49cd1aad89f4..000000000000 --- a/onnxruntime/core/providers/vitisai/include/vaip/node_attrs.h +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include - -#include - -#include "core/graph/basic_types.h" -namespace vaip { -using namespace onnxruntime; -class NodeAttr { - public: - NodeAttr(const std::string& name, int64_t value); - NodeAttr(const std::string& name, const std::vector& value); - NodeAttr(const std::string& name, const std::string& value); - NodeAttr(const std::string& name, const std::vector& value); - NodeAttr(const std::string& name, const std::vector& value); - NodeAttr(const std::string& name, const onnx::TensorProto& value); - - onnx::AttributeProto& get(); - - private: - onnx::AttributeProto attribute_proto_; -}; - -class NodeAttributesBuiler { - public: - explicit NodeAttributesBuiler(size_t capacity = 10); - NodeAttributesBuiler(const NodeAttributesBuiler&) = delete; - NodeAttributesBuiler(NodeAttributesBuiler&&) = default; - /// after build, all attrs_ are cleared. - NodeAttributes build(); - /// for efficiency reason, after merge_into, all attrs_ are moved. - void merge_into(Node& node); - void merge_into(NodeAttributes& attrs); - template - NodeAttributesBuiler& add(const std::string& name, T&& value) { - attrs_.emplace_back(name, std::forward(value)); - return *this; - } - - private: - std::vector attrs_; -}; -} // namespace vaip diff --git a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h index 0d7d5f6220d0..ae5f71d66269 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h @@ -13,6 +13,7 @@ struct OrtApi; namespace vaip_core { struct OrtApiForVaip { + onnxruntime::ProviderHost* host_; const OrtApi* ort_api_; // model Model* (*model_load)(const std::string& file); // [0] @@ -49,7 +50,7 @@ struct OrtApiForVaip { const std::string& description, const std::vector& input_args, const std::vector& output_args, - NodeAttributes& attributes, + const NodeAttributes& attributes, const std::string& domain); // [18] void (*graph_save)(const Graph& graph, const std::string& filename, const std::string& dat_filename, @@ -119,8 +120,8 @@ struct OrtApiForVaip { NodeAttributes* (*node_attributes_new)(); // [46] void (*node_attributes_delete)(NodeAttributes* p); // [47] void (*node_attributes_add)(NodeAttributes& p, AttributeProto&& attr); // [48] - AttributeProto* (*node_attributes_get)(NodeAttributes& p, - const std::string& name); // [49] + const AttributeProto* (*node_attributes_get)(const NodeAttributes& p, + const std::string& name); // [49] DllSafe> (*node_attributes_get_keys)( NodeAttributes& p); // [50] /// attr proto @@ -194,5 +195,4 @@ VAIP_DLL_SPEC const OrtApiForVaip* api(); ? ::vaip_core::api()->name \ : (assert(false && #name " is not set"), nullptr)) #endif -VAIP_DLL_SPEC void initialize_ort(); } // namespace vaip_core diff --git a/onnxruntime/core/providers/vitisai/symbols.def b/onnxruntime/core/providers/vitisai/symbols.def new file mode 100644 index 000000000000..4ec2f7914c20 --- /dev/null +++ b/onnxruntime/core/providers/vitisai/symbols.def @@ -0,0 +1,2 @@ +EXPORTS + GetProvider diff --git a/onnxruntime/core/providers/vitisai/version_script.lds b/onnxruntime/core/providers/vitisai/version_script.lds new file mode 100644 index 000000000000..2c8e9c4b3ed6 --- /dev/null +++ b/onnxruntime/core/providers/vitisai/version_script.lds @@ -0,0 +1,9 @@ +#_init and _fini should be local +VERS_1.0 { + global: + GetProvider; + + # Hide everything else. + local: + *; +}; diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc index 5f20b32cd6dc..6fc09f3495aa 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc @@ -1,91 +1,34 @@ // Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. // Licensed under the MIT License. -#include "core/graph/graph_utils.h" #include "vitisai_execution_provider.h" #include -#include #include #include -#include "core/common/common.h" - #include "vaip/capability.h" #include "vaip/global_api.h" -#include "core/session/custom_ops.h" -#include "core/session/inference_session.h" using namespace ONNX_NAMESPACE; namespace onnxruntime { - constexpr const char* VITISAI = "VITISAI"; -static vaip_core::DllSafe>> compile_onnx_model( - const onnxruntime::GraphViewer& graph_viewer, const logging::Logger& logger, const ProviderOptions& options) { -#ifndef _WIN32 - auto model_path = graph_viewer.ModelPath().ToPathString(); -#else - using convert_t = std::codecvt_utf8; - std::wstring_convert strconverter; - auto model_path = strconverter.to_bytes(graph_viewer.ModelPath().ToPathString()); -#endif - return compile_onnx_model_with_options(model_path, graph_viewer.GetGraph(), options); -} - -struct MyCustomOpKernel : OpKernel { - MyCustomOpKernel(const OpKernelInfo& info, const OrtCustomOp& op) : OpKernel(info), op_(op) { - op_kernel_ = - op_.CreateKernel(&op_, OrtGetApiBase()->GetApi(op_.version), reinterpret_cast(&info)); - } - - ~MyCustomOpKernel() override { op_.KernelDestroy(op_kernel_); } - - Status Compute(OpKernelContext* ctx) const override { - op_.KernelCompute(op_kernel_, reinterpret_cast(ctx)); - return Status::OK(); - } - - private: - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(MyCustomOpKernel); - - const OrtCustomOp& op_; - void* op_kernel_; -}; - -VitisAIExecutionProvider::VitisAIExecutionProvider(const ProviderOptions& info) +VitisAIExecutionProvider::VitisAIExecutionProvider( + const ProviderOptions& info) : IExecutionProvider{onnxruntime::kVitisAIExecutionProvider}, info_(info) { - custom_op_domains_ = initialize_vitisai_ep(); - registry_ = std::make_shared(); CreateKernelRegistry(); } void VitisAIExecutionProvider::CreateKernelRegistry() { - for (const auto& domain : custom_op_domains_) { + for (const auto& domain : get_domains_vitisaiep()) { for (const auto* op : domain->custom_ops_) { - KernelDefBuilder def_builder; - def_builder.SetName(op->GetName(op)); - def_builder.SetDomain(domain->domain_); - def_builder.SinceVersion(1); - if (op->version > 12) { - auto input_count = op->GetInputTypeCount(op); - for (auto i = 0u; i < input_count; i++) { - def_builder.InputMemoryType(op->GetInputMemoryType(op, i), i); - } - } - def_builder.Provider(onnxruntime::kVitisAIExecutionProvider); - KernelCreateFn kernel_create_fn = [op](FuncManager&, const OpKernelInfo& info, - std::unique_ptr& out) -> Status { - out = std::make_unique(info, *op); - return Status::OK(); - }; - std::ignore = registry_->Register(def_builder, kernel_create_fn); vitisai_optypes_.insert(op->GetName(op)); } } } -std::shared_ptr VitisAIExecutionProvider::GetKernelRegistry() const { return registry_; } +std::shared_ptr VitisAIExecutionProvider::GetKernelRegistry() const { return get_kernel_registry_vitisaiep(); } std::vector> VitisAIExecutionProvider::GetCapability( const onnxruntime::GraphViewer& graph, const IKernelLookup& /*kernel_lookup*/) const { @@ -111,9 +54,9 @@ common::Status VitisAIExecutionProvider::Compile(const std::vector& node_compute_funcs) { for (const auto& fused_node_graph : fused_nodes_and_graphs) { NodeComputeInfo compute_info; - const onnx::AttributeProto* attr = graph_utils::GetNodeAttribute(fused_node_graph.fused_node, "index"); - assert(attr != nullptr); - size_t index = (size_t)attr->i(); + auto& attrs = fused_node_graph.fused_node.get().GetAttributes(); + assert(attrs.count("index")); + size_t index = attrs.at("index").i(); compute_info.create_state_func = [this, index](ComputeContext* context, FunctionState* state) { auto* p = (**this->execution_providers_)[index]->compile().release(); *state = p; diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h index e86b53339d4d..186427be4fab 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h @@ -9,8 +9,7 @@ #include #include -#include "core/framework/execution_provider.h" -#include "core/framework/customregistry.h" +#include "core/providers/shared_library/provider_api.h" #include "core/session/onnxruntime_c_api.h" // we cannot include vaip/vaip.hpp here because header file referred by @@ -21,7 +20,6 @@ class DllSafe; class ExecutionProvider; } // namespace vaip_core namespace onnxruntime { - // Logical device representation. class VitisAIExecutionProvider : public IExecutionProvider { public: diff --git a/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc b/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc index 4c416124ca8f..5895e1973f23 100755 --- a/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc @@ -11,7 +11,6 @@ #include "core/framework/execution_provider.h" #include "core/session/abi_session_options_impl.h" -#include "core/providers/shared_library/provider_host_api.h" using namespace onnxruntime; namespace onnxruntime { @@ -30,10 +29,37 @@ std::unique_ptr VitisAIProviderFactory::CreateProvider() { return std::make_unique(info_); } -std::shared_ptr VitisAIProviderFactoryCreator::Create( - const ProviderOptions& provider_options) { - initialize_vitisai_ep(); - return std::make_shared(provider_options); -} +struct VitisAI_Provider : Provider { + // Takes a pointer to a provider specific structure to create the factory. For example, with OpenVINO it is a pointer to an OrtOpenVINOProviderOptions structure + std::shared_ptr + CreateExecutionProviderFactory(const void* options) override { + return std::make_shared(GetProviderOptions(options)); + } + // Convert provider options struct to ProviderOptions which is a map + ProviderOptions GetProviderOptions(const void* options) override { + auto vitisai_options = reinterpret_cast(options); + return *vitisai_options; + } + // Update provider options from key-value string configuration + void UpdateProviderOptions(void* options, const ProviderOptions& provider_options) override { + auto vitisai_options = reinterpret_cast(options); + for (const auto& entry : provider_options) { + vitisai_options->insert_or_assign(entry.first, entry.second); + } + }; + // Get provider specific custom op domain list. Provider has the resposibility to release OrtCustomOpDomain instances it creates. + void GetCustomOpDomainList(IExecutionProviderFactory*, std::vector&) override{}; + // Called right after loading the shared library, if this throws any errors Shutdown() will be called and the library unloaded + void Initialize() override { initialize_vitisai_ep(); } + // Called right before unloading the shared library + void Shutdown() override {} +} g_provider; } // namespace onnxruntime + +extern "C" { + +ORT_API(onnxruntime::Provider*, GetProvider) { + return &onnxruntime::g_provider; +} +} diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 91a7f0d930b5..dec8754ea244 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2724,6 +2724,7 @@ static constexpr OrtApi ort_api_1_to_18 = { &OrtApis::SetDeterministicCompute, &OrtApis::KernelContext_ParallelFor, &OrtApis::SessionOptionsAppendExecutionProvider_OpenVINO_V2, + &OrtApis::SessionOptionsAppendExecutionProvider_VitisAI, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index c1caafa4dcad..9ce94ba89a94 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -509,4 +509,8 @@ ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_OpenVINO_V2, _In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); + +ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_VitisAI, _In_ OrtSessionOptions* options, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); } // namespace OrtApis diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 2e445e4982d2..32ae15e71acc 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -57,6 +57,8 @@ namespace ONNX_NAMESPACE { // We use these names in the provider API because we don't have the protobuf definitions of the RepeatedField* types using int64s = google::protobuf::RepeatedField; +using float32s = google::protobuf::RepeatedField; +using StringStringEntryProtos = google::protobuf::RepeatedPtrField; using TensorProtos = google::protobuf::RepeatedPtrField; using TensorShapeProto_Dimensions = google::protobuf::RepeatedPtrField; using ValueInfoProtos = google::protobuf::RepeatedPtrField; @@ -77,6 +79,7 @@ using IndexedSubGraph_MetaDef = IndexedSubGraph::MetaDef; #include "core/providers/migraphx/migraphx_provider_factory_creator.h" #include "core/providers/openvino/openvino_provider_factory_creator.h" #include "core/providers/tensorrt/tensorrt_provider_factory_creator.h" +#include "core/providers/vitisai/vitisai_provider_factory_creator.h" #include "core/providers/cuda/cuda_provider_factory.h" #include "core/providers/cann/cann_provider_factory.h" @@ -123,6 +126,7 @@ ProviderInfo_Dnnl& GetProviderInfo_Dnnl(); ProviderInfo_ROCM* TryGetProviderInfo_ROCM(); ProviderInfo_ROCM& GetProviderInfo_ROCM(); ProviderHostCPU& GetProviderHostCPU(); +ONNX_NAMESPACE::OpSchema CreateSchema(const std::string& domain, const std::vector& ops); struct TensorShapeProto_Dimension_Iterator_Impl : TensorShapeProto_Dimension_Iterator { TensorShapeProto_Dimension_Iterator_Impl(google::protobuf::internal::RepeatedPtrIterator&& v) : v_{std::move(v)} {} @@ -274,7 +278,10 @@ struct ProviderHostImpl : ProviderHost { Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ uint32_t* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); } Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ int64_t* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); } Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ uint64_t* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); } - + Status UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& tensor, const Path& model_path, + /*out*/ std::vector& unpacked_tensor) override { + return utils::UnpackInitializerData(tensor, model_path, unpacked_tensor); + } uint16_t math__floatToHalf(float f) override { return math::floatToHalf(f); } float math__halfToFloat(uint16_t h) override { return math::halfToFloat(h); } @@ -352,12 +359,32 @@ struct ProviderHostImpl : ProviderHost { void logging__Capture__operator_delete(logging::Capture* p) noexcept override { delete p; } std::ostream& logging__Capture__Stream(logging::Capture* p) noexcept override { return p->Stream(); } + // Env + Env& Env__Default() override { return Env::Default(); } + // Utils::DataTypeUtils (wrapped) const std::string* Utils__DataTypeUtils__ToType(const ONNX_NAMESPACE::TypeProto& type_proto) override { return ONNX_NAMESPACE::Utils::DataTypeUtils::ToType(type_proto); } // int64s (wrapped) int int64s__size(const ONNX_NAMESPACE::int64s* p) override { return p->size(); } const int64_t& int64s__Get(const ONNX_NAMESPACE::int64s* p, int index) override { return p->Get(index); } + void int64s__Reserve(ONNX_NAMESPACE::int64s* p, int size) override { p->Reserve(size); }; + const int64_t* int64s__data(const ONNX_NAMESPACE::int64s* p) override { return p->data(); } + + // float32s + void float32s__Reserve(ONNX_NAMESPACE::float32s* p, int size) override { p->Reserve(size); }; + const float* float32s__data(const ONNX_NAMESPACE::float32s* p) override { return p->data(); } + int float32s__size(const ONNX_NAMESPACE::float32s* p) override { return p->size(); } + + // StringStringEntryProto + std::string* StringStringEntryProto__mutable_key(ONNX_NAMESPACE::StringStringEntryProto* p) override { return p->mutable_key(); } + std::string* StringStringEntryProto__mutable_value(ONNX_NAMESPACE::StringStringEntryProto* p) override { return p->mutable_value(); } + + // StringStringEntryProtos + void StringStringEntryProtos__Clear(ONNX_NAMESPACE::StringStringEntryProtos* p) override { p->Clear(); }; + ONNX_NAMESPACE::StringStringEntryProto* StringStringEntryProtos__Add(ONNX_NAMESPACE::StringStringEntryProtos* p) override { return p->Add(); } + int StringStringEntryProtos__size(ONNX_NAMESPACE::StringStringEntryProtos* p) override { return p->size(); } + ONNX_NAMESPACE::StringStringEntryProto& StringStringEntryProtos__at(ONNX_NAMESPACE::StringStringEntryProtos* p, int index) override { return p->at(index); }; #if !defined(DISABLE_OPTIONAL_TYPE) // TypeProto_Optional (wrapped) @@ -374,6 +401,7 @@ struct ProviderHostImpl : ProviderHost { const ONNX_NAMESPACE::TensorShapeProto& TypeProto_Tensor__shape(const ONNX_NAMESPACE::TypeProto_Tensor* p) override { return p->shape(); } ONNX_NAMESPACE::TensorShapeProto* TypeProto_Tensor__mutable_shape(ONNX_NAMESPACE::TypeProto_Tensor* p) override { return p->mutable_shape(); } int32_t TypeProto_Tensor__elem_type(const ONNX_NAMESPACE::TypeProto_Tensor* p) override { return p->elem_type(); } + void TypeProto_Tensor__set_elem_type(ONNX_NAMESPACE::TypeProto_Tensor* p, int32_t value) override { p->set_elem_type(value); }; // TypeProto_SparseTensor (wrapped) #if !defined(DISABLE_SPARSE_TENSORS) @@ -426,9 +454,18 @@ struct ProviderHostImpl : ProviderHost { float AttributeProto__floats(const ONNX_NAMESPACE::AttributeProto* p, int i) override { return p->floats(i); } const std::string& AttributeProto__strings(const ONNX_NAMESPACE::AttributeProto* p, int i) override { return p->strings(i); } const ONNX_NAMESPACE::int64s& AttributeProto__ints(const ONNX_NAMESPACE::AttributeProto* p) override { return p->ints(); } + const ONNX_NAMESPACE::float32s& AttributeProto__floats(const ONNX_NAMESPACE::AttributeProto* p) override { return p->floats(); } + ONNX_NAMESPACE::int64s* AttributeProto__mutable_ints(ONNX_NAMESPACE::AttributeProto* p) override { return p->mutable_ints(); } + ONNX_NAMESPACE::float32s* AttributeProto__mutable_floats(ONNX_NAMESPACE::AttributeProto* p) override { return p->mutable_floats(); } + void AttributeProto__add_ints(ONNX_NAMESPACE::AttributeProto* p, int64_t value) override { p->add_ints(value); }; + void AttributeProto__add_floats(ONNX_NAMESPACE::AttributeProto* p, float value) override { p->add_floats(value); }; + void AttributeProto__add_strings(ONNX_NAMESPACE::AttributeProto* p, const ::std::string& value) override { p->add_strings(value); }; + int64_t AttributeProto__i(const ONNX_NAMESPACE::AttributeProto* p) override { return p->i(); } float AttributeProto__f(const ONNX_NAMESPACE::AttributeProto* p) override { return p->f(); } + const ONNX_NAMESPACE::TensorProto& AttributeProto__t(const ONNX_NAMESPACE::AttributeProto* p) override { return p->t(); } void AttributeProto__set_s(ONNX_NAMESPACE::AttributeProto* p, const ::std::string& value) override { return p->set_s(value); } + void AttributeProto__set_f(ONNX_NAMESPACE::AttributeProto* p, const float& value) override { return p->set_f(value); } void AttributeProto__set_i(ONNX_NAMESPACE::AttributeProto* p, int64_t value) override { return p->set_i(value); } const ::std::string& AttributeProto__s(const ONNX_NAMESPACE::AttributeProto* p) override { return p->s(); } void AttributeProto__set_name(ONNX_NAMESPACE::AttributeProto* p, const ::std::string& value) override { return p->set_name(value); } @@ -450,6 +487,7 @@ struct ProviderHostImpl : ProviderHost { ONNX_NAMESPACE::ValueInfoProtos* GraphProto__mutable_value_info(ONNX_NAMESPACE::GraphProto* p) override { return p->mutable_value_info(); } ONNX_NAMESPACE::TensorProtos* GraphProto__mutable_initializer(ONNX_NAMESPACE::GraphProto* p) override { return p->mutable_initializer(); } ONNX_NAMESPACE::NodeProto* GraphProto__add_node(ONNX_NAMESPACE::GraphProto* p) override { return p->add_node(); } + std::string* GraphProto__mutable_name(ONNX_NAMESPACE::GraphProto* p) override { return p->mutable_name(); } ONNX_NAMESPACE::NodeProto* GraphProto__mutable_node(ONNX_NAMESPACE::GraphProto* p, int index) override { return p->mutable_node(index); } void GraphProto__operator_assign(ONNX_NAMESPACE::GraphProto* p, const ONNX_NAMESPACE::GraphProto& v) override { *p = v; } @@ -467,6 +505,7 @@ struct ProviderHostImpl : ProviderHost { ONNX_NAMESPACE::GraphProto* ModelProto__mutable_graph(ONNX_NAMESPACE::ModelProto* p) override { return p->mutable_graph(); } void ModelProto__set_ir_version(ONNX_NAMESPACE::ModelProto* p, int64_t value) override { p->set_ir_version(value); } + ONNX_NAMESPACE::StringStringEntryProtos* ModelProto__mutable_metadata_props(ONNX_NAMESPACE::ModelProto* p) override { return p->mutable_metadata_props(); }; // NodeProto (wrapped) std::unique_ptr NodeProto__construct() override { return std::make_unique(); } @@ -481,19 +520,34 @@ struct ProviderHostImpl : ProviderHost { void TensorProto__operator_delete(ONNX_NAMESPACE::TensorProto* p) override { delete p; } void TensorProto__operator_assign(ONNX_NAMESPACE::TensorProto* p, const ONNX_NAMESPACE::TensorProto& v) override { *p = v; } bool TensorProto__has_name(const ONNX_NAMESPACE::TensorProto* p) override { return p->has_name(); } + void TensorProto__set_name(ONNX_NAMESPACE::TensorProto* p, const ::std::string& name) override { p->set_name(name); } + const ::std::string& TensorProto__name(const ONNX_NAMESPACE::TensorProto* p) override { return p->name(); } int TensorProto__dims_size(const ONNX_NAMESPACE::TensorProto* p) override { return p->dims_size(); } const ONNX_NAMESPACE::int64s& TensorProto__dims(const ONNX_NAMESPACE::TensorProto* p) override { return p->dims(); } + void TensorProto__add_dims(ONNX_NAMESPACE::TensorProto* p, int64_t value) override { p->add_dims(value); } bool TensorProto__has_data_location(const ONNX_NAMESPACE::TensorProto* p) override { return p->has_data_location(); } int TensorProto__data_location(const ONNX_NAMESPACE::TensorProto* p) override { return p->data_location(); } bool TensorProto__has_raw_data(const ONNX_NAMESPACE::TensorProto* p) override { return p->has_raw_data(); } const std::string& TensorProto__raw_data(const ONNX_NAMESPACE::TensorProto* p) override { return p->raw_data(); } + std::string* TensorProto__mutable_raw_data(ONNX_NAMESPACE::TensorProto* p) override { return p->mutable_raw_data(); } + int32_t TensorProto__data_type(const ONNX_NAMESPACE::TensorProto* p) override { return p->data_type(); } + void TensorProto__set_data_type(ONNX_NAMESPACE::TensorProto* p, int32_t type) override { p->set_data_type(type); } bool TensorProto_DataType_IsValid(int value) override { return ONNX_NAMESPACE::TensorProto::DataType_IsValid(value); } void TensorProto__CopyFrom(ONNX_NAMESPACE::TensorProto* p, const ONNX_NAMESPACE::TensorProto* other) override { p->CopyFrom(*other); } + ONNX_NAMESPACE::StringStringEntryProtos* TensorProto__mutable_external_data(ONNX_NAMESPACE::TensorProto* p) override { return p->mutable_external_data(); }; + void TensorProto__clear_float_data(ONNX_NAMESPACE::TensorProto* p) override { p->clear_float_data(); } + void TensorProto__clear_int32_data(ONNX_NAMESPACE::TensorProto* p) override { p->clear_int32_data(); } + void TensorProto__clear_string_data(ONNX_NAMESPACE::TensorProto* p) override { p->clear_string_data(); } + void TensorProto__clear_int64_data(ONNX_NAMESPACE::TensorProto* p) override { p->clear_int64_data(); } + void TensorProto__clear_double_data(ONNX_NAMESPACE::TensorProto* p) override { p->clear_double_data(); } + void TensorProto__clear_uint64_data(ONNX_NAMESPACE::TensorProto* p) override { p->clear_uint64_data(); } // TensorProtos (wrapped) ONNX_NAMESPACE::TensorProto* TensorProtos__Add(ONNX_NAMESPACE::TensorProtos* p) override { return p->Add(); } + int TensorProtos__size(ONNX_NAMESPACE::TensorProtos* p) override { return p->size(); } + ONNX_NAMESPACE::TensorProto& TensorProtos__at(ONNX_NAMESPACE::TensorProtos* p, int index) override { return p->at(index); }; // TensorShapeProto_Dimension (wrapped) int TensorShapeProto_Dimension__value_case(const ONNX_NAMESPACE::TensorShapeProto_Dimension* p) override { return p->value_case(); } @@ -503,6 +557,8 @@ struct ProviderHostImpl : ProviderHost { void TensorShapeProto_Dimension__set_dim_value(ONNX_NAMESPACE::TensorShapeProto_Dimension* p, int64_t value) override { return p->set_dim_value(value); } bool TensorShapeProto_Dimension__has_dim_value(const ONNX_NAMESPACE::TensorShapeProto_Dimension* p) override { return p->has_dim_value(); } bool TensorShapeProto_Dimension__has_dim_param(const ONNX_NAMESPACE::TensorShapeProto_Dimension* p) override { return p->has_dim_param(); } + const std::string& TensorShapeProto_Dimension__denotation(const ONNX_NAMESPACE::TensorShapeProto_Dimension* p) const override { return p->denotation(); } + void TensorShapeProto_Dimension__set_denotation(ONNX_NAMESPACE::TensorShapeProto_Dimension* p, const std::string& value) override { return p->set_denotation(value); } // TensorShapeProto_Dimensions (wrapped) std::unique_ptr TensorShapeProto_Dimensions__begin(const ONNX_NAMESPACE::TensorShapeProto_Dimensions* p) override { @@ -531,6 +587,90 @@ struct ProviderHostImpl : ProviderHost { const ONNX_NAMESPACE::ValueInfoProto& ValueInfoProtos__operator_array(const ONNX_NAMESPACE::ValueInfoProtos* p, int index) override { return (*p)[index]; } + static void xir_shape_infer(ONNX_NAMESPACE::InferenceContext& ctx) { + auto* shape = ctx.getAttribute("shape"); + auto* data_type = ctx.getAttribute("data_type"); + int32_t elemType = 0; + if (data_type->s() == "float32") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT; + } else if (data_type->s() == "int8") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_INT8; + } else if (data_type->s() == "uint8") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT8; + } else if (data_type->s() == "int32") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_INT32; + } else if (data_type->s() == "int64") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_INT64; + } else if (data_type->s() == "int1") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_BOOL; + } else if (data_type->s() == "bfloat16") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16; + } else if (data_type->s() == "float16") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT16; + } else if (data_type->s() == "uint16") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT16; + } else if (data_type->s() == "int16") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_INT16; + } else { + return; + } + ONNX_NAMESPACE::updateOutputElemType(ctx, 0, elemType); + if (shape != nullptr) { + for (auto i = 0; i < shape->ints_size(); ++i) { + ONNX_NAMESPACE::getOutputShape(ctx, 0, ONNX_NAMESPACE::TypeProto::kTensorType)->add_dim()->set_dim_value(shape->ints(i)); + } + } else { + // set scalar type. + ONNX_NAMESPACE::getOutputShape(ctx, 0, ONNX_NAMESPACE::TypeProto::kTensorType)->clear_dim(); + } + } + + static void xir_fixneuron_shape_inference(ONNX_NAMESPACE::InferenceContext& ctx) { + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); + ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, 0, 0); + } + + static void xir_subgraph_shape_inference(ONNX_NAMESPACE::InferenceContext& ctx) { + auto num_inputs = ctx.getNumInputs(); + + // Run inferencing on the subgraph + auto* graphInferencer = ctx.getGraphAttributeInferencer("body"); + + std::vector input_data; + std::vector subgraph_input_types; + for (size_t i = 0; i < num_inputs; ++i) { + input_data.push_back(ctx.getInputData(i)); + subgraph_input_types.push_back(ctx.getInputType(i)); + } + + auto output_types = graphInferencer->doInferencing(subgraph_input_types, input_data); + for (size_t i = 0, end = output_types.size(); i < end; ++i) { + *ctx.getOutputType(i) = *output_types[i]; + } + } + void RegisterSchema(const std::string& domain, const OrtCustomOp* op, int type) override { + auto& domain_instance = ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance(); + const auto& domain_to_version_map = domain_instance.Map(); + if (domain_to_version_map.find(domain) == domain_to_version_map.end()) { + domain_instance.AddDomainToVersion(domain, 1, 1000); + } + auto schema = CreateSchema(domain, {op}); + switch (type) { + case 1: + schema.TypeAndShapeInferenceFunction(xir_subgraph_shape_inference); + break; + case 2: + schema.TypeAndShapeInferenceFunction(xir_fixneuron_shape_inference); + break; + case 3: + schema.TypeAndShapeInferenceFunction(xir_shape_infer); + break; + default: + break; + } + ONNX_NAMESPACE::RegisterSchema(schema, ORT_API_VERSION); + } + // ConfigOptions (wrapped) std::optional ConfigOptions__GetConfigEntry(const ConfigOptions* p, const std::string& config_key) override { return p->GetConfigEntry(config_key); @@ -762,6 +902,9 @@ struct ProviderHostImpl : ProviderHost { void Node__ToProto(const Node* p, ONNX_NAMESPACE::NodeProto& proto, bool update_subgraphs = false) override { p->ToProto(proto, update_subgraphs); } const NodeAttributes& Node__GetAttributes(const Node* p) noexcept override { return p->GetAttributes(); } + void Node__AddAttribute(Node* p, const ::std::string& attr_name, const ONNX_NAMESPACE::GraphProto& value) override { + p->AddAttribute(attr_name, value); + } size_t Node__GetInputEdgesCount(const Node* p) noexcept override { return p->GetInputEdgesCount(); } size_t Node__GetOutputEdgesCount(const Node* p) noexcept override { return p->GetOutputEdgesCount(); } @@ -770,13 +913,19 @@ struct ProviderHostImpl : ProviderHost { std::unique_ptr Node__OutputNodesBegin(const Node* p) noexcept override { return std::make_unique(p->OutputNodesBegin()); } std::unique_ptr Node__OutputNodesEnd(const Node* p) noexcept override { return std::make_unique(p->OutputNodesEnd()); } - + std::unique_ptr Node__InputEdgesBegin(const Node* p) noexcept override { + return std::make_unique(p->InputEdgesBegin()); + } + std::unique_ptr Node__InputEdgesEnd(const Node* p) noexcept override { + return std::make_unique(p->InputEdgesEnd()); + } std::unique_ptr Node__OutputEdgesBegin(const Node* p) noexcept override { return std::make_unique(p->OutputEdgesBegin()); } std::unique_ptr Node__OutputEdgesEnd(const Node* p) noexcept override { return std::make_unique(p->OutputEdgesEnd()); } void Node__ForEachDef(const Node* p, std::function func, bool include_missing_optional_defs) override { p->ForEachDef(func, std::move(include_missing_optional_defs)); } const std::unordered_map>& Node__GetAttributeNameToMutableSubgraphMap(Node* p) noexcept override { return p->GetAttributeNameToMutableSubgraphMap(); } std::unordered_map> Node__GetAttributeNameToSubgraphMap(const Node* p) const override { return p->GetAttributeNameToSubgraphMap(); } + int Node__NodeType(const Node* p) const noexcept override { return int(p->NodeType()); } // NodeArg (wrapped) const std::string& NodeArg__Name(const NodeArg* p) noexcept override { return p->Name(); } @@ -785,6 +934,7 @@ struct ProviderHostImpl : ProviderHost { const NodeArgInfo& NodeArg__ToProto(const NodeArg* p) noexcept override { return p->ToProto(); } bool NodeArg__Exists(const NodeArg* p) const noexcept override { return p->Exists(); } const ONNX_NAMESPACE::TypeProto* NodeArg__TypeAsProto(const NodeArg* p) noexcept override { return p->TypeAsProto(); } + Status NodeArg__OverrideTypesHelper(NodeArg* p, const ONNX_NAMESPACE::TypeProto& input_type, int32_t input_tensor_elem_type, int32_t current_tensor_elem_type, bool override_types) override { return p->OverrideTypesHelper(input_type, input_tensor_elem_type, current_tensor_elem_type, override_types); }; // NodeAttributes (wrapped) std::unique_ptr NodeAttributes__construct() override { return std::make_unique(); } @@ -807,12 +957,20 @@ struct ProviderHostImpl : ProviderHost { } void NodeAttributes__insert(NodeAttributes* p, const NodeAttributes& v) override { return p->insert(v.begin(), v.end()); } void NodeAttributes__emplace(NodeAttributes* p, const std::string& k, const ONNX_NAMESPACE::AttributeProto& v) override { p->emplace(k, v); } + void NodeAttributes__insert_or_assign(NodeAttributes* p, const std::string& k, const ONNX_NAMESPACE::AttributeProto& v) override { p->insert_or_assign(k, v); } void NodeAttributes__reserve(NodeAttributes* p, size_t size) override { p->reserve(size); } // Model (wrapped) + std::unique_ptr Model__construct(ONNX_NAMESPACE::ModelProto&& model_proto, const PathString& model_path, + const logging::Logger& logger) override { + return std::make_unique(model_proto, model_path, nullptr, logger); + } void Model__operator_delete(Model* p) override { delete p; } Graph& Model__MainGraph(Model* p) override { return p->MainGraph(); } std::unique_ptr Model__ToProto(Model* p) override { return std::make_unique(p->ToProto()); } + std::unique_ptr Model__ToGraphProtoWithExternalInitializers(Model* p, const std::string& external_file_name, const PathString& file_path, size_t initializer_size_threshold) override { return std::make_unique(p->ToGraphProtoWithExternalInitializers(external_file_name, file_path, initializer_size_threshold)); }; + const ModelMetaData& Model__MetaData(const Model* p) const noexcept override { return p->MetaData(); }; + Status Model__Load(const PathString& file_path, /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) override { return Model::Load(file_path, model_proto); } // Graph (wrapped) std::unique_ptr Graph__CreateGraphViewer(const Graph* p) override { return std::make_unique(*p); } @@ -832,6 +990,12 @@ struct ProviderHostImpl : ProviderHost { void Graph__SetOutputs(Graph* p, gsl::span outputs) override { p->SetOutputs(outputs); } const std::vector& Graph__GetInputs(const Graph* p) noexcept override { return p->GetInputs(); } + std::vector Graph__Nodes(const Graph* p) override { + auto& node_refererence = p->Nodes(); + std::vector nodes(p->NumberOfNodes(), nullptr); + std::transform(node_refererence.begin(), node_refererence.end(), nodes.begin(), [](const Node& n) { return &n; }); + return nodes; + } bool Graph__GetInitializedTensor(const Graph* p, const std::string& tensor_name, const ONNX_NAMESPACE::TensorProto*& value) override { return p->GetInitializedTensor(tensor_name, value); } const Node* Graph__ParentNode(const Graph* p) const override { return p->ParentNode(); } @@ -841,6 +1005,40 @@ struct ProviderHostImpl : ProviderHost { const Path& Graph__ModelPath(const Graph* p) const override { return p->ModelPath(); } const std::vector& Graph__GetInputsIncludingInitializers(const Graph* p) const noexcept override { return p->GetInputsIncludingInitializers(); } bool Graph__IsSubgraph(const Graph* p) override { return p->IsSubgraph(); } + const Node* Graph__GetProducerNode(const Graph* p, const std::string& node_arg_name) const override { return p->GetProducerNode(node_arg_name); } + const Model& Graph__GetModel(const Graph* p) override { return p->GetModel(); } + void Graph__ReverseDFSFrom(const Graph* p, gsl::span from, + const std::function& enter, + const std::function& leave, + const std::function& comp, + const std::function& stop) const override { + p->ReverseDFSFrom(from, enter, leave, comp, stop); + } + Graph& Graph__SetGraphResolveNeeded(Graph* p) override { return p->SetGraphResolveNeeded(); } + void Graph__RemoveInitializedTensor(Graph* p, const std::string& tensor_name) override { p->RemoveInitializedTensor(tensor_name); } + + std::vector Graph__GetConsumerNodes(const Graph* p, const std::string& node_arg_name) const override { + return p->GetConsumerNodes(node_arg_name); + } + void Graph__AddEdge(Graph* p, NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, + int dst_arg_index) override { + p->AddEdge(src_node_index, dst_node_index, src_arg_index, dst_arg_index); + } + void Graph__RemoveEdge(Graph* p, NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, + int dst_arg_index) override { + p->RemoveEdge(src_node_index, dst_node_index, src_arg_index, dst_arg_index); + } + void Graph__RemoveNode(Graph* p, NodeIndex index) override { p->RemoveNode(index); } + Node& Graph__FuseSubGraph(Graph* p, const IndexedSubGraph& sub_graph, const std::string& fused_node_name) override { + return p->FuseSubGraph(sub_graph, fused_node_name); + } + void Graph__UpdateProducerNode(Graph* p, const std::string& node_arg_name, NodeIndex node_index) override { + p->UpdateProducerNode(node_arg_name, node_index); + } + const ONNX_NAMESPACE::TensorProto* Graph__GetConstantInitializer(const Graph* p, const std::string& name, bool check_outer_scope) const override { + return p->GetConstantInitializer(name, check_outer_scope); + } + const InitializedTensorSet& Graph__GetAllInitializedTensors(const Graph* p) override { return p->GetAllInitializedTensors(); } int Graph__MaxNodeIndex(const Graph* p) const noexcept override { return p->MaxNodeIndex(); } Node* Graph__GetNode(Graph* p, NodeIndex node_index) noexcept override { return p->GetNode(node_index); } const Node* Graph__GetNode(const Graph* p, NodeIndex node_index) const override { return p->GetNode(node_index); } @@ -885,11 +1083,14 @@ struct ProviderHostImpl : ProviderHost { void GraphViewer__ToProto(const GraphViewer* p, ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializers, bool include_outer_scope_args) noexcept override { GraphViewerToProto(*p, graph_proto, include_initializers, include_outer_scope_args); } + const Node* GraphViewer__GetProducerNode(const GraphViewer* p, const std::string& node_arg_name) const override { return p->GetProducerNode(node_arg_name); } // Path (wrapped) PathString Path__ToPathString(const Path* p) noexcept override { return p->ToPathString(); } const std::vector& Path__GetComponents(const Path* p) noexcept override { return p->GetComponents(); } bool Path__IsEmpty(const Path* p) noexcept override { return p->IsEmpty(); } + std::unique_ptr Path__construct() override { return std::make_unique(); } + void Path__operator_delete(ONNX_NAMESPACE::Path* p) override { delete p; }; // OpKernel (direct) const Node& OpKernel__Node(const OpKernel* p) override { return p->OpKernel::Node(); } @@ -1280,6 +1481,7 @@ static ProviderLibrary s_library_rocm(LIBRARY_PREFIX ORT_TSTR("onnxruntime_provi #endif ); static ProviderLibrary s_library_dnnl(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_dnnl") LIBRARY_EXTENSION); +static ProviderLibrary s_library_vitisai(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_vitisai") LIBRARY_EXTENSION); static ProviderLibrary s_library_openvino(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_openvino") LIBRARY_EXTENSION); static ProviderLibrary s_library_tensorrt(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_tensorrt") LIBRARY_EXTENSION #ifndef _WIN32 @@ -1308,6 +1510,7 @@ static ProviderLibrary s_library_migraphx(LIBRARY_PREFIX ORT_TSTR("onnxruntime_p void UnloadSharedProviders() { s_library_dnnl.Unload(); + s_library_vitisai.Unload(); s_library_openvino.Unload(); s_library_tensorrt.Unload(); s_library_cuda.Unload(); @@ -1524,6 +1727,10 @@ std::shared_ptr DnnlProviderFactoryCreator::Create(co return s_library_dnnl.Get().CreateExecutionProviderFactory(dnnl_options); } +std::shared_ptr VitisAIProviderFactoryCreator::Create(const ProviderOptions& provider_options) { + return s_library_vitisai.Get().CreateExecutionProviderFactory(&provider_options); +} + ProviderInfo_OpenVINO* GetProviderInfo_OpenVINO() { return reinterpret_cast(s_library_openvino.Get().GetInfo()); } @@ -2416,3 +2623,34 @@ ORT_API(void, OrtApis::ReleaseROCMProviderOptions, _Frees_ptr_opt_ OrtROCMProvid ORT_UNUSED_PARAMETER(ptr); #endif } + +ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_VitisAI, _In_ OrtSessionOptions* options, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys) { + API_IMPL_BEGIN + onnxruntime::ProviderOptions provider_options; + for (size_t i = 0; i != num_keys; ++i) { + if (provider_options_keys[i] == nullptr || provider_options_keys[i][0] == '\0' || + provider_options_values[i] == nullptr || provider_options_values[i][0] == '\0') { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Provider options key/value cannot be empty"); + } + + // arbitrary length to validate the key/value. adjust if/when needed. + // TODO: are any other input validation checks required here (and in the other functions that process + // provider options)? + if (strlen(provider_options_keys[i]) > 1024 || strlen(provider_options_values[i]) > 1024) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Maximum string length for a provider options key/value is 1024."); + } + + provider_options[provider_options_keys[i]] = provider_options_values[i]; + } + auto factory = onnxruntime::VitisAIProviderFactoryCreator::Create(provider_options); + if (!factory) { + return OrtApis::CreateStatus(ORT_FAIL, "SessionOptionsAppendExecutionProvider_VitisAI: Failed to load shared library"); + } + + options->provider_factories.push_back(factory); + return nullptr; + API_IMPL_END +} diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index 964355956b4a..ade1d96d617f 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -148,12 +148,6 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, options->provider_factories.push_back(JsProviderFactoryCreator::Create(provider_options, &(options->value))); #else status = create_not_supported_status(); -#endif - } else if (strcmp(provider_name, "VitisAI") == 0) { -#if defined(USE_VITISAI) - options->provider_factories.push_back(VitisAIProviderFactoryCreator::Create(provider_options)); -#else - status = create_not_supported_status(); #endif } else { ORT_UNUSED_PARAMETER(options); @@ -499,4 +493,14 @@ ORT_API_STATUS_IMPL(OrtApis::GetROCMProviderOptionsAsString, ORT_API(void, OrtApis::ReleaseROCMProviderOptions, _Frees_ptr_opt_ OrtROCMProviderOptions* ptr) { ORT_UNUSED_PARAMETER(ptr); } + +ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_VitisAI, + _In_ OrtSessionOptions* options, _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys) { + ORT_UNUSED_PARAMETER(options); + ORT_UNUSED_PARAMETER(provider_options_keys); + ORT_UNUSED_PARAMETER(provider_options_values); + ORT_UNUSED_PARAMETER(num_keys); + return CreateNotEnabledStatus("VitisAI"); +} #endif diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 8e13982ca686..9c36eb635ffc 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -982,7 +982,7 @@ std::unique_ptr CreateExecutionProviderInstance( return onnxruntime::TVMProviderFactoryCreator::Create(info)->CreateProvider(); #endif } else if (type == kVitisAIExecutionProvider) { -#if USE_VITISAI +#ifdef USE_VITISAI const auto it = provider_options_map.find(type); if (it == provider_options_map.end()) { LOGS_DEFAULT(FATAL) << "cannot find provider options for VitisAIExecutionProvider"; diff --git a/setup.py b/setup.py index e94165fdf9b0..67d34b065ad0 100644 --- a/setup.py +++ b/setup.py @@ -298,6 +298,7 @@ def finalize_options(self): libs.extend(["libonnxruntime_providers_shared.so"]) libs.extend(["libonnxruntime_providers_dnnl.so"]) libs.extend(["libonnxruntime_providers_openvino.so"]) + libs.extend(["libonnxruntime_providers_vitisai.so"]) libs.append(providers_cuda_or_rocm) libs.append(providers_tensorrt_or_migraphx) libs.append(providers_cann) @@ -310,6 +311,7 @@ def finalize_options(self): libs.extend(["libonnxruntime_providers_dnnl.dylib"]) libs.extend(["libonnxruntime_providers_tensorrt.dylib"]) libs.extend(["libonnxruntime_providers_cuda.dylib"]) + libs.extend(["libonnxruntime_providers_vitisai.dylib"]) if nightly_build: libs.extend(["libonnxruntime_pywrapper.dylib"]) else: @@ -320,6 +322,7 @@ def finalize_options(self): libs.extend(["onnxruntime_providers_tensorrt.dll"]) libs.extend(["onnxruntime_providers_openvino.dll"]) libs.extend(["onnxruntime_providers_cuda.dll"]) + libs.extend(["onnxruntime_providers_vitisai.dll"]) # DirectML Libs libs.extend(["DirectML.dll"]) if nightly_build: