Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorRT EP] Refactor of TRT plugins support #17946

Merged
merged 25 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1206,6 +1206,12 @@ Status TensorrtExecutionProvider::OnRunEnd(bool sync_stream) {
}

void TensorrtExecutionProvider::GetCustomOpDomainList(std::vector<OrtCustomOpDomain*>& custom_op_domain_list) const {
if (info_.custom_op_domain_list.empty()) {
common::Status status = CreateTensorRTCustomOpDomainList(info_);
if (!status.IsOK()) {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration.";
}
}
custom_op_domain_list = info_.custom_op_domain_list;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ class TensorrtExecutionProvider : public IExecutionProvider {
Status ReplayGraph() override;

private:
TensorrtExecutionProviderInfo info_;
mutable TensorrtExecutionProviderInfo info_;
bool external_stream_ = false;
cudaStream_t stream_ = nullptr;
int max_partition_iterations_ = 1000;
Expand Down
10 changes: 0 additions & 10 deletions onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,6 @@ struct Tensorrt_Provider : Provider {
info.device_id = device_id;
info.has_trt_options = false;

common::Status status = CreateTensorRTCustomOpDomainList(info);
if (!status.IsOK()) {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration.";
}

return std::make_shared<TensorrtProviderFactory>(info);
}

Expand Down Expand Up @@ -121,11 +116,6 @@ struct Tensorrt_Provider : Provider {
info.profile_opt_shapes = options.trt_profile_opt_shapes == nullptr ? "" : options.trt_profile_opt_shapes;
info.cuda_graph_enable = options.trt_cuda_graph_enable != 0;

common::Status status = CreateTensorRTCustomOpDomainList(info);
if (!status.IsOK()) {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration.";
}

return std::make_shared<TensorrtProviderFactory>(info);
}

Expand Down
30 changes: 28 additions & 2 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -613,9 +613,35 @@ common::Status InferenceSession::RegisterExecutionProvider(const std::shared_ptr
}

#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
// Create Custom Op if EP requests it
// Register Custom Op if EP requests it
std::vector<OrtCustomOpDomain*> custom_op_domains;
p_exec_provider->GetCustomOpDomainList(custom_op_domains);
std::vector<OrtCustomOpDomain*> candidate_custom_op_domains;
p_exec_provider->GetCustomOpDomainList(candidate_custom_op_domains);

auto registry_kernels = kernel_registry_manager_.GetKernelRegistriesByProviderType(p_exec_provider->Type());

// Register the custom op domain only if it has not been registered before
if (registry_kernels.empty()) {
custom_op_domains = candidate_custom_op_domains;
} else {
for (auto candidate_custom_op_domain : candidate_custom_op_domains) {
for (auto registry_kernel : registry_kernels) {
const auto& kernel_map = registry_kernel->GetKernelCreateMap();
bool need_resigter = true;
chilo-ms marked this conversation as resolved.
Show resolved Hide resolved
// If the kernel registry is the ep's custom op registry, we only need to check the first kernel,
// because all kernels in one kernel registry should have the same domain name.
for (auto iter = kernel_map.begin(); iter != kernel_map.end(); iter++) {
if (iter->second.kernel_def->Domain() == candidate_custom_op_domain->domain_) {
need_resigter = false;
chilo-ms marked this conversation as resolved.
Show resolved Hide resolved
break;
}
}
if (need_resigter) {
custom_op_domains.push_back(candidate_custom_op_domain);
}
}
}
}

if (!custom_op_domains.empty()) {
if (AddCustomOpDomains(custom_op_domains) != Status::OK()) {
Expand Down
43 changes: 25 additions & 18 deletions onnxruntime/core/session/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1623,6 +1623,28 @@

} // namespace onnxruntime

void AddTensorRTCustomOpDomainToSessionOption(OrtSessionOptions* options, std::string extra_plugin_lib_paths) {
auto is_already_in_domains = [&](std::string& domain_name, std::vector<OrtCustomOpDomain*>& domains) {
for (auto ptr : domains) {
if (domain_name == ptr->domain_) {
return true;
}
}
return false;
};

std::vector<OrtCustomOpDomain*> custom_op_domains;
onnxruntime::ProviderInfo_TensorRT& provider_info = onnxruntime::GetProviderInfo_TensorRT();
provider_info.GetTensorRTCustomOpDomainList(custom_op_domains, extra_plugin_lib_paths);

Check warning on line 1638 in onnxruntime/core/session/provider_bridge_ort.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/session/provider_bridge_ort.cc#L1638

Add #include <vector> for vector<> [build/include_what_you_use] [4]
Raw output
onnxruntime/core/session/provider_bridge_ort.cc:1638:  Add #include <vector> for vector<>  [build/include_what_you_use] [4]
for (auto ptr : custom_op_domains) {
if (!is_already_in_domains(ptr->domain_, options->custom_op_domains_)) {
options->custom_op_domains_.push_back(ptr);
} else {
LOGS_DEFAULT(WARNING) << "The custom op domain name " << ptr->domain_ << " is already in session option.";
}
}
}

ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Dnnl, _In_ OrtSessionOptions* options, int use_arena) {
API_IMPL_BEGIN
auto factory = onnxruntime::DnnlProviderFactoryCreator::Create(use_arena);
Expand All @@ -1644,13 +1666,8 @@

options->provider_factories.push_back(factory);

std::vector<OrtCustomOpDomain*> custom_op_domains;
std::string extra_plugin_lib_paths = onnxruntime::Env::Default().GetEnvironmentVar("trt_extra_plugin_lib_paths");
onnxruntime::ProviderInfo_TensorRT& provider_info = onnxruntime::GetProviderInfo_TensorRT();
provider_info.GetTensorRTCustomOpDomainList(custom_op_domains, extra_plugin_lib_paths);
for (auto ptr : custom_op_domains) {
options->custom_op_domains_.push_back(ptr);
}
AddTensorRTCustomOpDomainToSessionOption(options, extra_plugin_lib_paths);

return nullptr;
API_IMPL_END
Expand All @@ -1677,12 +1694,7 @@

options->provider_factories.push_back(factory);

std::vector<OrtCustomOpDomain*> custom_op_domains;
onnxruntime::ProviderInfo_TensorRT& provider_info = onnxruntime::GetProviderInfo_TensorRT();
provider_info.GetTensorRTCustomOpDomainList(custom_op_domains, "");
for (auto ptr : custom_op_domains) {
options->custom_op_domains_.push_back(ptr);
}
AddTensorRTCustomOpDomainToSessionOption(options, "");

return nullptr;
API_IMPL_END
Expand Down Expand Up @@ -1786,13 +1798,8 @@

options->provider_factories.push_back(factory);

std::vector<OrtCustomOpDomain*> custom_op_domains;
std::string extra_plugin_lib_paths = (tensorrt_options == nullptr || tensorrt_options->trt_extra_plugin_lib_paths == nullptr) ? "" : tensorrt_options->trt_extra_plugin_lib_paths;
onnxruntime::ProviderInfo_TensorRT& provider_info = onnxruntime::GetProviderInfo_TensorRT();
provider_info.GetTensorRTCustomOpDomainList(custom_op_domains, extra_plugin_lib_paths);
for (auto ptr : custom_op_domains) {
options->custom_op_domains_.push_back(ptr);
}
AddTensorRTCustomOpDomainToSessionOption(options, extra_plugin_lib_paths);

return nullptr;
API_IMPL_END
Expand Down
15 changes: 14 additions & 1 deletion onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,15 @@ const ROCMExecutionProviderInfo GetRocmExecutionProviderInfo(ProviderInfo_ROCM*
#ifdef USE_TENSORRT
void RegisterTensorRTPluginsAsCustomOps(PySessionOptions& so, const ProviderOptions& options) {
if (auto* tensorrt_provider_info = TryGetProviderInfo_TensorRT()) {
auto is_already_in_domains = [&](std::string& domain_name, std::vector<OrtCustomOpDomain*>& domains) {
for (auto ptr : domains) {
if (domain_name == ptr->domain_) {
return true;
}
}
return false;
};

std::string trt_extra_plugin_lib_paths = "";
const auto it = options.find("trt_extra_plugin_lib_paths");
if (it != options.end()) {
Expand All @@ -441,7 +450,11 @@ void RegisterTensorRTPluginsAsCustomOps(PySessionOptions& so, const ProviderOpti
std::vector<OrtCustomOpDomain*> domain_list;
tensorrt_provider_info->GetTensorRTCustomOpDomainList(domain_list, trt_extra_plugin_lib_paths);
for (auto ptr : domain_list) {
so.custom_op_domains_.push_back(ptr);
if (!is_already_in_domains(ptr->domain_, so.custom_op_domains_)) {
so.custom_op_domains_.push_back(ptr);
} else {
LOGS_DEFAULT(WARNING) << "The custom op domain name " << ptr->domain_ << " is already in session option.";
}
}
} else {
ORT_THROW("Please install TensorRT libraries as mentioned in the GPU requirements page, make sure they're in the PATH or LD_LIBRARY_PATH, and that your GPU is supported.");
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/test/python/onnxruntime_test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,14 @@ def test_set_providers_with_options(self):
self.assertEqual(option["trt_engine_cache_path"], str(engine_cache_path))
self.assertEqual(option["trt_force_sequential_engine_build"], "1")

from onnxruntime.capi import _pybind_state as C

session_options = C.get_default_session_options()

# TRT plugins registered as custom op domain should only be added once in session option regaldless of number of session creation
chilo-ms marked this conversation as resolved.
Show resolved Hide resolved
onnxrt.InferenceSession(get_name("mul_1.onnx"), session_options, providers=["TensorrtExecutionProvider"])
onnxrt.InferenceSession(get_name("mul_1.onnx"), session_options, providers=["TensorrtExecutionProvider"])

# We currently disable following test code since that not all test machines/GPUs have nvidia int8 capability

"""
Expand Down
Loading