Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VitisAI] Cache node subgraph when necessary #22073

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmake/onnxruntime_providers_vitisai.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
file(GLOB onnxruntime_providers_vitisai_cc_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/core/providers/vitisai/*.cc"
"${ONNXRUNTIME_ROOT}/core/providers/vitisai/*.h"
"${ONNXRUNTIME_ROOT}/core/providers/vitisai/include/vaip/*.h"
"${ONNXRUNTIME_ROOT}/core/providers/vitisai/imp/*.cc"
"${ONNXRUNTIME_ROOT}/core/providers/vitisai/imp/*.h"
"${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -993,6 +993,7 @@ struct ProviderHost {
bool include_outer_scope_args,
int execution_order) noexcept = 0;
virtual const Node* GraphViewer__GetProducerNode(const GraphViewer* p, const std::string& node_arg_name) const = 0;
virtual IOnnxRuntimeOpSchemaCollectionPtr GraphViewer__GetSchemaRegistry(const GraphViewer* p) const = 0;

// OpKernel
virtual const Node& OpKernel__Node(const OpKernel* p) = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1064,6 +1064,7 @@ class GraphViewer final {
g_host->GraphViewer__ToProto(this, graph_proto, include_initializers, include_outer_scope_args, execution_order);
}
const Node* GetProducerNode(const std::string& node_arg_name) const { return g_host->GraphViewer__GetProducerNode(this, node_arg_name); }
IOnnxRuntimeOpSchemaCollectionPtr GetSchemaRegistry() const { return g_host->GraphViewer__GetSchemaRegistry(this); }

GraphViewer() = delete;
GraphViewer(const GraphViewer&) = delete;
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/providers/vitisai/imp/global_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,11 @@ vaip_core::OrtApiForVaip* create_org_api_hook() {
graph.SetInputs(inputs);
};
the_global_api.node_arg_external_location = vaip::node_arg_external_location;
the_global_api.model_to_proto = [](onnxruntime::Model& model) { return model.ToProto().release(); };
the_global_api.model_proto_serialize_as_string = [](ONNX_NAMESPACE::ModelProto& model_proto) {
return vaip_core::DllSafe(model_proto.SerializeAsString());
};
the_global_api.model_proto_delete = [](ONNX_NAMESPACE::ModelProto* p) { delete p; };
if (!s_library_vitisaiep.vaip_get_version) {
return reinterpret_cast<vaip_core::OrtApiForVaip*>(&(the_global_api.host_));
} else {
Expand Down
12 changes: 6 additions & 6 deletions onnxruntime/core/providers/vitisai/include/vaip/custom_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,18 @@ class ExecutionProvider {
virtual DllSafe<std::vector<std::string>> get_meta_def_nodes() const = 0;
virtual DllSafe<std::vector<std::string>>
get_meta_def_constant_initializer() const = 0;
virtual bool get_meta_def_fallback_CPU() const { return false; };
virtual std::unique_ptr<CustomOp> compile() const = 0;

public:
inline void set_fused_node(const onnxruntime::Node* fused_node) {
fused_node_ = fused_node;
}
inline const onnxruntime::Node* get_fused_node() const {
return fused_node_;
}
inline void set_fused_node(const onnxruntime::Node* fused_node) { fused_node_ = fused_node; }
inline const onnxruntime::Node* get_fused_node() const { return fused_node_; }
inline void set_model(onnxruntime::Model* model) { model_ = model; }
inline onnxruntime::Model* get_model() const { return model_; }

private:
const onnxruntime::Node* fused_node_ = nullptr;
onnxruntime::Model* model_ = nullptr;
};

class CustomOp {
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/vitisai/include/vaip/my_ort.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ struct NodeAttributes;
namespace ONNX_NAMESPACE {
struct AttributeProto;
struct TensorProto;
struct ModelProto;
#ifndef USE_VITISAI
enum TensorProto_DataType : int {
TensorProto_DataType_UNDEFINED = 0,
Expand Down Expand Up @@ -70,6 +71,7 @@ enum AttributeProto_AttributeType : int {
namespace vaip_core {
class GraphHolder;
using ONNX_NAMESPACE::AttributeProto;
using ONNX_NAMESPACE::ModelProto;
using ONNX_NAMESPACE::TensorProto;
using onnxruntime::Graph;
using onnxruntime::GraphViewer;
Expand Down
11 changes: 7 additions & 4 deletions onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ struct OrtApi;

namespace vaip_core {

#define VAIP_ORT_API_MAJOR (9u)
#define VAIP_ORT_API_MAJOR (11u)
#define VAIP_ORT_API_MINOR (0u)
#define VAIP_ORT_API_PATCH (0u)
struct OrtApiForVaip {
Expand Down Expand Up @@ -227,9 +227,12 @@ struct OrtApiForVaip {
const std::vector<int16_t>& data); // [89]
const std::filesystem::path& (*get_model_path)(const Graph& graph); // [90]
Model* (*create_empty_model)(const std::filesystem::path& path, const std::vector<std::pair<std::string, int64_t>>& opset); //[91]
void (*graph_set_inputs)(Graph& graph,
gsl::span<const NodeArg* const> inputs); // [92]
int (*node_arg_external_location)(const Graph& graph, const NodeArg& node_arg, std::string& file, size_t& offset, size_t& size, size_t& checksum); // [93]
void (*graph_set_inputs)(Graph& graph, gsl::span<const NodeArg* const> inputs); // [92]
int (*node_arg_external_location)(const Graph& graph, const NodeArg& node_arg, std::string& file,
size_t& offset, size_t& size, size_t& checksum); // [93]
ModelProto* (*model_to_proto)(Model& model); //[94]
DllSafe<std::string> (*model_proto_serialize_as_string)(ModelProto& model_proto); //[95]
void (*model_proto_delete)(ModelProto* p); // [96]
};

#ifndef USE_VITISAI
Expand Down
12 changes: 11 additions & 1 deletion onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,17 @@ common::Status VitisAIExecutionProvider::Compile(const std::vector<FusedNodeAndG
auto& attrs = fused_node_graph.fused_node.get().GetAttributes();
assert(attrs.count("index"));
size_t index = attrs.at("index").i();
(**this->execution_providers_)[index]->set_fused_node(&fused_node_graph.fused_node.get());
auto& ep = (**this->execution_providers_)[index];
ep->set_fused_node(&fused_node_graph.fused_node.get());
if (ep->get_meta_def_fallback_CPU()) {
auto& subgraph = fused_node_graph.filtered_graph.get();
auto& logger = logging::LoggingManager::DefaultLogger();
auto model_proto = subgraph.CreateModel(logger)->ToProto();
subgraph.ToProto(*model_proto->mutable_graph(), true, true);
auto local_registries = IOnnxRuntimeOpSchemaRegistryList{subgraph.GetSchemaRegistry()};
auto model = Model::Create(std::move(*model_proto), subgraph.ModelPath(), &local_registries, logger);
ep->set_model(model.release());
}
compute_info.create_state_func = [this, index](ComputeContext* context, FunctionState* state) {
auto* p = (**this->execution_providers_)[index]->compile().release();
*state = p;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ class VitisAIExecutionProvider : public IExecutionProvider {
ProviderOptions info_;
std::vector<OrtCustomOpDomain*> custom_op_domains_;
std::shared_ptr<KernelRegistry> registry_;
std::set<std::string> vitisai_optypes_;
// EP context related.
bool ep_ctx_enabled_ = false;
bool ep_ctx_embed_mode_ = true;
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/session/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1203,6 +1203,7 @@ struct ProviderHostImpl : ProviderHost {
GraphViewerToProto(*p, graph_proto, include_initializers, include_outer_scope_args, static_cast<ExecutionOrder>(execution_order));
}
const Node* GraphViewer__GetProducerNode(const GraphViewer* p, const std::string& node_arg_name) const override { return p->GetProducerNode(node_arg_name); }
IOnnxRuntimeOpSchemaCollectionPtr GraphViewer__GetSchemaRegistry(const GraphViewer* p) const { return p->GetSchemaRegistry(); }

// OpKernel (direct)
const Node& OpKernel__Node(const OpKernel* p) override { return p->OpKernel::Node(); }
Expand Down