From 8b309c73153cc80a2c3a16b838ab802aa9f80df1 Mon Sep 17 00:00:00 2001 From: zhangli Date: Thu, 9 Feb 2023 11:55:26 +0800 Subject: [PATCH 01/16] wip --- csrc/mmdeploy/CMakeLists.txt | 4 + csrc/mmdeploy/apis/CMakeLists.txt | 1 + csrc/mmdeploy/apis/cxx/mmdeploy/pipeline.hpp | 22 +- csrc/mmdeploy/core/logger.h | 2 +- csrc/mmdeploy/core/model.h | 3 +- csrc/mmdeploy/graph/inference.cpp | 34 +- csrc/mmdeploy/graph/task.cpp | 1 + csrc/mmdeploy/preprocess/transform_module.cpp | 124 ++++-- csrc/mmdeploy/triton/CMakeLists.txt | 115 ++++++ csrc/mmdeploy/triton/convert.cpp | 255 ++++++++++++ csrc/mmdeploy/triton/convert.h | 17 + csrc/mmdeploy/triton/instance_state.cpp | 368 ++++++++++++++++++ csrc/mmdeploy/triton/instance_state.h | 43 ++ csrc/mmdeploy/triton/mmdeploy.cpp | 165 ++++++++ csrc/mmdeploy/triton/mmdeploy_utils.h | 48 +++ csrc/mmdeploy/triton/model_state.cpp | 90 +++++ csrc/mmdeploy/triton/model_state.h | 45 +++ demo/python/triton_client.py | 135 +++++++ 18 files changed, 1415 insertions(+), 57 deletions(-) create mode 100644 csrc/mmdeploy/triton/CMakeLists.txt create mode 100644 csrc/mmdeploy/triton/convert.cpp create mode 100644 csrc/mmdeploy/triton/convert.h create mode 100644 csrc/mmdeploy/triton/instance_state.cpp create mode 100644 csrc/mmdeploy/triton/instance_state.h create mode 100644 csrc/mmdeploy/triton/mmdeploy.cpp create mode 100644 csrc/mmdeploy/triton/mmdeploy_utils.h create mode 100644 csrc/mmdeploy/triton/model_state.cpp create mode 100644 csrc/mmdeploy/triton/model_state.h create mode 100644 demo/python/triton_client.py diff --git a/csrc/mmdeploy/CMakeLists.txt b/csrc/mmdeploy/CMakeLists.txt index 6bfbd3a95a..4d05e7ee79 100644 --- a/csrc/mmdeploy/CMakeLists.txt +++ b/csrc/mmdeploy/CMakeLists.txt @@ -18,4 +18,8 @@ if (MMDEPLOY_BUILD_SDK) add_subdirectory(net) add_subdirectory(codebase) add_subdirectory(apis) + + if (TRITON_MMDEPLOY_BACKEND) + add_subdirectory(triton) + endif () endif () diff --git a/csrc/mmdeploy/apis/CMakeLists.txt b/csrc/mmdeploy/apis/CMakeLists.txt index 1ab877be90..9331f710d8 100644 --- a/csrc/mmdeploy/apis/CMakeLists.txt +++ b/csrc/mmdeploy/apis/CMakeLists.txt @@ -9,3 +9,4 @@ add_subdirectory(java) if (MMDEPLOY_BUILD_SDK_PYTHON_API) add_subdirectory(python) endif () + diff --git a/csrc/mmdeploy/apis/cxx/mmdeploy/pipeline.hpp b/csrc/mmdeploy/apis/cxx/mmdeploy/pipeline.hpp index e20ec6a224..806a420fd5 100644 --- a/csrc/mmdeploy/apis/cxx/mmdeploy/pipeline.hpp +++ b/csrc/mmdeploy/apis/cxx/mmdeploy/pipeline.hpp @@ -11,27 +11,28 @@ namespace mmdeploy { namespace cxx { -class Pipeline : public NonMovable { +class Pipeline : public UniqueHandle { public: Pipeline(const Value& config, const Context& context) { - mmdeploy_pipeline_t pipeline{}; - auto ec = mmdeploy_pipeline_create_v3((mmdeploy_value_t)&config, context, &pipeline); + auto ec = mmdeploy_pipeline_create_v3((mmdeploy_value_t)&config, context, &handle_); if (ec != MMDEPLOY_SUCCESS) { throw_exception(static_cast(ec)); } - pipeline_ = pipeline; } ~Pipeline() { - if (pipeline_) { - mmdeploy_pipeline_destroy(pipeline_); - pipeline_ = nullptr; + if (handle_) { + mmdeploy_pipeline_destroy(handle_); + handle_ = nullptr; } } + Pipeline(Pipeline&&) noexcept = default; + Pipeline& operator=(Pipeline&&) noexcept = default; + Value Apply(const Value& inputs) { mmdeploy_value_t tmp{}; - auto ec = mmdeploy_pipeline_apply(pipeline_, (mmdeploy_value_t)&inputs, &tmp); + auto ec = mmdeploy_pipeline_apply(handle_, (mmdeploy_value_t)&inputs, &tmp); if (ec != MMDEPLOY_SUCCESS) { throw_exception(static_cast(ec)); } @@ -50,7 +51,7 @@ class Pipeline : public NonMovable { if (ec != MMDEPLOY_SUCCESS) { throw_exception(static_cast(ec)); } - auto outputs = Apply(*reinterpret_cast(inputs)); + auto outputs = this->Apply(*reinterpret_cast(inputs)); mmdeploy_value_destroy(inputs); return outputs; @@ -65,9 +66,6 @@ class Pipeline : public NonMovable { } return rets; } - - private: - mmdeploy_pipeline_t pipeline_{}; }; } // namespace cxx diff --git a/csrc/mmdeploy/core/logger.h b/csrc/mmdeploy/core/logger.h index 73de4f0ee1..5d8947edf0 100644 --- a/csrc/mmdeploy/core/logger.h +++ b/csrc/mmdeploy/core/logger.h @@ -47,7 +47,7 @@ MMDEPLOY_API void SetLogger(spdlog::logger *logger); #endif #ifdef SPDLOG_LOGGER_CALL -#define MMDEPLOY_LOG(level, ...) SPDLOG_LOGGER_CALL(mmdeploy::GetLogger(), level, __VA_ARGS__) +#define MMDEPLOY_LOG(level, ...) SPDLOG_LOGGER_CALL(::mmdeploy::GetLogger(), level, __VA_ARGS__) #else #define MMDEPLOY_LOG(level, ...) mmdeploy::GetLogger()->log(level, __VA_ARGS__) #endif diff --git a/csrc/mmdeploy/core/model.h b/csrc/mmdeploy/core/model.h index fcb396d267..73a575f3d3 100644 --- a/csrc/mmdeploy/core/model.h +++ b/csrc/mmdeploy/core/model.h @@ -30,8 +30,9 @@ struct model_meta_info_t { struct deploy_meta_info_t { std::string version; + std::string task; std::vector models; - MMDEPLOY_ARCHIVE_MEMBERS(version, models); + MMDEPLOY_ARCHIVE_MEMBERS(version, task, models); }; class ModelImpl; diff --git a/csrc/mmdeploy/graph/inference.cpp b/csrc/mmdeploy/graph/inference.cpp index 8f5c8d1699..de9f632c60 100644 --- a/csrc/mmdeploy/graph/inference.cpp +++ b/csrc/mmdeploy/graph/inference.cpp @@ -14,24 +14,30 @@ using namespace framework; InferenceBuilder::InferenceBuilder(Value config) : Builder(std::move(config)) {} Result> InferenceBuilder::BuildImpl() { - auto& model_config = config_["params"]["model"]; - Model model; - if (model_config.is_any()) { - model = model_config.get(); - } else { - auto model_name = model_config.get(); - if (auto m = Maybe{config_} / "context" / "model" / model_name / identity{}) { - model = *m; + Value pipeline_config; + auto context = config_.value("context", Value(ValueType::kObject)); + const auto& params = config_["params"]; + if (params.contains("model")) { + auto& model_config = params["model"]; + Model model; + if (model_config.is_any()) { + model = model_config.get(); } else { - model = Model(model_name); + auto model_name = model_config.get(); + if (auto m = Maybe{config_} / "context" / "model" / model_name / identity{}) { + model = *m; + } else { + model = Model(model_name); + } } + OUTCOME_TRY(pipeline_config, model.ReadConfig("pipeline.json")); + context["model"] = std::move(model); + } else if (params.contains("pipeline")) { + assert(context.contains("model")); + auto model = context["model"].get(); + OUTCOME_TRY(pipeline_config, model.ReadConfig(params["pipeline"].get())); } - OUTCOME_TRY(auto pipeline_config, model.ReadConfig("pipeline.json")); - - auto context = config_.value("context", Value(ValueType::kObject)); - context["model"] = std::move(model); - if (context.contains("scope")) { auto name = config_.value("name", config_["type"].get()); auto scope = context["scope"].get_ref()->CreateScope(name); diff --git a/csrc/mmdeploy/graph/task.cpp b/csrc/mmdeploy/graph/task.cpp index 6cb6c4a798..6f876348d6 100644 --- a/csrc/mmdeploy/graph/task.cpp +++ b/csrc/mmdeploy/graph/task.cpp @@ -96,6 +96,7 @@ Result> TaskBuilder::BuildImpl() { task->is_thread_safe_ = config_.value("is_thread_safe", false); return std::move(task); } catch (const std::exception& e) { + MMDEPLOY_ERROR("unhandled exception: {}", e.what()); MMDEPLOY_ERROR("error parsing config: {}", config_); return nullptr; } diff --git a/csrc/mmdeploy/preprocess/transform_module.cpp b/csrc/mmdeploy/preprocess/transform_module.cpp index b718843ea8..269c170edf 100644 --- a/csrc/mmdeploy/preprocess/transform_module.cpp +++ b/csrc/mmdeploy/preprocess/transform_module.cpp @@ -10,46 +10,112 @@ namespace mmdeploy { class TransformModule { public: - ~TransformModule(); - TransformModule(TransformModule&&) noexcept; + ~TransformModule() = default; + TransformModule(TransformModule&&) noexcept = default; - explicit TransformModule(const Value& args); - Result operator()(const Value& input); + explicit TransformModule(const Value& args) { + const auto type = "Compose"; + auto creator = gRegistry().Get(type); + if (!creator) { + MMDEPLOY_ERROR("Unable to find Transform creator: {}. Available transforms: {}", type, + gRegistry().List()); + throw_exception(eEntryNotFound); + } + auto cfg = args; + if (cfg.contains("device")) { + MMDEPLOY_WARN("force using device: {}", cfg["device"].get()); + auto device = Device(cfg["device"].get()); + cfg["context"]["device"] = device; + cfg["context"]["stream"] = Stream::GetDefault(device); + } + transform_ = creator->Create(cfg); + } + + Result operator()(const Value& input) { + auto data = input; + OUTCOME_TRY(transform_->Apply(data)); + return data; + } private: std::unique_ptr transform_; }; -TransformModule::~TransformModule() = default; +MMDEPLOY_REGISTER_FACTORY_FUNC(Module, (Transform, 0), [](const Value& config) { + return CreateTask(TransformModule{config}); +}); -TransformModule::TransformModule(TransformModule&&) noexcept = default; +#if 0 +class Preload { + public: + explicit Preload(const Value& args) { + const auto type = "Compose"; + auto creator = gRegistry().Get(type); + if (!creator) { + MMDEPLOY_ERROR("Unable to find Transform creator: {}. Available transforms: {}", type, + gRegistry().List()); + throw_exception(eEntryNotFound); + } + auto cfg = args; + if (cfg.contains("device")) { + MMDEPLOY_WARN("force using device: {}", cfg["device"].get()); + auto device = Device(cfg["device"].get()); + cfg["context"]["device"] = device; + cfg["context"]["stream"] = Stream::GetDefault(device); + } + const auto& ctx = cfg["context"]; + ctx["device"].get_to(device_); + ctx["stream"].get_to(stream_); + } -TransformModule::TransformModule(const Value& args) { - const auto type = "Compose"; - auto creator = gRegistry().Get(type); - if (!creator) { - MMDEPLOY_ERROR("Unable to find Transform creator: {}. Available transforms: {}", type, - gRegistry().List()); - throw_exception(eEntryNotFound); + Result operator()(const Value& input) { + auto data = input; + if (device_.is_device()) { + bool need_sync = false; + OUTCOME_TRY(Process(data, need_sync)); + MMDEPLOY_ERROR("need_sync = {}", need_sync); + MMDEPLOY_ERROR("{}", data); + if (need_sync) { + OUTCOME_TRY(stream_.Wait()); + } + } + return data; } - auto cfg = args; - if (cfg.contains("device")) { - MMDEPLOY_WARN("force using device: {}", cfg["device"].get()); - auto device = Device(cfg["device"].get()); - cfg["context"]["device"] = device; - cfg["context"]["stream"] = Stream::GetDefault(device); + + Result Process(Value& item, bool& need_sync) { + if (item.is_any()) { + auto& mat = item.get_ref(); + if (mat.device().is_host()) { + Mat tmp(mat.height(), mat.width(), mat.pixel_format(), mat.type(), device_); + OUTCOME_TRY(stream_.Copy(mat.buffer(), tmp.buffer(), mat.byte_size())); + mat = tmp; + need_sync |= true; + } + } else if (item.is_any()) { + auto& ten = item.get_ref(); + if (ten.device().is_host()) { + TensorDesc desc = ten.desc(); + desc.device = device_; + Tensor tmp(desc); + OUTCOME_TRY(stream_.Copy(ten.buffer(), tmp.buffer(), ten.byte_size())); + ten = tmp; + need_sync |= true; + } + } else if (item.is_array() || item.is_object()) { + for (auto& child : item) { + OUTCOME_TRY(Process(child, need_sync)); + } + } + return success(); } - transform_ = creator->Create(cfg); -} -Result TransformModule::operator()(const Value& input) { - auto data = input; - OUTCOME_TRY(transform_->Apply(data)); - return data; -} + private: + Device device_; + Stream stream_; +}; -MMDEPLOY_REGISTER_FACTORY_FUNC(Module, (Transform, 0), [](const Value& config) { - return CreateTask(TransformModule{config}); -}); +MMDEPLOY_REGISTER_FACTORY_FUNC(Module, (Preload, 0), + [](const Value& config) { return CreateTask(Preload{config}); }); +#endif } // namespace mmdeploy diff --git a/csrc/mmdeploy/triton/CMakeLists.txt b/csrc/mmdeploy/triton/CMakeLists.txt new file mode 100644 index 0000000000..f5b57eb4b3 --- /dev/null +++ b/csrc/mmdeploy/triton/CMakeLists.txt @@ -0,0 +1,115 @@ +# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cmake_minimum_required(VERSION 3.17) + +project(tritonmmdeploybackend LANGUAGES C CXX) + +# +# Options +# +# Must include options required for this project as well as any +# projects included in this one by FetchContent. +# +# GPU support is disabled by default because recommended backend +# doesn't use GPUs. +# + +set(TRITON_COMMON_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/common repo") +set(TRITON_CORE_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/core repo") +set(TRITON_BACKEND_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/backend repo") + + +# +# Dependencies +# +# FetchContent requires us to include the transitive closure of all +# repos that we depend on so that we can override the tags. +# +include(FetchContent) + +FetchContent_Declare( + repo-common + GIT_REPOSITORY https://github.com/triton-inference-server/common.git + GIT_TAG ${TRITON_COMMON_REPO_TAG} + GIT_SHALLOW ON +) +FetchContent_Declare( + repo-core + GIT_REPOSITORY https://github.com/triton-inference-server/core.git + GIT_TAG ${TRITON_CORE_REPO_TAG} + GIT_SHALLOW ON +) +FetchContent_Declare( + repo-backend + GIT_REPOSITORY https://github.com/triton-inference-server/backend.git + GIT_TAG ${TRITON_BACKEND_REPO_TAG} + GIT_SHALLOW ON +) +FetchContent_MakeAvailable(repo-common repo-core repo-backend) + +add_library(triton-mmdeploy-backend SHARED + model_state.cpp + instance_state.cpp + convert.cpp + mmdeploy.cpp) + +target_include_directories(triton-mmdeploy-backend PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) + +target_compile_options( + triton-mmdeploy-backend PRIVATE + $<$,$,$>: + -Wall -Wno-unused-parameter -Wno-type-limits -Werror> + $<$:/Wall /D_WIN32_WINNT=0x0A00 /EHsc> +) + +target_link_libraries( + triton-mmdeploy-backend + PRIVATE + triton-core-serverapi # from repo-core + triton-core-backendapi # from repo-core + triton-core-serverstub # from repo-core + triton-backend-utils # from repo-backend +) + +target_link_libraries(triton-mmdeploy-backend PRIVATE mmdeploy) + +mmdeploy_export(triton-mmdeploy-backend) + + +if (WIN32) + set_target_properties( + triton-mmdeploy-backend PROPERTIES + POSITION_INDEPENDENT_CODE ON + OUTPUT_NAME triton_mmdeploy + ) +else () + set_target_properties( + triton-mmdeploy-backend PROPERTIES + POSITION_INDEPENDENT_CODE ON + OUTPUT_NAME triton_mmdeploy + ) +endif () diff --git a/csrc/mmdeploy/triton/convert.cpp b/csrc/mmdeploy/triton/convert.cpp new file mode 100644 index 0000000000..d1f44e0111 --- /dev/null +++ b/csrc/mmdeploy/triton/convert.cpp @@ -0,0 +1,255 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "convert.h" + +#include + +#include "mmdeploy/archive/value_archive.h" +#include "mmdeploy/codebase/mmaction/mmaction.h" +#include "mmdeploy/codebase/mmcls/mmcls.h" +#include "mmdeploy/codebase/mmdet/mmdet.h" +#include "mmdeploy/codebase/mmedit/mmedit.h" +#include "mmdeploy/codebase/mmocr/mmocr.h" +#include "mmdeploy/codebase/mmpose/mmpose.h" +#include "mmdeploy/codebase/mmrotate/mmrotate.h" +#include "mmdeploy/codebase/mmseg/mmseg.h" +#include "mmdeploy/core/utils/formatter.h" +#include "triton/backend/backend_common.h" + +namespace mmdeploy { + +namespace core = framework; + +core::Tensor Mat2Tensor(core::Mat mat) { + TensorDesc desc{mat.device(), mat.type(), {mat.height(), mat.width(), mat.channel()}, ""}; + return {desc, mat.buffer()}; +} + +} // namespace mmdeploy + +namespace triton::backend::mmdeploy { + +using Value = ::mmdeploy::Value; +using Tensor = ::mmdeploy::core::Tensor; +using TensorDesc = ::mmdeploy::core::TensorDesc; + +void ConvertClassifications(const Value& item, std::vector& tensors) { + ::mmdeploy::mmcls::Labels classify_outputs; + ::mmdeploy::from_value(item, classify_outputs); + Tensor labels(TensorDesc{::mmdeploy::Device(0), + ::mmdeploy::DataType::kINT32, + {static_cast(classify_outputs.size())}, + "labels"}); + Tensor scores(TensorDesc{::mmdeploy::Device(0), + ::mmdeploy::DataType::kFLOAT, + {static_cast(classify_outputs.size())}, + "scores"}); + auto labels_data = labels.data(); + auto scores_data = scores.data(); + for (const auto& c : classify_outputs) { + *labels_data++ = c.label_id; + *scores_data++ = c.score; + } + tensors.push_back(std::move(labels)); + tensors.push_back(std::move(scores)); +} + +void ConvertDetections(const Value& item, std::vector& tensors) { + ::mmdeploy::mmdet::Detections detections; + ::mmdeploy::from_value(item, detections); + Tensor bboxes(TensorDesc{::mmdeploy::Device(0), + ::mmdeploy::DataType::kFLOAT, + {static_cast(detections.size()), 5}, + "bboxes"}); + Tensor labels(TensorDesc{bboxes.device(), + ::mmdeploy::DataType::kINT32, + {static_cast(detections.size())}, + "labels"}); + auto bboxes_data = bboxes.data(); + auto labels_data = labels.data(); + for (const auto& det : detections) { + for (const auto& x : det.bbox) { + *bboxes_data++ = x; + } + *bboxes_data++ = det.score; + *labels_data++ = det.label_id; + } + tensors.push_back(std::move(bboxes)); + tensors.push_back(std::move(labels)); +} + +void ConvertSegmentation(const Value& item, std::vector& tensors) { + ::mmdeploy::mmseg::SegmentorOutput seg; + ::mmdeploy::from_value(item, seg); + if (seg.score.size()) { + auto desc = seg.score.desc(); + desc.name = "score"; + tensors.emplace_back(desc, seg.score.buffer()); + } + if (seg.mask.size()) { + auto desc = seg.mask.desc(); + desc.name = "mask"; + tensors.emplace_back(desc, seg.mask.buffer()); + } +} + +void ConvertMats(const Value& item, std::vector& tensors) { + ::mmdeploy::mmedit::RestorerOutput restoration; + ::mmdeploy::from_value(item, restoration); + tensors.push_back(::mmdeploy::Mat2Tensor(restoration)); +} + +void ConvertTextDetections(const Value& item, std::vector& tensors) { + ::mmdeploy::mmocr::TextDetections detections; + ::mmdeploy::from_value(item, detections); + Tensor bboxes(TensorDesc{::mmdeploy::Device(0), + ::mmdeploy::DataType::kFLOAT, + {static_cast(detections.size()), 9}, + "dets"}); + auto bboxes_data = bboxes.data(); + for (const auto& det : detections) { + bboxes_data = std::copy(det.bbox.begin(), det.bbox.end(), bboxes_data); + *bboxes_data++ = det.score; + } + tensors.push_back(std::move(bboxes)); +} + +void ConvertTextRecognitions(const Value& item, std::vector& tensors, + std::vector& strings) { + std::vector<::mmdeploy::mmocr::TextRecognition> recognitions; + ::mmdeploy::from_value(item, recognitions); + Tensor texts(TensorDesc{::mmdeploy::Device(0), + ::mmdeploy::DataType::kINT32, + {static_cast(recognitions.size())}, + "text"}); + Tensor score(TensorDesc{::mmdeploy::Device(0), + ::mmdeploy::DataType::kFLOAT, + {static_cast(recognitions.size())}, + "text_score"}); + auto text_data = texts.data(); + auto score_data = score.data(); + for (size_t text_id = 0; text_id < recognitions.size(); ++text_id) { + text_data[text_id] = static_cast(strings.size()); + strings.push_back(recognitions[text_id].text); + auto& s = recognitions[text_id].score; + if (!s.empty()) { + score_data[text_id] = std::accumulate(s.begin(), s.end(), 0.f) / static_cast(s.size()); + } else { + score_data[text_id] = 0; + } + } + tensors.push_back(std::move(texts)); + tensors.push_back(std::move(score)); +} + +void ConvertPreprocess(const Value& item, std::vector& tensors, + std::vector& strings) { + Value::Object img_metas; + for (auto it = item.begin(); it != item.end(); ++it) { + if (it->is_any()) { + auto tensor = it->get(); + auto desc = tensor.desc(); + desc.name = it.key(); + tensors.emplace_back(desc, tensor.buffer()); + } else if (!it->is_any<::mmdeploy::framework::Mat>()) { + img_metas.insert({it.key(), *it}); + } + } + auto index = static_cast(strings.size()); + strings.push_back(::mmdeploy::format_value(img_metas)); + Tensor img_meta_tensor( + TensorDesc{::mmdeploy::Device(0), ::mmdeploy::DataType::kINT32, {1}, "img_metas"}); + *img_meta_tensor.data() = index; + tensors.push_back(std::move(img_meta_tensor)); +} + +void ConvertPoseDetections(const Value& item, std::vector& tensors) { + ::mmdeploy::mmpose::PoseDetectorOutput detections; + ::mmdeploy::from_value(item, detections); + Tensor pts(TensorDesc{::mmdeploy::Device(0), + ::mmdeploy::DataType::kFLOAT, + {static_cast(detections.key_points.size()), 3}, + "keypoints"}); + auto pts_data = pts.data(); + for (const auto& p : detections.key_points) { + *pts_data++ = p.bbox[0]; + *pts_data++ = p.bbox[1]; + *pts_data++ = p.score; + } + tensors.push_back({std::move(pts)}); +} + +void ConvertRotatedDetections(const Value& item, std::vector& tensors) { + ::mmdeploy::mmrotate::RotatedDetectorOutput detections; + ::mmdeploy::from_value(item, detections); + Tensor bboxes(TensorDesc{::mmdeploy::Device(0), + ::mmdeploy::DataType::kFLOAT, + {static_cast(detections.detections.size()), 5}, + "bboxes"}); + Tensor labels(TensorDesc{::mmdeploy::Device(0), + ::mmdeploy::DataType::kINT32, + {static_cast(detections.detections.size())}, + "labels"}); + auto bboxes_data = bboxes.data(); + auto labels_data = labels.data(); + for (const auto& det : detections.detections) { + bboxes_data = std::copy(det.rbbox.begin(), det.rbbox.end(), bboxes_data); + *bboxes_data++ = det.score; + *labels_data++ = det.label_id; + } + tensors.push_back(std::move(bboxes)); + tensors.push_back(std::move(labels)); +} + +std::vector> ConvertOutputToTensors(const std::string& type, + int32_t request_count, const Value& output, + std::vector& strings) { + std::vector> tensors(request_count); + if (type == "Preprocess") { + for (int i = 0; i < request_count; ++i) { + ConvertPreprocess(output.front()[i], tensors[i], strings); + } + } else if (type == "Classifier") { + for (int i = 0; i < request_count; ++i) { + ConvertClassifications(output.front()[i], tensors[i]); + } + } else if (type == "Detector") { + for (int i = 0; i < request_count; ++i) { + ConvertDetections(output.front()[i], tensors[i]); + } + } else if (type == "Segmentor") { + for (int i = 0; i < request_count; ++i) { + ConvertSegmentation(output.front()[i], tensors[i]); + } + } else if (type == "Restorer") { + for (int i = 0; i < request_count; ++i) { + ConvertMats(output.front()[i], tensors[i]); + } + } else if (type == "TextDetector") { + for (int i = 0; i < request_count; ++i) { + ConvertTextDetections(output.front()[i], tensors[i]); + } + } else if (type == "TextRecognizer") { + for (int i = 0; i < request_count; ++i) { + ConvertTextRecognitions(output.front(), tensors[i], strings); + } + } else if (type == "PoseDetector") { + for (int i = 0; i < request_count; ++i) { + ConvertPoseDetections(output.front()[i], tensors[i]); + } + } else if (type == "RotatedDetector") { + for (int i = 0; i < request_count; ++i) { + ConvertRotatedDetections(output.front()[i], tensors[i]); + } + } else if (type == "TextOCR") { + for (int i = 0; i < request_count; ++i) { + ConvertTextDetections(output[0][i], tensors[i]); + ConvertTextRecognitions(output[1][i], tensors[i], strings); + } + } else { + LOG_MESSAGE(TRITONSERVER_LOG_ERROR, ("Unsupported type: " + type).c_str()); + } + return tensors; +} + +} // namespace triton::backend::mmdeploy diff --git a/csrc/mmdeploy/triton/convert.h b/csrc/mmdeploy/triton/convert.h new file mode 100644 index 0000000000..23e593efce --- /dev/null +++ b/csrc/mmdeploy/triton/convert.h @@ -0,0 +1,17 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#ifndef MMDEPLOY_CONVERT_H +#define MMDEPLOY_CONVERT_H + +#include "mmdeploy/core/tensor.h" +#include "mmdeploy/core/value.h" + +namespace triton::backend::mmdeploy { + +std::vector> ConvertOutputToTensors( + const std::string& type, int32_t request_count, const ::mmdeploy::Value& output, + std::vector& strings); + +} + +#endif // MMDEPLOY_CONVERT_H diff --git a/csrc/mmdeploy/triton/instance_state.cpp b/csrc/mmdeploy/triton/instance_state.cpp new file mode 100644 index 0000000000..3a4a026472 --- /dev/null +++ b/csrc/mmdeploy/triton/instance_state.cpp @@ -0,0 +1,368 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "instance_state.h" + +#include + +#include "convert.h" +#include "json.hpp" +#include "mmdeploy/archive/json_archive.h" +#include "mmdeploy/core/mat.h" +#include "mmdeploy/core/utils/formatter.h" +#include "mmdeploy_utils.h" + +namespace triton::backend::mmdeploy { + +TRITONSERVER_Error* ModelInstanceState::Create(ModelState* model_state, + TRITONBACKEND_ModelInstance* triton_model_instance, + ModelInstanceState** state) { + try { + *state = new ModelInstanceState(model_state, triton_model_instance); + } catch (const BackendModelInstanceException& ex) { + RETURN_ERROR_IF_TRUE(ex.err_ == nullptr, TRITONSERVER_ERROR_INTERNAL, + std::string("unexpected nullptr in BackendModelInstanceException")); + RETURN_IF_ERROR(ex.err_); + } + + return nullptr; // success +} +ModelInstanceState::ModelInstanceState(ModelState* model_state, + TRITONBACKEND_ModelInstance* triton_model_instance) + : BackendModelInstance(model_state, triton_model_instance), + model_state_(model_state), + pipeline_(model_state_->CreatePipeline(Kind(), DeviceId())) {} + +// TRITON DIR MMDeploy +// (Tensor, PixFmt, Region) -> (Mat , Region) +// [Tensor] <- ([Tensor], Meta ) +// [Tensor] -> ([Tensor], Meta ) +// [Tensor] <- [Value] + +TRITONSERVER_Error* ModelInstanceState::Execute(TRITONBACKEND_Request** requests, + uint32_t request_count) { + // Collect various timestamps during the execution of this batch or + // requests. These values are reported below before returning from + // the function. + + uint64_t exec_start_ns = 0; + SET_TIMESTAMP(exec_start_ns); + + ModelState* model_state = StateForModel(); + + // 'responses' is initialized as a parallel array to 'requests', + // with one TRITONBACKEND_Response object for each + // TRITONBACKEND_Request object. If something goes wrong while + // creating these response objects, the backend simply returns an + // error from TRITONBACKEND_ModelInstanceExecute, indicating to + // Triton that this backend did not create or send any responses and + // so it is up to Triton to create and send an appropriate error + // response for each request. RETURN_IF_ERROR is one of several + // useful macros for error handling that can be found in + // backend_common.h. + + std::vector responses; + responses.reserve(request_count); + for (uint32_t r = 0; r < request_count; ++r) { + TRITONBACKEND_Request* request = requests[r]; + TRITONBACKEND_Response* response; + RETURN_IF_ERROR(TRITONBACKEND_ResponseNew(&response, request)); + responses.push_back(response); + } + + BackendInputCollector collector(requests, request_count, &responses, + model_state->TritonMemoryManager(), false /* pinned_enabled */, + nullptr /* stream*/); + + // To instruct ProcessTensor to "gather" the entire batch of input + // tensors into a single contiguous buffer in CPU memory, set the + // "allowed input types" to be the CPU ones (see tritonserver.h in + // the triton-inference-server/core repo for allowed memory types). + std::vector> allowed_input_types = { + {TRITONSERVER_MEMORY_CPU_PINNED, 0}, {TRITONSERVER_MEMORY_CPU, 0}}; + + std::vector input_buffers(model_state->input_names().size()); + + std::vector> collectors(request_count); + std::vector> response_vecs(request_count); + + ::mmdeploy::Value::Array image_and_metas_array; + ::mmdeploy::Value::Array input_tensors_array; + + // Setting input data + for (uint32_t request_index = 0; request_index < request_count; ++request_index) { + ::mmdeploy::Value::Object input_tensors; + ::mmdeploy::Value::Object image_and_metas; + response_vecs[request_index] = {responses[request_index]}; + collectors[request_index] = std::make_unique( + &requests[request_index], 1, &response_vecs[request_index], + model_state->TritonMemoryManager(), false, nullptr); + + for (size_t input_id = 0; input_id < model_state->input_names().size(); ++input_id) { + const auto& input_name = model_state->input_names()[input_id]; + // Get input shape + TRITONBACKEND_Input* input{}; + RETURN_IF_ERROR( + TRITONBACKEND_RequestInput(requests[request_index], input_name.c_str(), &input)); + TRITONSERVER_DataType data_type{}; + const int64_t* dims{}; + uint32_t dims_count{}; + RETURN_IF_ERROR(TRITONBACKEND_InputProperties(input, nullptr, &data_type, &dims, &dims_count, + nullptr, nullptr)); + if (data_type != TRITONSERVER_TYPE_BYTES) { + // Collect input buffer + const char* buffer{}; + size_t buffer_size{}; + TRITONSERVER_MemoryType memory_type{}; + int64_t memory_type_id{}; + RETURN_IF_ERROR(collectors[request_index]->ProcessTensor( + input_name.c_str(), nullptr, 0, allowed_input_types, &buffer, &buffer_size, + &memory_type, &memory_type_id)); + + ::mmdeploy::framework::Device device(0); + if (memory_type == TRITONSERVER_MEMORY_GPU) { + device = ::mmdeploy::framework::Device("cuda", static_cast(memory_type_id)); + } + if (model_state->input_formats()[request_index] == "FORMAT_NHWC") { + // Construct Mat from shape & buffer + ::mmdeploy::framework::Mat mat( + static_cast(dims[0]), static_cast(dims[1]), ::mmdeploy::PixelFormat::kBGR, + ::mmdeploy::DataType::kINT8, + std::shared_ptr(const_cast(buffer), [](auto) {}), device); + image_and_metas.insert({input_name, mat}); + } else { + ::mmdeploy::framework::Tensor tensor( + ::mmdeploy::framework::TensorDesc{ + device, ::mmdeploy::DataType::kFLOAT, + ::mmdeploy::framework::TensorShape(dims, dims + dims_count), input_name}, + std::shared_ptr(const_cast(buffer), [](auto) {})); + input_tensors.insert({input_name, std::move(tensor)}); + } + } else { + ::mmdeploy::Value value; + GetStringInputTensor(input, dims, dims_count, value); + assert(value.is_array()); + ::mmdeploy::update(image_and_metas, value.front().object(), 2); + } + } + + if (!input_tensors.empty()) { + input_tensors_array.emplace_back(std::move(input_tensors)); + } + if (!image_and_metas.empty()) { + image_and_metas_array.emplace_back(std::move(image_and_metas)); + } + + // Input from device memory is not supported yet + const bool need_cuda_input_sync = collectors[request_index]->Finalize(); + if (need_cuda_input_sync) { +#if TRITON_ENABLE_GPU + cudaStreamSynchronize(CudaStream()); +#else + LOG_MESSAGE(TRITONSERVER_LOG_ERROR, + "mmdeploy backend: unexpected CUDA sync required by collector"); +#endif + } + } + + ::mmdeploy::Value input_args; + if (!image_and_metas_array.empty()) { + input_args.push_back(std::move(image_and_metas_array)); + } + if (!input_tensors_array.empty()) { + input_args.push_back(std::move(input_tensors_array)); + } + + MMDEPLOY_ERROR("input: {}", input_args); + + uint64_t compute_start_ns = 0; + SET_TIMESTAMP(compute_start_ns); + + ::mmdeploy::Value outputs = pipeline_.Apply(input_args); + + std::vector strings; + auto output_tensors = + ConvertOutputToTensors(model_state->task_type(), request_count, outputs, strings); + + uint64_t compute_end_ns = 0; + SET_TIMESTAMP(compute_end_ns); + + std::vector> responders(request_count); + MMDEPLOY_ERROR("request_count {}", request_count); + for (uint32_t request_index = 0; request_index < request_count; ++request_index) { + responders[request_index] = std::make_unique( + &requests[request_index], 1, &response_vecs[request_index], + model_state->TritonMemoryManager(), false, false, nullptr); + for (const auto& name : model_state->output_names()) { + MMDEPLOY_ERROR("name {}", name); + } + for (size_t output_id = 0; output_id < model_state->output_names().size(); ++output_id) { + auto output_name = model_state->output_names()[output_id]; + MMDEPLOY_ERROR("output name {}", output_name); + auto output_data_type = model_state->output_data_types()[output_id]; + for (const auto& tensor : output_tensors[request_index]) { + if (tensor.name() == output_name) { + if (output_data_type != TRITONSERVER_TYPE_BYTES) { + auto shape = tensor.shape(); + MMDEPLOY_ERROR("name {}, shape {}", tensor.name(), shape); + auto memory_type = TRITONSERVER_MEMORY_CPU; + int64_t memory_type_id = 0; + if (not tensor.device().is_host()) { + memory_type = TRITONSERVER_MEMORY_GPU; + memory_type_id = tensor.device().device_id(); + } + responders[request_index]->ProcessTensor( + tensor.name(), ConvertDataType(tensor.data_type()), shape, tensor.data(), + memory_type, memory_type_id); + } else { + RETURN_IF_ERROR(SetStringOutputTensor(tensor, strings, responses[request_index])); + } + break; + } + } + } + + const bool need_cuda_output_sync = responders[request_index]->Finalize(); + if (need_cuda_output_sync) { +#if TRITON_ENABLE_GPU + cudaStreamSynchronize(CudaStream()); +#else + LOG_MESSAGE(TRITONSERVER_LOG_ERROR, + "mmdeploy backend: unexpected CUDA sync required by responder"); +#endif + } + } + + // Send all the responses that haven't already been sent because of + // an earlier error. + for (auto& response : responses) { + if (response != nullptr) { + LOG_IF_ERROR( + TRITONBACKEND_ResponseSend(response, TRITONSERVER_RESPONSE_COMPLETE_FINAL, nullptr), + "failed to send response"); + } + } + + uint64_t exec_end_ns = 0; + SET_TIMESTAMP(exec_end_ns); + +#ifdef TRITON_ENABLE_STATS + // For batch statistics need to know the total batch size of the + // requests. This is not necessarily just the number of requests, + // because if the model supports batching then any request can be a + // batched request itself. + size_t total_batch_size = request_count; +#else + (void)exec_start_ns; + (void)exec_end_ns; + (void)compute_start_ns; + (void)compute_end_ns; +#endif // TRITON_ENABLE_STATS + + // Report statistics for each request, and then release the request. + for (uint32_t r = 0; r < request_count; ++r) { + auto& request = requests[r]; + +#ifdef TRITON_ENABLE_STATS + LOG_IF_ERROR(TRITONBACKEND_ModelInstanceReportStatistics( + TritonModelInstance(), request, (responses[r] != nullptr) /* success */, + exec_start_ns, compute_start_ns, compute_end_ns, exec_end_ns), + "failed reporting request statistics"); +#endif // TRITON_ENABLE_STATS + + LOG_IF_ERROR(TRITONBACKEND_RequestRelease(request, TRITONSERVER_REQUEST_RELEASE_ALL), + "failed releasing request"); + } + +#ifdef TRITON_ENABLE_STATS + // Report batch statistics. + LOG_IF_ERROR(TRITONBACKEND_ModelInstanceReportBatchStatistics( + TritonModelInstance(), total_batch_size, exec_start_ns, compute_start_ns, + compute_end_ns, exec_end_ns), + "failed reporting batch request statistics"); +#endif // TRITON_ENABLE_STATS + + return nullptr; // success +} + +TRITONSERVER_Error* ModelInstanceState::GetStringInputTensor(TRITONBACKEND_Input* input, + const int64_t* dims, + uint32_t dims_count, + ::mmdeploy::Value& value) { + ::mmdeploy::Value::Array array; + const char* buffer{}; + uint64_t buffer_byte_size{}; + TRITONSERVER_MemoryType memory_type = TRITONSERVER_MEMORY_CPU; + int64_t memory_type_id{}; + RETURN_IF_ERROR(TRITONBACKEND_InputBuffer(input, 0, reinterpret_cast(&buffer), + &buffer_byte_size, &memory_type, &memory_type_id)); + auto count = std::accumulate(dims, dims + dims_count, 1LL, std::multiplies<>{}); + size_t offset = 0; + for (int64_t i = 0; i < count; ++i) { + // read string length + if (offset + sizeof(uint32_t) > buffer_byte_size) { + break; + } + auto length = *reinterpret_cast(buffer + offset); + offset += sizeof(uint32_t); + // read string data + if (offset + length > buffer_byte_size) { + break; + } + std::string data(buffer + offset, buffer + offset + length); + offset += length; + // deserialize from json string + auto data_value = ::mmdeploy::from_json<::mmdeploy::Value>(nlohmann::json::parse(data)); + array.push_back(std::move(data_value)); + } + value = std::move(array); + return nullptr; +} + +TRITONSERVER_Error* ModelInstanceState::SetStringOutputTensor( + const ::mmdeploy::framework::Tensor& tensor, const std::vector& strings, + TRITONBACKEND_Response* response) { + assert(tensor.data_type() == ::mmdeploy::DataType::kINT32); + TRITONSERVER_Error* err{}; + TRITONBACKEND_Output* response_output{}; + err = TRITONBACKEND_ResponseOutput(response, &response_output, tensor.name(), + TRITONSERVER_TYPE_BYTES, tensor.shape().data(), + tensor.shape().size()); + if (!err) { + size_t data_byte_size{}; + auto index_data = tensor.data(); + auto size = tensor.size(); + for (int64_t j = 0; j < size; ++j) { + data_byte_size += strings[index_data[j]].size(); + } + auto expected_byte_size = data_byte_size + sizeof(uint32_t) * size; + void* buffer{}; + TRITONSERVER_MemoryType actual_memory_type = TRITONSERVER_MEMORY_CPU; + int64_t actual_memory_type_id = 0; + err = TRITONBACKEND_OutputBuffer(response_output, &buffer, expected_byte_size, + &actual_memory_type, &actual_memory_type_id); + if (!err) { + bool cuda_used = false; + size_t copied_byte_size = 0; + for (int64_t j = 0; j < size; ++j) { + auto len = static_cast(strings[index_data[j]].size()); + err = CopyBuffer(tensor.name(), TRITONSERVER_MEMORY_CPU, 0, actual_memory_type, + actual_memory_type_id, sizeof(uint32_t), &len, + static_cast(buffer) + copied_byte_size, nullptr, &cuda_used); + if (err) { + break; + } + copied_byte_size += sizeof(uint32_t); + err = CopyBuffer(tensor.name(), TRITONSERVER_MEMORY_CPU, 0, actual_memory_type, + actual_memory_type_id, len, strings[index_data[j]].data(), + static_cast(buffer) + copied_byte_size, nullptr, &cuda_used); + if (err) { + break; + } + copied_byte_size += len; + } + } + } + return err; +} + +} // namespace triton::backend::mmdeploy diff --git a/csrc/mmdeploy/triton/instance_state.h b/csrc/mmdeploy/triton/instance_state.h new file mode 100644 index 0000000000..bb0d9e71b6 --- /dev/null +++ b/csrc/mmdeploy/triton/instance_state.h @@ -0,0 +1,43 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#ifndef MMDEPLOY_INSTANCE_STATE_H +#define MMDEPLOY_INSTANCE_STATE_H + +#include "mmdeploy/core/tensor.h" +#include "model_state.h" +#include "triton/backend/backend_input_collector.h" +#include "triton/backend/backend_model_instance.h" +#include "triton/backend/backend_output_responder.h" + +namespace triton::backend::mmdeploy { + +class ModelInstanceState : public BackendModelInstance { + public: + static TRITONSERVER_Error* Create(ModelState* model_state, + TRITONBACKEND_ModelInstance* triton_model_instance, + ModelInstanceState** state); + ~ModelInstanceState() override = default; + + // Get the state of the model that corresponds to this instance. + ModelState* StateForModel() const { return model_state_; } + + TRITONSERVER_Error* Execute(TRITONBACKEND_Request** requests, uint32_t request_count); + + TRITONSERVER_Error* GetStringInputTensor(TRITONBACKEND_Input* input, const int64_t* dims, + uint32_t dims_count, ::mmdeploy::Value& value); + + TRITONSERVER_Error* SetStringOutputTensor(const ::mmdeploy::framework::Tensor& tensor, + const std::vector& strings, + TRITONBACKEND_Response* response); + + private: + ModelInstanceState(ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance); + + private: + ModelState* model_state_; + ::mmdeploy::Pipeline pipeline_; +}; + +} // namespace triton::backend::mmdeploy + +#endif // MMDEPLOY_INSTANCE_STATE_H diff --git a/csrc/mmdeploy/triton/mmdeploy.cpp b/csrc/mmdeploy/triton/mmdeploy.cpp new file mode 100644 index 0000000000..b4901aaf14 --- /dev/null +++ b/csrc/mmdeploy/triton/mmdeploy.cpp @@ -0,0 +1,165 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "instance_state.h" +#include "mmdeploy/core/logger.h" +#include "model_state.h" +#include "triton/backend/backend_common.h" +#include "triton/backend/backend_model_instance.h" +#include "triton/core/tritonbackend.h" + +namespace triton::backend::mmdeploy { + +extern "C" { + +MMDEPLOY_EXPORT TRITONSERVER_Error* TRITONBACKEND_Initialize(TRITONBACKEND_Backend* backend) { + const char* cname; + RETURN_IF_ERROR(TRITONBACKEND_BackendName(backend, &cname)); + std::string name(cname); + + LOG_MESSAGE(TRITONSERVER_LOG_INFO, (std::string("TRITONBACKEND_Initialize: ") + name).c_str()); + + // Check the backend API version that Triton supports vs. what this + // backend was compiled against. Make sure that the Triton major + // version is the same and the minor version is >= what this backend + // uses. + uint32_t api_version_major, api_version_minor; + RETURN_IF_ERROR(TRITONBACKEND_ApiVersion(&api_version_major, &api_version_minor)); + + LOG_MESSAGE(TRITONSERVER_LOG_INFO, + (std::string("Triton TRITONBACKEND API version: ") + + std::to_string(api_version_major) + "." + std::to_string(api_version_minor)) + .c_str()); + LOG_MESSAGE(TRITONSERVER_LOG_INFO, (std::string("'") + name + "' TRITONBACKEND API version: " + + std::to_string(TRITONBACKEND_API_VERSION_MAJOR) + "." + + std::to_string(TRITONBACKEND_API_VERSION_MINOR)) + .c_str()); + + if ((api_version_major != TRITONBACKEND_API_VERSION_MAJOR) || + (api_version_minor < TRITONBACKEND_API_VERSION_MINOR)) { + return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_UNSUPPORTED, + "triton backend API version does not support this backend"); + } + + // The backend configuration may contain information needed by the + // backend, such as tritonserver command-line arguments. This + // backend doesn't use any such configuration but for this example + // print whatever is available. + TRITONSERVER_Message* backend_config_message; + RETURN_IF_ERROR(TRITONBACKEND_BackendConfig(backend, &backend_config_message)); + + const char* buffer; + size_t byte_size; + RETURN_IF_ERROR(TRITONSERVER_MessageSerializeToJson(backend_config_message, &buffer, &byte_size)); + LOG_MESSAGE(TRITONSERVER_LOG_INFO, (std::string("backend configuration:\n") + buffer).c_str()); + + // This backend does not require any "global" state but as an + // example create a string to demonstrate. + std::string* state = new std::string("backend state"); + RETURN_IF_ERROR(TRITONBACKEND_BackendSetState(backend, reinterpret_cast(state))); + + return nullptr; // success +} + +// Triton calls TRITONBACKEND_Finalize when a backend is no longer +// needed. +// +MMDEPLOY_EXPORT TRITONSERVER_Error* TRITONBACKEND_Finalize(TRITONBACKEND_Backend* backend) { + // Delete the "global" state associated with the backend. + void* vstate; + RETURN_IF_ERROR(TRITONBACKEND_BackendState(backend, &vstate)); + std::string* state = reinterpret_cast(vstate); + + LOG_MESSAGE(TRITONSERVER_LOG_INFO, + (std::string("TRITONBACKEND_Finalize: state is '") + *state + "'").c_str()); + + delete state; + + return nullptr; // success +} + +MMDEPLOY_EXPORT TRITONSERVER_Error* TRITONBACKEND_ModelInitialize(TRITONBACKEND_Model* model) { + ModelState* model_state; + RETURN_IF_ERROR(ModelState::Create(model, &model_state)); + RETURN_IF_ERROR(TRITONBACKEND_ModelSetState(model, reinterpret_cast(model_state))); + + return nullptr; // success +} + +MMDEPLOY_EXPORT TRITONSERVER_Error* TRITONBACKEND_ModelFinalize(TRITONBACKEND_Model* model) { + void* vstate; + RETURN_IF_ERROR(TRITONBACKEND_ModelState(model, &vstate)); + auto model_state = reinterpret_cast(vstate); + delete model_state; + + return nullptr; // success +} +} + +extern "C" { + +// Triton calls TRITONBACKEND_ModelInstanceInitialize when a model +// instance is created to allow the backend to initialize any state +// associated with the instance. +// +MMDEPLOY_EXPORT TRITONSERVER_Error* TRITONBACKEND_ModelInstanceInitialize( + TRITONBACKEND_ModelInstance* instance) { + // Get the model state associated with this instance's model. + TRITONBACKEND_Model* model; + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceModel(instance, &model)); + + void* vmodelstate; + RETURN_IF_ERROR(TRITONBACKEND_ModelState(model, &vmodelstate)); + ModelState* model_state = reinterpret_cast(vmodelstate); + + // Create a ModelInstanceState object and associate it with the + // TRITONBACKEND_ModelInstance. + ModelInstanceState* instance_state; + RETURN_IF_ERROR(ModelInstanceState::Create(model_state, instance, &instance_state)); + RETURN_IF_ERROR( + TRITONBACKEND_ModelInstanceSetState(instance, reinterpret_cast(instance_state))); + + return nullptr; // success +} + +// Triton calls TRITONBACKEND_ModelInstanceFinalize when a model +// instance is no longer needed. The backend should cleanup any state +// associated with the model instance. +// +MMDEPLOY_EXPORT TRITONSERVER_Error* TRITONBACKEND_ModelInstanceFinalize( + TRITONBACKEND_ModelInstance* instance) { + void* vstate; + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceState(instance, &vstate)); + ModelInstanceState* instance_state = reinterpret_cast(vstate); + delete instance_state; + + return nullptr; // success +} + +} // extern "C" + +extern "C" { + +// When Triton calls TRITONBACKEND_ModelInstanceExecute it is required +// that a backend create a response for each request in the batch. A +// response may be the output tensors required for that request or may +// be an error that is returned in the response. +// +MMDEPLOY_EXPORT TRITONSERVER_Error* TRITONBACKEND_ModelInstanceExecute( + TRITONBACKEND_ModelInstance* instance, TRITONBACKEND_Request** requests, + const uint32_t request_count) { + // Triton will not call this function simultaneously for the same + // 'instance'. But since this backend could be used by multiple + // instances from multiple models the implementation needs to handle + // multiple calls to this function at the same time (with different + // 'instance' objects). Best practice for a high-performance + // implementation is to avoid introducing mutex/lock and instead use + // only function-local and model-instance-specific state. + ModelInstanceState* instance_state; + RETURN_IF_ERROR( + TRITONBACKEND_ModelInstanceState(instance, reinterpret_cast(&instance_state))); + return instance_state->Execute(requests, request_count); +} + +} // extern "C" + +} // namespace triton::backend::mmdeploy diff --git a/csrc/mmdeploy/triton/mmdeploy_utils.h b/csrc/mmdeploy/triton/mmdeploy_utils.h new file mode 100644 index 0000000000..2fad294ad7 --- /dev/null +++ b/csrc/mmdeploy/triton/mmdeploy_utils.h @@ -0,0 +1,48 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#ifndef MMDEPLOY_MMDEPLOY_UTILS_H +#define MMDEPLOY_MMDEPLOY_UTILS_H + +#include "mmdeploy/core/types.h" + +namespace triton::backend::mmdeploy { + +inline TRITONSERVER_DataType ConvertDataType(::mmdeploy::DataType data_type) { + using namespace ::mmdeploy::data_types; + switch (data_type) { + case kFLOAT: + return TRITONSERVER_TYPE_FP32; + case kHALF: + return TRITONSERVER_TYPE_FP16; + case kINT8: + return TRITONSERVER_TYPE_UINT8; + case kINT32: + return TRITONSERVER_TYPE_INT32; + case kINT64: + return TRITONSERVER_TYPE_INT64; + default: + return TRITONSERVER_TYPE_INVALID; + } +} + +inline ::mmdeploy::DataType ConvertDataType(TRITONSERVER_DataType data_type) { + using namespace ::mmdeploy::data_types; + switch (data_type) { + case TRITONSERVER_TYPE_FP32: + return kFLOAT; + case TRITONSERVER_TYPE_FP16: + return kHALF; + case TRITONSERVER_TYPE_UINT8: + return kINT8; + case TRITONSERVER_TYPE_INT32: + return kINT32; + case TRITONSERVER_TYPE_INT64: + return kINT64; + default: + return ::mmdeploy::DataType::kCOUNT; + } +} + +} // namespace triton::backend::mmdeploy + +#endif // MMDEPLOY_MMDEPLOY_UTILS_H diff --git a/csrc/mmdeploy/triton/model_state.cpp b/csrc/mmdeploy/triton/model_state.cpp new file mode 100644 index 0000000000..ae8ae8eb4e --- /dev/null +++ b/csrc/mmdeploy/triton/model_state.cpp @@ -0,0 +1,90 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "model_state.h" + +#include "mmdeploy/core/utils/formatter.h" +#include "mmdeploy/pipeline.hpp" + +namespace triton::backend::mmdeploy { + +TRITONSERVER_Error* ModelState::Create(TRITONBACKEND_Model* triton_model, ModelState** state) { + try { + *state = new ModelState(triton_model); + } catch (const BackendModelException& ex) { + RETURN_ERROR_IF_TRUE(ex.err_ == nullptr, TRITONSERVER_ERROR_INTERNAL, + std::string("unexpected nullptr in BackendModelException")); + RETURN_IF_ERROR(ex.err_); + } + + return {}; +} + +ModelState::ModelState(TRITONBACKEND_Model* triton_model) + : BackendModel(triton_model), model_(JoinPath({RepositoryPath(), std::to_string(Version())})) { + THROW_IF_BACKEND_MODEL_ERROR(ValidateModelConfig()); +} + +TRITONSERVER_Error* ModelState::ValidateModelConfig() { + common::TritonJson::Value inputs; + common::TritonJson::Value outputs; + RETURN_IF_ERROR(ModelConfig().MemberAsArray("input", &inputs)); + RETURN_IF_ERROR(ModelConfig().MemberAsArray("output", &outputs)); + + for (size_t i = 0; i < inputs.ArraySize(); ++i) { + common::TritonJson::Value input; + RETURN_IF_ERROR(inputs.IndexAsObject(i, &input)); + + triton::common::TritonJson::Value reshape; + RETURN_ERROR_IF_TRUE(input.Find("reshape", &reshape), TRITONSERVER_ERROR_UNSUPPORTED, + std::string("reshape not supported for input tensor")); + + std::string name; + RETURN_IF_ERROR(input.MemberAsString("name", &name)); + input_names_.push_back(name); + + std::string data_type; + RETURN_IF_ERROR(input.MemberAsString("data_type", &data_type)); + input_data_types_.push_back(ModelConfigDataTypeToTritonServerDataType(data_type)); + + std::string format; + RETURN_IF_ERROR(input.MemberAsString("format", &format)); + input_formats_.push_back(format); + } + + for (size_t i = 0; i < outputs.ArraySize(); ++i) { + common::TritonJson::Value output; + RETURN_IF_ERROR(outputs.IndexAsObject(i, &output)); + + triton::common::TritonJson::Value reshape; + RETURN_ERROR_IF_TRUE(output.Find("reshape", &reshape), TRITONSERVER_ERROR_UNSUPPORTED, + std::string("reshape not supported for output tensor")); + + std::string name; + RETURN_IF_ERROR(output.MemberAsString("name", &name)); + output_names_.push_back(name); + + std::string data_type; + RETURN_IF_ERROR(output.MemberAsString("data_type", &data_type)); + output_data_types_.push_back(ModelConfigDataTypeToTritonServerDataType(data_type)); + } + + return {}; +} + +::mmdeploy::Pipeline ModelState::CreatePipeline(TRITONSERVER_InstanceGroupKind kind, + int device_id) { + // infer device name + std::string device_name = "cpu"; + if (kind == TRITONSERVER_INSTANCEGROUPKIND_GPU) { + device_name = "cuda"; + } + + auto config = model_.ReadConfig("pipeline.json").value(); + + config["context"]["model"] = model_; + + ::mmdeploy::Context context(::mmdeploy::Device(device_name, device_id)); + return {config, context}; +} + +} // namespace triton::backend::mmdeploy diff --git a/csrc/mmdeploy/triton/model_state.h b/csrc/mmdeploy/triton/model_state.h new file mode 100644 index 0000000000..f0ef5fedbe --- /dev/null +++ b/csrc/mmdeploy/triton/model_state.h @@ -0,0 +1,45 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#ifndef MMDEPLOY_MODEL_STATE_H +#define MMDEPLOY_MODEL_STATE_H + +#define MMDEPLOY_CXX_USE_OPENCV 0 + +#include "mmdeploy/core/model.h" +#include "mmdeploy/pipeline.hpp" +#include "triton/backend/backend_model.h" + +namespace triton::backend::mmdeploy { + +class ModelState : public BackendModel { + public: + static TRITONSERVER_Error* Create(TRITONBACKEND_Model* triton_model, ModelState** state); + + const std::vector& input_names() const { return input_names_; } + const std::vector& output_names() const { return output_names_; } + const std::vector& input_data_types() const { return input_data_types_; } + const std::vector& output_data_types() const { return output_data_types_; } + + const std::vector& input_formats() const { return input_formats_; } + + ::mmdeploy::Pipeline CreatePipeline(TRITONSERVER_InstanceGroupKind kind, int device_id); + + const std::string& task_type() { return model_.meta().task; } + + private: + explicit ModelState(TRITONBACKEND_Model* triton_model); + + TRITONSERVER_Error* ValidateModelConfig(); + + private: + ::mmdeploy::framework::Model model_; + std::vector input_names_; + std::vector output_names_; + std::vector input_data_types_; + std::vector output_data_types_; + std::vector input_formats_; +}; + +} // namespace triton::backend::mmdeploy + +#endif // MMDEPLOY_MODEL_STATE_H diff --git a/demo/python/triton_client.py b/demo/python/triton_client.py new file mode 100644 index 0000000000..c3009787de --- /dev/null +++ b/demo/python/triton_client.py @@ -0,0 +1,135 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +from math import cos, sin + +import cv2 +import numpy as np +import tritonclient.grpc as grpcclient + + +def parse_args(): + parser = argparse.ArgumentParser(description='') + parser.add_argument('model_name', type=str) + parser.add_argument('image_path', type=str) + parser.add_argument( + '-v', '--model_version', type=str, required=False, default='1') + parser.add_argument( + '-u', '--url', type=str, required=False, default='localhost:8001') + return parser.parse_args() + + +def get_palette(num_classes=256): + state = np.random.get_state() + # random color + np.random.seed(42) + palette = np.random.randint(0, 256, size=(num_classes, 3)) + np.random.set_state(state) + return [tuple(c) for c in palette] + + +def vis_cls(img, scores, label_ids): + print('\n'.join(map(str, zip(scores, label_ids)))) + + +def vis_det(img, bboxes, labels): + for bbox, label in zip(bboxes, labels): + (left, top, right, bottom), score = bbox[0:4].astype(int), bbox[4] + if score < 0.3: + continue + cv2.rectangle(img, (left, top), (right, bottom), (0, 255, 0)) + return img + + +def vis_rdet(img, bboxes, labels): + for rbbox, label_id in zip(bboxes, labels): + [cx, cy, w, h, angle], score = rbbox[0:5], rbbox[-1] + if score < 0.1: + continue + [wx, wy, hx, hy] = \ + 0.5 * np.array([w, w, -h, h]) * \ + np.array([cos(angle), sin(angle), sin(angle), cos(angle)]) + points = np.array([[[int(cx - wx - hx), + int(cy - wy - hy)], + [int(cx + wx - hx), + int(cy + wy - hy)], + [int(cx + wx + hx), + int(cy + wy + hy)], + [int(cx - wx + hx), + int(cy - wy + hy)]]]) + cv2.drawContours(img, points, -1, (0, 255, 0), 2) + return img + + +def vis_seg(img, mask, scores): + if mask is None: + mask = np.argmax(scores, axis=0) + + palette = get_palette() + color_seg = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8) + for label, color in enumerate(palette): + color_seg[mask == label, :] = color + # convert to BGR + color_seg = color_seg[..., ::-1] + + img = img * 0.5 + color_seg * 0.5 + return img.astype(np.uint8) + + +def vis_ocr(img, dets, text, text_score): + pts = ((dets[:, 0:8] + 0.5).reshape(len(dets), -1, 2).astype(int)) + cv2.polylines(img, pts, True, (0, 255, 0), 2) + print('\n'.join(map(str, zip(range(len(text)), text, text_score)))) + return img + + +def vis_pose(img, dets, kpts): + pass + + +def main(): + args = parse_args() + triton_client = grpcclient.InferenceServerClient(url=args.url) + + model_config = triton_client.get_model_config( + model_name=args.model_name, model_version=args.model_version) + + img = cv2.imread(args.image_path) + + if img is None: + print(f'failed to load image {args.image_path}') + return + + task = model_config.config.parameters['task'].string_value + + task_map = dict( + Classifier=(('scores', 'labels'), vis_cls), + Detector=(('bboxes', 'labels'), vis_det), + TextOCR=(('dets', 'text', 'text_score'), vis_ocr), + Restorer=(('output',), lambda _, hires: hires), + Segmentor=(('mask', 'score'), vis_seg), + RotatedDetector=(('bboxes', 'labels'), None), + DetPose=(('bboxes', 'keypoints'), vis_pose)) + + output_names, visualize = task_map[task] + + # request input + inputs = [grpcclient.InferInput('ori_img', img.shape, 'UINT8')] + inputs[0].set_data_from_numpy(img) + + # request outputs + outputs = map(grpcclient.InferRequestedOutput, output_names) + + # run inference + response = triton_client.infer( + model_config.config.name, inputs, outputs=list(outputs)) + + # visualize results + vis = visualize(img, *map(response.as_numpy, output_names)) + + if vis is not None: + cv2.imshow('', vis) + cv2.waitKey(0) + + +if __name__ == '__main__': + main() From 7f85148ec53946534d848dad1f1a1982569cfefc Mon Sep 17 00:00:00 2001 From: zhangli Date: Tue, 7 Mar 2023 12:16:59 +0800 Subject: [PATCH 02/16] wip --- csrc/mmdeploy/triton/CMakeLists.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/csrc/mmdeploy/triton/CMakeLists.txt b/csrc/mmdeploy/triton/CMakeLists.txt index f5b57eb4b3..3924ae8b42 100644 --- a/csrc/mmdeploy/triton/CMakeLists.txt +++ b/csrc/mmdeploy/triton/CMakeLists.txt @@ -95,7 +95,8 @@ target_link_libraries( triton-backend-utils # from repo-backend ) -target_link_libraries(triton-mmdeploy-backend PRIVATE mmdeploy) +mmdeploy_load_static(triton-mmdeploy-backend MMDeployStaticModules) +target_link_libraries(triton-mmdeploy-backend PRIVATE MMDeployLibs) mmdeploy_export(triton-mmdeploy-backend) From c89eb58d62c5dc81abf6a7318eebf74a825ac955 Mon Sep 17 00:00:00 2001 From: zhangli Date: Thu, 9 Mar 2023 18:36:06 +0800 Subject: [PATCH 03/16] wip --- demo/python/to_triton_model.py | 147 +++++++++++++++++++++++++++++++++ demo/python/triton_client.py | 2 +- 2 files changed, 148 insertions(+), 1 deletion(-) create mode 100644 demo/python/to_triton_model.py diff --git a/demo/python/to_triton_model.py b/demo/python/to_triton_model.py new file mode 100644 index 0000000000..cc77082838 --- /dev/null +++ b/demo/python/to_triton_model.py @@ -0,0 +1,147 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import json +import os.path as osp + +import tritonclient.grpc.model_config_pb2 as pb +from google.protobuf import text_format + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('model_path', type=str) + parser.add_argument('--name', type=str) + parser.add_argument('--nocopy', type=bool) + return parser.parse_args() + + +IMAGE_INPUT = [ + dict( + name='ori_img', + dtype=pb.TYPE_UINT8, + format=pb.ModelInput.FORMAT_NHWC, + dims=[-1, -1, 3]), + dict(name='pix_fmt', dtype=pb.TYPE_INT32, dims=[1], optional=True) +] + +TASK_OUTPUT = dict( + Preprocess=[ + dict(name='img', dtype=pb.TYPE_FP32, dims=[3, -1, -1]), + dict(name='img_metas', dtype=pb.TYPE_STRING, dims=[1]) + ], + Classifier=[ + dict(name='scores', dtype=pb.TYPE_FP32, dims=[-1, 1]), + dict(name='label_ids', dtype=pb.TYPE_FP32, dims=[-1, 1]) + ], + Detector=[dict(name='dets', dtype=pb.TYPE_FP32, dims=[-1, 1])], + Segmentor=[ + dict(name='mask', dtype=pb.TYPE_INT32, dims=[-1, -1]), + dict(name='score', dtype=pb.TYPE_FP32, dims=[-1, -1, -1]) + ], + Restorer=[ + dict(name='output', dtype=pb.TYPE_FP32, dims=[-1, -1, 3]) + ], + TextDetector=[], + TextRecognizer=[], + PoseDetector=[], + RotatedDetector=[], + TextOCR=[], + DetPose=[]) + + +def add_input(model_config, params): + p = model_config.input.add() + p.name = params['name'] + p.data_type = params['dtype'] + p.dims.extend(params['dims']) + if 'format' in params: + p.format = params['format'] + if 'optional' in params: + p.optional = params['optional'] + + +def add_output(model_config, params): + p = model_config.output.add() + p.name = params['name'] + p.data_type = params['dtype'] + p.dims.extend(params['dims']) + + +def serialize_model_config(model_config): + return text_format.MessageToString( + model_config, + use_short_repeated_primitives=True, + use_index_order=True, + print_unknown_fields=True) + + +def create_model_config(name, task, backend=None, platform=None): + model_config = pb.ModelConfig() + if backend: + model_config.backend = backend + if platform: + model_config.platform = platform + model_config.name = name + model_config.max_batch_size = 0 + + for input in IMAGE_INPUT: + add_input(model_config, input) + for output in TASK_OUTPUT[task]: + add_output(model_config, output) + return model_config + + +def create_preprocess_model(): + pass + + +def get_onnx_io_names(detail_info): + onnx_config = detail_info['onnx_config'] + return onnx_config['input_names'], onnx_config['output_names'] + + +def create_inference_model(deploy_info, pipeline_info, detail_info): + if 'pipeline' in pipeline_info: + # old-style pipeline specification + pipeline = pipeline_info['pipeline']['tasks'] + else: + pipeline = pipeline_info['tasks'] + + for task_cfg in pipeline: + if task_cfg['module'] == 'Net': + input_names, output_names = get_onnx_io_names(detail_info) + + +def create_postprocess_model(): + pass + + +def create_pipeline_model(): + pass + + +def create_ensemble_model(deploy_cfg, pipeline_cfg): + inference_model_config = create_inference_model(deploy_cfg, pipeline_cfg) + preprocess_model_config = create_preprocess_model() + postprocess_model_config = create_postprocess_model() + pipeline_model_config = create_pipeline_model() + + +def main(): + args = parse_args() + model_path = args.model_path + if not osp.isdir(model_path): + model_path = osp.split(model_path)[-2] + if osp.isdir(model_path): + with open(osp.join(model_path, 'deploy.json'), 'r') as f: + deploy_cfg = json.load(f) + with open(osp.join(model_path, 'pipeline.json'), 'r') as f: + pipeline_cfg = json.load(f) + task = deploy_cfg['task'] + model_config = create_model_config('model', task, 'onnxruntime') + data = serialize_model_config(model_config) + print(data) + + +if __name__ == '__main__': + main() diff --git a/demo/python/triton_client.py b/demo/python/triton_client.py index c3009787de..0dee948d91 100644 --- a/demo/python/triton_client.py +++ b/demo/python/triton_client.py @@ -105,7 +105,7 @@ def main(): Classifier=(('scores', 'labels'), vis_cls), Detector=(('bboxes', 'labels'), vis_det), TextOCR=(('dets', 'text', 'text_score'), vis_ocr), - Restorer=(('output',), lambda _, hires: hires), + Restorer=(('output', ), lambda _, hires: hires), Segmentor=(('mask', 'score'), vis_seg), RotatedDetector=(('bboxes', 'labels'), None), DetPose=(('bboxes', 'keypoints'), vis_pose)) From 1183a9815e0a32974b5447ae01d594010e7b4d44 Mon Sep 17 00:00:00 2001 From: zhangli Date: Tue, 4 Apr 2023 11:29:40 +0000 Subject: [PATCH 04/16] update --- csrc/mmdeploy/net/trt/trt_net.cpp | 1 + csrc/mmdeploy/triton/CMakeLists.txt | 5 ++--- csrc/mmdeploy/triton/convert.cpp | 1 + csrc/mmdeploy/triton/instance_state.cpp | 11 ++++------- demo/python/triton_client.py | 9 +++++++-- 5 files changed, 15 insertions(+), 12 deletions(-) diff --git a/csrc/mmdeploy/net/trt/trt_net.cpp b/csrc/mmdeploy/net/trt/trt_net.cpp index 8b9b98b5d7..142fca7fe9 100644 --- a/csrc/mmdeploy/net/trt/trt_net.cpp +++ b/csrc/mmdeploy/net/trt/trt_net.cpp @@ -8,6 +8,7 @@ #include "mmdeploy/core/model.h" #include "mmdeploy/core/module.h" #include "mmdeploy/core/utils/formatter.h" +#include "mmdeploy/device/cuda/cuda_device.h" namespace mmdeploy::framework { diff --git a/csrc/mmdeploy/triton/CMakeLists.txt b/csrc/mmdeploy/triton/CMakeLists.txt index 3924ae8b42..efdf880d7a 100644 --- a/csrc/mmdeploy/triton/CMakeLists.txt +++ b/csrc/mmdeploy/triton/CMakeLists.txt @@ -61,13 +61,11 @@ FetchContent_Declare( repo-core GIT_REPOSITORY https://github.com/triton-inference-server/core.git GIT_TAG ${TRITON_CORE_REPO_TAG} - GIT_SHALLOW ON ) FetchContent_Declare( repo-backend GIT_REPOSITORY https://github.com/triton-inference-server/backend.git GIT_TAG ${TRITON_BACKEND_REPO_TAG} - GIT_SHALLOW ON ) FetchContent_MakeAvailable(repo-common repo-core repo-backend) @@ -98,8 +96,9 @@ target_link_libraries( mmdeploy_load_static(triton-mmdeploy-backend MMDeployStaticModules) target_link_libraries(triton-mmdeploy-backend PRIVATE MMDeployLibs) -mmdeploy_export(triton-mmdeploy-backend) +set_target_properties(triton-mmdeploy-backend PROPERTIES INSTALL_RPATH "\$ORIGIN") +install(TARGETS triton-mmdeploy-backend DESTINATION backend/mmdeploy) if (WIN32) set_target_properties( diff --git a/csrc/mmdeploy/triton/convert.cpp b/csrc/mmdeploy/triton/convert.cpp index d1f44e0111..42e7de6c36 100644 --- a/csrc/mmdeploy/triton/convert.cpp +++ b/csrc/mmdeploy/triton/convert.cpp @@ -90,6 +90,7 @@ void ConvertSegmentation(const Value& item, std::vector& tensors) { auto desc = seg.mask.desc(); desc.name = "mask"; tensors.emplace_back(desc, seg.mask.buffer()); + tensors.back().Squeeze(); } } diff --git a/csrc/mmdeploy/triton/instance_state.cpp b/csrc/mmdeploy/triton/instance_state.cpp index 3a4a026472..ad997993e3 100644 --- a/csrc/mmdeploy/triton/instance_state.cpp +++ b/csrc/mmdeploy/triton/instance_state.cpp @@ -172,7 +172,7 @@ TRITONSERVER_Error* ModelInstanceState::Execute(TRITONBACKEND_Request** requests input_args.push_back(std::move(input_tensors_array)); } - MMDEPLOY_ERROR("input: {}", input_args); + MMDEPLOY_DEBUG("input: {}", input_args); uint64_t compute_start_ns = 0; SET_TIMESTAMP(compute_start_ns); @@ -187,23 +187,20 @@ TRITONSERVER_Error* ModelInstanceState::Execute(TRITONBACKEND_Request** requests SET_TIMESTAMP(compute_end_ns); std::vector> responders(request_count); - MMDEPLOY_ERROR("request_count {}", request_count); + MMDEPLOY_DEBUG("request_count {}", request_count); for (uint32_t request_index = 0; request_index < request_count; ++request_index) { responders[request_index] = std::make_unique( &requests[request_index], 1, &response_vecs[request_index], model_state->TritonMemoryManager(), false, false, nullptr); - for (const auto& name : model_state->output_names()) { - MMDEPLOY_ERROR("name {}", name); - } for (size_t output_id = 0; output_id < model_state->output_names().size(); ++output_id) { auto output_name = model_state->output_names()[output_id]; - MMDEPLOY_ERROR("output name {}", output_name); + MMDEPLOY_DEBUG("output name {}", output_name); auto output_data_type = model_state->output_data_types()[output_id]; for (const auto& tensor : output_tensors[request_index]) { if (tensor.name() == output_name) { if (output_data_type != TRITONSERVER_TYPE_BYTES) { auto shape = tensor.shape(); - MMDEPLOY_ERROR("name {}, shape {}", tensor.name(), shape); + MMDEPLOY_DEBUG("name {}, shape {}", tensor.name(), shape); auto memory_type = TRITONSERVER_MEMORY_CPU; int64_t memory_type_id = 0; if (not tensor.device().is_host()) { diff --git a/demo/python/triton_client.py b/demo/python/triton_client.py index 0dee948d91..98f3e38d82 100644 --- a/demo/python/triton_client.py +++ b/demo/python/triton_client.py @@ -61,6 +61,10 @@ def vis_rdet(img, bboxes, labels): def vis_seg(img, mask, scores): + if mask is not None: + print(f'mask {mask.shape}') + if scores is not None: + print(f'scores {scores.shape}') if mask is None: mask = np.argmax(scores, axis=0) @@ -127,8 +131,9 @@ def main(): vis = visualize(img, *map(response.as_numpy, output_names)) if vis is not None: - cv2.imshow('', vis) - cv2.waitKey(0) + cv2.imwrite('vis.jpg', vis) + # cv2.imshow('', vis) + # cv2.waitKey(0) if __name__ == '__main__': From d644d27c6651ae25936787920869f4f29b5fa196 Mon Sep 17 00:00:00 2001 From: irexyc Date: Wed, 26 Apr 2023 17:40:22 +0800 Subject: [PATCH 05/16] remove no used backend state --- csrc/mmdeploy/triton/mmdeploy.cpp | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/csrc/mmdeploy/triton/mmdeploy.cpp b/csrc/mmdeploy/triton/mmdeploy.cpp index b4901aaf14..4d1f6ca3f5 100644 --- a/csrc/mmdeploy/triton/mmdeploy.cpp +++ b/csrc/mmdeploy/triton/mmdeploy.cpp @@ -52,11 +52,6 @@ MMDEPLOY_EXPORT TRITONSERVER_Error* TRITONBACKEND_Initialize(TRITONBACKEND_Backe RETURN_IF_ERROR(TRITONSERVER_MessageSerializeToJson(backend_config_message, &buffer, &byte_size)); LOG_MESSAGE(TRITONSERVER_LOG_INFO, (std::string("backend configuration:\n") + buffer).c_str()); - // This backend does not require any "global" state but as an - // example create a string to demonstrate. - std::string* state = new std::string("backend state"); - RETURN_IF_ERROR(TRITONBACKEND_BackendSetState(backend, reinterpret_cast(state))); - return nullptr; // success } @@ -64,16 +59,6 @@ MMDEPLOY_EXPORT TRITONSERVER_Error* TRITONBACKEND_Initialize(TRITONBACKEND_Backe // needed. // MMDEPLOY_EXPORT TRITONSERVER_Error* TRITONBACKEND_Finalize(TRITONBACKEND_Backend* backend) { - // Delete the "global" state associated with the backend. - void* vstate; - RETURN_IF_ERROR(TRITONBACKEND_BackendState(backend, &vstate)); - std::string* state = reinterpret_cast(vstate); - - LOG_MESSAGE(TRITONSERVER_LOG_INFO, - (std::string("TRITONBACKEND_Finalize: state is '") + *state + "'").c_str()); - - delete state; - return nullptr; // success } From a8b78de918443652b6b654b44e348a26d6e0697d Mon Sep 17 00:00:00 2001 From: irexyc Date: Wed, 26 Apr 2023 17:59:05 +0800 Subject: [PATCH 06/16] fix build --- csrc/mmdeploy/apis/CMakeLists.txt | 1 - csrc/mmdeploy/device/cuda/cuda_device.h | 5 +++++ csrc/mmdeploy/net/trt/trt_net.cpp | 1 - demo/python/to_triton_model.py | 4 +--- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/csrc/mmdeploy/apis/CMakeLists.txt b/csrc/mmdeploy/apis/CMakeLists.txt index 9331f710d8..1ab877be90 100644 --- a/csrc/mmdeploy/apis/CMakeLists.txt +++ b/csrc/mmdeploy/apis/CMakeLists.txt @@ -9,4 +9,3 @@ add_subdirectory(java) if (MMDEPLOY_BUILD_SDK_PYTHON_API) add_subdirectory(python) endif () - diff --git a/csrc/mmdeploy/device/cuda/cuda_device.h b/csrc/mmdeploy/device/cuda/cuda_device.h index 20b894652d..55d9b439f0 100644 --- a/csrc/mmdeploy/device/cuda/cuda_device.h +++ b/csrc/mmdeploy/device/cuda/cuda_device.h @@ -1,5 +1,8 @@ // Copyright (c) OpenMMLab. All rights reserved. +#ifndef MMDEPLOY_SRC_DEVICE_CUDA_CUDE_DEVICE_H_ +#define MMDEPLOY_SRC_DEVICE_CUDA_CUDE_DEVICE_H_ + #include #include @@ -196,3 +199,5 @@ class CudaDeviceGuard { }; } // namespace mmdeploy::framework + +#endif // MMDEPLOY_SRC_DEVICE_CUDA_CUDE_DEVICE_H_ diff --git a/csrc/mmdeploy/net/trt/trt_net.cpp b/csrc/mmdeploy/net/trt/trt_net.cpp index 142fca7fe9..8b9b98b5d7 100644 --- a/csrc/mmdeploy/net/trt/trt_net.cpp +++ b/csrc/mmdeploy/net/trt/trt_net.cpp @@ -8,7 +8,6 @@ #include "mmdeploy/core/model.h" #include "mmdeploy/core/module.h" #include "mmdeploy/core/utils/formatter.h" -#include "mmdeploy/device/cuda/cuda_device.h" namespace mmdeploy::framework { diff --git a/demo/python/to_triton_model.py b/demo/python/to_triton_model.py index cc77082838..92d768870f 100644 --- a/demo/python/to_triton_model.py +++ b/demo/python/to_triton_model.py @@ -38,9 +38,7 @@ def parse_args(): dict(name='mask', dtype=pb.TYPE_INT32, dims=[-1, -1]), dict(name='score', dtype=pb.TYPE_FP32, dims=[-1, -1, -1]) ], - Restorer=[ - dict(name='output', dtype=pb.TYPE_FP32, dims=[-1, -1, 3]) - ], + Restorer=[dict(name='output', dtype=pb.TYPE_FP32, dims=[-1, -1, 3])], TextDetector=[], TextRecognizer=[], PoseDetector=[], From 645dee1bf9e85752268182e050ad43b1d3c848f1 Mon Sep 17 00:00:00 2001 From: irexyc Date: Thu, 27 Apr 2023 19:27:24 +0800 Subject: [PATCH 07/16] backwards compatibility --- csrc/mmdeploy/triton/instance_state.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/mmdeploy/triton/instance_state.cpp b/csrc/mmdeploy/triton/instance_state.cpp index ad997993e3..9e65f4eb6a 100644 --- a/csrc/mmdeploy/triton/instance_state.cpp +++ b/csrc/mmdeploy/triton/instance_state.cpp @@ -190,8 +190,8 @@ TRITONSERVER_Error* ModelInstanceState::Execute(TRITONBACKEND_Request** requests MMDEPLOY_DEBUG("request_count {}", request_count); for (uint32_t request_index = 0; request_index < request_count; ++request_index) { responders[request_index] = std::make_unique( - &requests[request_index], 1, &response_vecs[request_index], - model_state->TritonMemoryManager(), false, false, nullptr); + &requests[request_index], 1, &response_vecs[request_index], false, + model_state->TritonMemoryManager(), false, nullptr); for (size_t output_id = 0; output_id < model_state->output_names().size(); ++output_id) { auto output_name = model_state->output_names()[output_id]; MMDEPLOY_DEBUG("output name {}", output_name); From 4cd226fe36b8caa20f7b1a0bde4b8872f14b40da Mon Sep 17 00:00:00 2001 From: irexyc Date: Thu, 4 May 2023 19:39:42 +0800 Subject: [PATCH 08/16] pipeline sync --- csrc/mmdeploy/codebase/mmdet/rtmdet_head.cpp | 3 --- csrc/mmdeploy/triton/instance_state.cpp | 12 +++++++++++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/csrc/mmdeploy/codebase/mmdet/rtmdet_head.cpp b/csrc/mmdeploy/codebase/mmdet/rtmdet_head.cpp index 27dc6578b5..59df608b59 100644 --- a/csrc/mmdeploy/codebase/mmdet/rtmdet_head.cpp +++ b/csrc/mmdeploy/codebase/mmdet/rtmdet_head.cpp @@ -64,9 +64,6 @@ static float sigmoid(float x) { return 1.0 / (1.0 + expf(-x)); } Result RTMDetSepBNHead::GetBBoxes(const Value& prep_res, const std::vector& bbox_preds, const std::vector& cls_scores) const { - MMDEPLOY_DEBUG("bbox_pred: {}, {}", bbox_preds[0].shape(), dets[0].data_type()); - MMDEPLOY_DEBUG("cls_score: {}, {}", scores[0].shape(), scores[0].data_type()); - std::vector filter_boxes; std::vector obj_probs; std::vector class_ids; diff --git a/csrc/mmdeploy/triton/instance_state.cpp b/csrc/mmdeploy/triton/instance_state.cpp index 9e65f4eb6a..4315c23cdc 100644 --- a/csrc/mmdeploy/triton/instance_state.cpp +++ b/csrc/mmdeploy/triton/instance_state.cpp @@ -7,6 +7,7 @@ #include "convert.h" #include "json.hpp" #include "mmdeploy/archive/json_archive.h" +#include "mmdeploy/core/device.h" #include "mmdeploy/core/mat.h" #include "mmdeploy/core/utils/formatter.h" #include "mmdeploy_utils.h" @@ -178,6 +179,15 @@ TRITONSERVER_Error* ModelInstanceState::Execute(TRITONBACKEND_Request** requests SET_TIMESTAMP(compute_start_ns); ::mmdeploy::Value outputs = pipeline_.Apply(input_args); + { + std::string device_name = "cpu"; + if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) { + device_name = "cuda"; + } + auto device = ::mmdeploy::framework::Device(device_name.c_str(), DeviceId()); + auto stream = ::mmdeploy::framework::Stream::GetDefault(device); + stream.Wait(); + } std::vector strings; auto output_tensors = @@ -276,7 +286,7 @@ TRITONSERVER_Error* ModelInstanceState::Execute(TRITONBACKEND_Request** requests TritonModelInstance(), total_batch_size, exec_start_ns, compute_start_ns, compute_end_ns, exec_end_ns), "failed reporting batch request statistics"); -#endif // TRITON_ENABLE_STATS +#endif // TRITON_ENABLE_STATS return nullptr; // success } From 0208fbad9d8404850a9125a4addc909769ccc1e4 Mon Sep 17 00:00:00 2001 From: irexyc Date: Mon, 15 May 2023 16:31:20 +0800 Subject: [PATCH 09/16] add demo --- csrc/mmdeploy/codebase/mmseg/segment.cpp | 1 + csrc/mmdeploy/triton/convert.cpp | 157 ++++++++++--- csrc/mmdeploy/triton/convert.h | 6 +- csrc/mmdeploy/triton/instance_state.cpp | 212 +++++++++++++----- csrc/mmdeploy/triton/instance_state.h | 1 + csrc/mmdeploy/triton/json_input.cpp | 14 ++ csrc/mmdeploy/triton/json_input.h | 22 ++ csrc/mmdeploy/triton/model_state.cpp | 47 +++- csrc/mmdeploy/triton/model_state.h | 3 +- demo/python/to_triton_model.py | 145 ------------ demo/python/triton_client.py | 140 ------------ demo/triton/image-classification/README.md | 1 + demo/triton/image-classification/README_CN.md | 0 .../image-classification/grpc_client.py | 76 +++++++ .../serving/model/1/README.md | 1 + .../serving/model/config.pbtxt | 20 ++ demo/triton/instance-segmentation/README.md | 1 + .../instance-segmentation/README_zh-CN.md | 0 .../instance-segmentation/grpc_client.py | 94 ++++++++ .../serving/model/1/README.md | 1 + .../serving/model/config.pbtxt | 30 +++ demo/triton/keypoint-detection/README.md | 4 + demo/triton/keypoint-detection/README_CN.md | 0 demo/triton/keypoint-detection/grpc_client.py | 92 ++++++++ .../serving/model/1/README.md | 1 + .../serving/model/config.pbtxt | 29 +++ demo/triton/object-detection/README.md | 1 + demo/triton/object-detection/README_zh-CN.md | 0 demo/triton/object-detection/grpc_client.py | 79 +++++++ .../serving/model/1/README.md | 1 + .../serving/model/config.pbtxt | 20 ++ .../oriented-object-detection/README.md | 1 + .../oriented-object-detection/README_zh-CN.md | 0 .../oriented-object-detection/grpc_client.py | 90 ++++++++ .../serving/model/1/README.md | 1 + .../serving/model/config.pbtxt | 20 ++ demo/triton/semantic-segmentation/README.md | 1 + .../triton/semantic-segmentation/README_CN.md | 0 .../semantic-segmentation/grpc_client.py | 95 ++++++++ .../serving/mask/model/1/README.md | 1 + .../serving/mask/model/config.pbtxt | 15 ++ .../serving/score/model/1/README.md | 1 + .../serving/score/model/config.pbtxt | 15 ++ demo/triton/text-detection/README.md | 6 + demo/triton/text-detection/README_CN.md | 0 demo/triton/text-detection/grpc_client.py | 82 +++++++ .../text-detection/serving/model/1/README.md | 1 + .../text-detection/serving/model/config.pbtxt | 21 ++ demo/triton/text-ocr/grpc_client.py | 80 +++++++ .../serving/model/1/pipeline_template.json | 50 +++++ .../serving/model/1/text_detection/README.md | 0 .../model/1/text_recognition/README.md | 0 .../text-ocr/serving/model/config.pbtxt | 33 +++ demo/triton/text-recognition/README.md | 1 + demo/triton/text-recognition/README_CN.md | 0 demo/triton/text-recognition/grpc_client.py | 90 ++++++++ .../serving/model/1/README.md | 1 + .../serving/model/config.pbtxt | 28 +++ demo/triton/to_triton_model.py | 178 +++++++++++++++ 59 files changed, 1623 insertions(+), 387 deletions(-) create mode 100644 csrc/mmdeploy/triton/json_input.cpp create mode 100644 csrc/mmdeploy/triton/json_input.h delete mode 100644 demo/python/to_triton_model.py delete mode 100644 demo/python/triton_client.py create mode 100644 demo/triton/image-classification/README.md create mode 100644 demo/triton/image-classification/README_CN.md create mode 100644 demo/triton/image-classification/grpc_client.py create mode 100644 demo/triton/image-classification/serving/model/1/README.md create mode 100644 demo/triton/image-classification/serving/model/config.pbtxt create mode 100644 demo/triton/instance-segmentation/README.md create mode 100644 demo/triton/instance-segmentation/README_zh-CN.md create mode 100644 demo/triton/instance-segmentation/grpc_client.py create mode 100644 demo/triton/instance-segmentation/serving/model/1/README.md create mode 100644 demo/triton/instance-segmentation/serving/model/config.pbtxt create mode 100644 demo/triton/keypoint-detection/README.md create mode 100644 demo/triton/keypoint-detection/README_CN.md create mode 100644 demo/triton/keypoint-detection/grpc_client.py create mode 100644 demo/triton/keypoint-detection/serving/model/1/README.md create mode 100644 demo/triton/keypoint-detection/serving/model/config.pbtxt create mode 100644 demo/triton/object-detection/README.md create mode 100644 demo/triton/object-detection/README_zh-CN.md create mode 100644 demo/triton/object-detection/grpc_client.py create mode 100644 demo/triton/object-detection/serving/model/1/README.md create mode 100644 demo/triton/object-detection/serving/model/config.pbtxt create mode 100644 demo/triton/oriented-object-detection/README.md create mode 100644 demo/triton/oriented-object-detection/README_zh-CN.md create mode 100644 demo/triton/oriented-object-detection/grpc_client.py create mode 100644 demo/triton/oriented-object-detection/serving/model/1/README.md create mode 100644 demo/triton/oriented-object-detection/serving/model/config.pbtxt create mode 100644 demo/triton/semantic-segmentation/README.md create mode 100644 demo/triton/semantic-segmentation/README_CN.md create mode 100644 demo/triton/semantic-segmentation/grpc_client.py create mode 100644 demo/triton/semantic-segmentation/serving/mask/model/1/README.md create mode 100644 demo/triton/semantic-segmentation/serving/mask/model/config.pbtxt create mode 100644 demo/triton/semantic-segmentation/serving/score/model/1/README.md create mode 100644 demo/triton/semantic-segmentation/serving/score/model/config.pbtxt create mode 100644 demo/triton/text-detection/README.md create mode 100644 demo/triton/text-detection/README_CN.md create mode 100644 demo/triton/text-detection/grpc_client.py create mode 100644 demo/triton/text-detection/serving/model/1/README.md create mode 100644 demo/triton/text-detection/serving/model/config.pbtxt create mode 100644 demo/triton/text-ocr/grpc_client.py create mode 100644 demo/triton/text-ocr/serving/model/1/pipeline_template.json create mode 100644 demo/triton/text-ocr/serving/model/1/text_detection/README.md create mode 100644 demo/triton/text-ocr/serving/model/1/text_recognition/README.md create mode 100644 demo/triton/text-ocr/serving/model/config.pbtxt create mode 100644 demo/triton/text-recognition/README.md create mode 100644 demo/triton/text-recognition/README_CN.md create mode 100644 demo/triton/text-recognition/grpc_client.py create mode 100644 demo/triton/text-recognition/serving/model/1/README.md create mode 100644 demo/triton/text-recognition/serving/model/config.pbtxt create mode 100644 demo/triton/to_triton_model.py diff --git a/csrc/mmdeploy/codebase/mmseg/segment.cpp b/csrc/mmdeploy/codebase/mmseg/segment.cpp index 56811a4fad..6f96f99f27 100644 --- a/csrc/mmdeploy/codebase/mmseg/segment.cpp +++ b/csrc/mmdeploy/codebase/mmseg/segment.cpp @@ -90,6 +90,7 @@ class ResizeMask : public MMSegmentation { std::vector axes = {0, 3, 1, 2}; ::mmdeploy::operation::Context ctx(host, stream_); OUTCOME_TRY(permute_.Apply(tensor_score, tensor_score, axes)); + tensor_score.Squeeze(0); } SegmentorOutput output{tensor_mask, tensor_score, input_height, input_width, classes_}; diff --git a/csrc/mmdeploy/triton/convert.cpp b/csrc/mmdeploy/triton/convert.cpp index 42e7de6c36..e558b7c41b 100644 --- a/csrc/mmdeploy/triton/convert.cpp +++ b/csrc/mmdeploy/triton/convert.cpp @@ -4,6 +4,7 @@ #include +#include "mmdeploy/archive/json_archive.h" #include "mmdeploy/archive/value_archive.h" #include "mmdeploy/codebase/mmaction/mmaction.h" #include "mmdeploy/codebase/mmcls/mmcls.h" @@ -67,15 +68,42 @@ void ConvertDetections(const Value& item, std::vector& tensors) { "labels"}); auto bboxes_data = bboxes.data(); auto labels_data = labels.data(); + int64_t sum_byte_size = 0; for (const auto& det : detections) { for (const auto& x : det.bbox) { *bboxes_data++ = x; } *bboxes_data++ = det.score; *labels_data++ = det.label_id; + sum_byte_size += det.mask.byte_size(); } tensors.push_back(std::move(bboxes)); tensors.push_back(std::move(labels)); + if (sum_byte_size > 0) { + // return mask + Tensor masks(TensorDesc{bboxes.device(), + ::mmdeploy::DataType::kINT8, + {static_cast(sum_byte_size)}, + "masks"}); + Tensor offs(TensorDesc{bboxes.device(), + ::mmdeploy::DataType::kINT32, + {static_cast(detections.size()), 3}, + "mask_offs"}); // [(off, w, h), ... ] + + auto masks_data = masks.data(); + auto offs_data = offs.data(); + int sum_offs = 0; + for (const auto& det : detections) { + memcpy(masks_data, det.mask.data(), det.mask.byte_size()); + masks_data += det.mask.byte_size(); + *offs_data++ = sum_offs; + *offs_data++ = det.mask.width(); + *offs_data++ = det.mask.height(); + sum_offs += det.mask.byte_size(); + } + tensors.push_back(std::move(masks)); + tensors.push_back(std::move(offs)); + } } void ConvertSegmentation(const Value& item, std::vector& tensors) { @@ -105,39 +133,75 @@ void ConvertTextDetections(const Value& item, std::vector& tensors) { ::mmdeploy::from_value(item, detections); Tensor bboxes(TensorDesc{::mmdeploy::Device(0), ::mmdeploy::DataType::kFLOAT, - {static_cast(detections.size()), 9}, - "dets"}); + {static_cast(detections.size()), 8}, + "bboxes"}); + Tensor scores(TensorDesc{::mmdeploy::Device(0), + ::mmdeploy::DataType::kFLOAT, + {static_cast(detections.size()), 1}, + "scores"}); auto bboxes_data = bboxes.data(); + auto scores_data = scores.data(); for (const auto& det : detections) { bboxes_data = std::copy(det.bbox.begin(), det.bbox.end(), bboxes_data); - *bboxes_data++ = det.score; + *scores_data++ = det.score; } tensors.push_back(std::move(bboxes)); + tensors.push_back(std::move(scores)); +} + +void ConvertTextRecognitions(const Value& item, int request_count, + const std::vector& batch_per_request, + std::vector>& tensors, + std::vector& strings) { + std::vector<::mmdeploy::mmocr::TextRecognition> recognitions; + ::mmdeploy::from_value(item, recognitions); + + int k = 0; + for (int i = 0; i < request_count; i++) { + int num = batch_per_request[i]; + Tensor texts(TensorDesc{ + ::mmdeploy::Device(0), ::mmdeploy::DataType::kINT32, {static_cast(num)}, "texts"}); + Tensor score(TensorDesc{::mmdeploy::Device(0), + ::mmdeploy::DataType::kINT32, + {static_cast(num)}, + "scores"}); + auto text_data = texts.data(); + auto score_data = score.data(); + + for (int j = 0; j < num; j++) { + auto& recognition = recognitions[k++]; + text_data[j] = static_cast(strings.size()); + strings.push_back(recognition.text); + score_data[j] = static_cast(strings.size()); + strings.push_back(::mmdeploy::to_json(::mmdeploy::to_value(recognition.score)).dump()); + } + tensors[i].push_back(std::move(texts)); + tensors[i].push_back(std::move(score)); + } } void ConvertTextRecognitions(const Value& item, std::vector& tensors, std::vector& strings) { std::vector<::mmdeploy::mmocr::TextRecognition> recognitions; ::mmdeploy::from_value(item, recognitions); + Tensor texts(TensorDesc{::mmdeploy::Device(0), ::mmdeploy::DataType::kINT32, {static_cast(recognitions.size())}, - "text"}); + "rec_texts"}); Tensor score(TensorDesc{::mmdeploy::Device(0), - ::mmdeploy::DataType::kFLOAT, + ::mmdeploy::DataType::kINT32, {static_cast(recognitions.size())}, - "text_score"}); + "rec_scores"}); auto text_data = texts.data(); - auto score_data = score.data(); - for (size_t text_id = 0; text_id < recognitions.size(); ++text_id) { - text_data[text_id] = static_cast(strings.size()); - strings.push_back(recognitions[text_id].text); - auto& s = recognitions[text_id].score; - if (!s.empty()) { - score_data[text_id] = std::accumulate(s.begin(), s.end(), 0.f) / static_cast(s.size()); - } else { - score_data[text_id] = 0; - } + auto score_data = score.data(); + + for (size_t j = 0; j < recognitions.size(); j++) { + auto& recognition = recognitions[j]; + text_data[j] = static_cast(strings.size()); + strings.push_back(recognition.text); + score_data[j] = static_cast(strings.size()); + strings.push_back(::mmdeploy::to_json(::mmdeploy::to_value(recognition.score)).dump()); } tensors.push_back(std::move(texts)); tensors.push_back(std::move(score)); @@ -164,20 +228,39 @@ void ConvertPreprocess(const Value& item, std::vector& tensors, tensors.push_back(std::move(img_meta_tensor)); } -void ConvertPoseDetections(const Value& item, std::vector& tensors) { - ::mmdeploy::mmpose::PoseDetectorOutput detections; +void ConvertInference(const Value& item, std::vector& tensors) { + for (auto it = item.begin(); it != item.end(); ++it) { + auto tensor = it->get(); + auto desc = tensor.desc(); + desc.name = it.key(); + tensors.emplace_back(desc, tensor.buffer()); + } +} + +void ConvertPoseDetections(const Value& item, int request_count, + const std::vector& batch_per_request, + std::vector>& tensors) { + std::vector<::mmdeploy::mmpose::PoseDetectorOutput> detections; ::mmdeploy::from_value(item, detections); - Tensor pts(TensorDesc{::mmdeploy::Device(0), - ::mmdeploy::DataType::kFLOAT, - {static_cast(detections.key_points.size()), 3}, - "keypoints"}); - auto pts_data = pts.data(); - for (const auto& p : detections.key_points) { - *pts_data++ = p.bbox[0]; - *pts_data++ = p.bbox[1]; - *pts_data++ = p.score; + + int k = 0; + for (int i = 0; i < request_count; i++) { + int num = batch_per_request[i]; + Tensor pts(TensorDesc{::mmdeploy::Device(0), + ::mmdeploy::DataType::kFLOAT, + {num, static_cast(detections[0].key_points.size()), 3}, + "keypoints"}); + auto pts_data = pts.data(); + for (int j = 0; j < num; j++) { + auto& detection = detections[k++]; + for (const auto& p : detection.key_points) { + *pts_data++ = p.bbox[0]; + *pts_data++ = p.bbox[1]; + *pts_data++ = p.score; + } + } + tensors[i].push_back(std::move(pts)); } - tensors.push_back({std::move(pts)}); } void ConvertRotatedDetections(const Value& item, std::vector& tensors) { @@ -185,7 +268,7 @@ void ConvertRotatedDetections(const Value& item, std::vector& tensors) { ::mmdeploy::from_value(item, detections); Tensor bboxes(TensorDesc{::mmdeploy::Device(0), ::mmdeploy::DataType::kFLOAT, - {static_cast(detections.detections.size()), 5}, + {static_cast(detections.detections.size()), 6}, "bboxes"}); Tensor labels(TensorDesc{::mmdeploy::Device(0), ::mmdeploy::DataType::kINT32, @@ -203,13 +286,19 @@ void ConvertRotatedDetections(const Value& item, std::vector& tensors) { } std::vector> ConvertOutputToTensors(const std::string& type, - int32_t request_count, const Value& output, + int32_t request_count, + const std::vector& batch_per_request, + const Value& output, std::vector& strings) { std::vector> tensors(request_count); if (type == "Preprocess") { for (int i = 0; i < request_count; ++i) { ConvertPreprocess(output.front()[i], tensors[i], strings); } + } else if (type == "Inference") { + for (int i = 0; i < request_count; ++i) { + ConvertInference(output.front()[i], tensors[i]); + } } else if (type == "Classifier") { for (int i = 0; i < request_count; ++i) { ConvertClassifications(output.front()[i], tensors[i]); @@ -231,13 +320,9 @@ std::vector> ConvertOutputToTensors(const std::string& type, ConvertTextDetections(output.front()[i], tensors[i]); } } else if (type == "TextRecognizer") { - for (int i = 0; i < request_count; ++i) { - ConvertTextRecognitions(output.front(), tensors[i], strings); - } + ConvertTextRecognitions(output.front(), request_count, batch_per_request, tensors, strings); } else if (type == "PoseDetector") { - for (int i = 0; i < request_count; ++i) { - ConvertPoseDetections(output.front()[i], tensors[i]); - } + ConvertPoseDetections(output.front(), request_count, batch_per_request, tensors); } else if (type == "RotatedDetector") { for (int i = 0; i < request_count; ++i) { ConvertRotatedDetections(output.front()[i], tensors[i]); diff --git a/csrc/mmdeploy/triton/convert.h b/csrc/mmdeploy/triton/convert.h index 23e593efce..0646b53d80 100644 --- a/csrc/mmdeploy/triton/convert.h +++ b/csrc/mmdeploy/triton/convert.h @@ -3,14 +3,16 @@ #ifndef MMDEPLOY_CONVERT_H #define MMDEPLOY_CONVERT_H +#include + #include "mmdeploy/core/tensor.h" #include "mmdeploy/core/value.h" namespace triton::backend::mmdeploy { std::vector> ConvertOutputToTensors( - const std::string& type, int32_t request_count, const ::mmdeploy::Value& output, - std::vector& strings); + const std::string& type, int32_t request_count, const std::vector& batch_per_request, + const ::mmdeploy::Value& output, std::vector& strings); } diff --git a/csrc/mmdeploy/triton/instance_state.cpp b/csrc/mmdeploy/triton/instance_state.cpp index 4315c23cdc..65a0f1af97 100644 --- a/csrc/mmdeploy/triton/instance_state.cpp +++ b/csrc/mmdeploy/triton/instance_state.cpp @@ -3,9 +3,11 @@ #include "instance_state.h" #include +#include #include "convert.h" #include "json.hpp" +#include "json_input.h" #include "mmdeploy/archive/json_archive.h" #include "mmdeploy/core/device.h" #include "mmdeploy/core/mat.h" @@ -31,7 +33,25 @@ ModelInstanceState::ModelInstanceState(ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance) : BackendModelInstance(model_state, triton_model_instance), model_state_(model_state), - pipeline_(model_state_->CreatePipeline(Kind(), DeviceId())) {} + pipeline_(model_state_->CreatePipeline(Kind(), DeviceId())) { + // parse parameters + ::triton::common::TritonJson::Value parameters; + model_state->ModelConfig().Find("parameters", ¶meters); + std::string info; + TryParseModelStringParameter(parameters, "merge_inputs", &info, ""); + if (info != "") { + std::stringstream ss1(info); + std::string group; + while (std::getline(ss1, group, ',')) { + std::stringstream ss2(group); + merge_inputs_.emplace_back(); + int v; + while (ss2 >> v) { + merge_inputs_.back().push_back(v); + } + } + } +} // TRITON DIR MMDeploy // (Tensor, PixFmt, Region) -> (Mat , Region) @@ -50,6 +70,42 @@ TRITONSERVER_Error* ModelInstanceState::Execute(TRITONBACKEND_Request** requests ModelState* model_state = StateForModel(); + const int max_batch_size = model_state->MaxBatchSize(); + + for (size_t i = 0; i < request_count; ++i) { + if (requests[i] == nullptr) { + RequestsRespondWithError( + requests, request_count, + TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + std::string("null request given to MMDeploy backend for '" + Name() + "'").c_str())); + return nullptr; + } + + if (max_batch_size > 0) { + // Retrieve the batch size from one of the inputs, if the model + // supports batching, the first dimension size is batch size + // and batch dim should be 1 for mmdeploy + TRITONBACKEND_Input* input; + TRITONSERVER_Error* err = + TRITONBACKEND_RequestInputByIndex(requests[i], 0 /* index */, &input); + if (err == nullptr) { + const int64_t* shape; + err = TRITONBACKEND_InputProperties(input, nullptr, nullptr, &shape, nullptr, nullptr, + nullptr); + if (err == nullptr && shape[0] != 1) { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + std::string("only support batch dim 1 for single request").c_str()); + } + } + if (err != nullptr) { + RequestsRespondWithError(requests, request_count, err); + return nullptr; + } + } + } + // 'responses' is initialized as a parallel array to 'requests', // with one TRITONBACKEND_Response object for each // TRITONBACKEND_Request object. If something goes wrong while @@ -70,35 +126,30 @@ TRITONSERVER_Error* ModelInstanceState::Execute(TRITONBACKEND_Request** requests responses.push_back(response); } - BackendInputCollector collector(requests, request_count, &responses, - model_state->TritonMemoryManager(), false /* pinned_enabled */, - nullptr /* stream*/); - - // To instruct ProcessTensor to "gather" the entire batch of input - // tensors into a single contiguous buffer in CPU memory, set the - // "allowed input types" to be the CPU ones (see tritonserver.h in - // the triton-inference-server/core repo for allowed memory types). std::vector> allowed_input_types = { {TRITONSERVER_MEMORY_CPU_PINNED, 0}, {TRITONSERVER_MEMORY_CPU, 0}}; - std::vector input_buffers(model_state->input_names().size()); - std::vector> collectors(request_count); std::vector> response_vecs(request_count); + bool need_cuda_input_sync = false; - ::mmdeploy::Value::Array image_and_metas_array; - ::mmdeploy::Value::Array input_tensors_array; - - // Setting input data for (uint32_t request_index = 0; request_index < request_count; ++request_index) { - ::mmdeploy::Value::Object input_tensors; - ::mmdeploy::Value::Object image_and_metas; response_vecs[request_index] = {responses[request_index]}; collectors[request_index] = std::make_unique( &requests[request_index], 1, &response_vecs[request_index], - model_state->TritonMemoryManager(), false, nullptr); + model_state->TritonMemoryManager(), false, CudaStream()); + } + + // Setting input data + ::mmdeploy::Value vec_inputs; + std::vector batch_per_request; + for (uint32_t request_index = 0; request_index < request_count; ++request_index) { + const auto& collector = collectors[request_index]; + ::mmdeploy::Value vec_inputi; + batch_per_request.push_back(1); for (size_t input_id = 0; input_id < model_state->input_names().size(); ++input_id) { + ::mmdeploy::Value inputi; const auto& input_name = model_state->input_names()[input_id]; // Get input shape TRITONBACKEND_Input* input{}; @@ -115,70 +166,122 @@ TRITONSERVER_Error* ModelInstanceState::Execute(TRITONBACKEND_Request** requests size_t buffer_size{}; TRITONSERVER_MemoryType memory_type{}; int64_t memory_type_id{}; - RETURN_IF_ERROR(collectors[request_index]->ProcessTensor( - input_name.c_str(), nullptr, 0, allowed_input_types, &buffer, &buffer_size, - &memory_type, &memory_type_id)); - + RETURN_IF_ERROR(collector->ProcessTensor(input_name.c_str(), nullptr, 0, + allowed_input_types, &buffer, &buffer_size, + &memory_type, &memory_type_id)); ::mmdeploy::framework::Device device(0); if (memory_type == TRITONSERVER_MEMORY_GPU) { device = ::mmdeploy::framework::Device("cuda", static_cast(memory_type_id)); } - if (model_state->input_formats()[request_index] == "FORMAT_NHWC") { + + if (model_state->input_formats()[input_id] == "FORMAT_NHWC") { // Construct Mat from shape & buffer + int h, w; + if (max_batch_size > 0) { + h = dims[1]; + w = dims[2]; + } else { + h = dims[0]; + w = dims[1]; + } ::mmdeploy::framework::Mat mat( - static_cast(dims[0]), static_cast(dims[1]), ::mmdeploy::PixelFormat::kBGR, - ::mmdeploy::DataType::kINT8, + h, w, ::mmdeploy::PixelFormat::kBGR, ::mmdeploy::DataType::kINT8, std::shared_ptr(const_cast(buffer), [](auto) {}), device); - image_and_metas.insert({input_name, mat}); + inputi = {{input_name, mat}}; } else { ::mmdeploy::framework::Tensor tensor( ::mmdeploy::framework::TensorDesc{ - device, ::mmdeploy::DataType::kFLOAT, + device, ConvertDataType(model_state->input_data_types()[input_id]), ::mmdeploy::framework::TensorShape(dims, dims + dims_count), input_name}, std::shared_ptr(const_cast(buffer), [](auto) {})); - input_tensors.insert({input_name, std::move(tensor)}); + inputi = {{input_name, std::move(tensor)}}; } } else { ::mmdeploy::Value value; GetStringInputTensor(input, dims, dims_count, value); assert(value.is_array()); - ::mmdeploy::update(image_and_metas, value.front().object(), 2); + + if (value[0].contains("type")) { + const auto& type = value[0]["type"].get_ref(); + CreateJsonInput(value[0]["value"], type, inputi); + batch_per_request.back() = inputi.size(); + } else { + inputi = {{}}; + inputi.update(value.front().object()); + } } + vec_inputi.push_back(std::move(inputi)); // [ a, [b,b] ] } - if (!input_tensors.empty()) { - input_tensors_array.emplace_back(std::move(input_tensors)); - } - if (!image_and_metas.empty()) { - image_and_metas_array.emplace_back(std::move(image_and_metas)); + // broadcast, [ a, [b,b] ] -> [[a, a], [b, b]] + if (batch_per_request.back() >= 1) { + // std::vector<::mmdeploy::Value> input; + ::mmdeploy::Value input; + for (size_t i = 0; i < vec_inputi.size(); i++) { + input.push_back(::mmdeploy::Value::kArray); + } + + for (int i = 0; i < batch_per_request.back(); i++) { + for (size_t input_id = 0; input_id < model_state->input_names().size(); ++input_id) { + if (vec_inputi[input_id].is_object()) { + input[input_id].push_back(vec_inputi[input_id]); + } else { + input[input_id].push_back(vec_inputi[input_id][i]); + } + } + } + vec_inputi = input; } - // Input from device memory is not supported yet - const bool need_cuda_input_sync = collectors[request_index]->Finalize(); - if (need_cuda_input_sync) { -#if TRITON_ENABLE_GPU - cudaStreamSynchronize(CudaStream()); -#else - LOG_MESSAGE(TRITONSERVER_LOG_ERROR, - "mmdeploy backend: unexpected CUDA sync required by collector"); -#endif + // construct [[a,a,a], [b,b,b]] + if (vec_inputs.is_null()) { + for (size_t i = 0; i < vec_inputi.size(); i++) { + vec_inputs.push_back(::mmdeploy::Value::kArray); + } + } + for (size_t i = 0; i < vec_inputi.size(); i++) { + auto&& inner = vec_inputi[i]; + for (auto&& obj : inner) { + vec_inputs[i].push_back(std::move(obj)); + } } } - ::mmdeploy::Value input_args; - if (!image_and_metas_array.empty()) { - input_args.push_back(std::move(image_and_metas_array)); - } - if (!input_tensors_array.empty()) { - input_args.push_back(std::move(input_tensors_array)); + // merget inputs for example: [[a,a,a], [b,b,b], [c,c,c]] -> [[aaa], [(b,c), (b,c), (b,c)]] + if (!merge_inputs_.empty()) { + int n_example = vec_inputs[0].size(); + ::mmdeploy::Value inputs; + for (const auto& group : merge_inputs_) { + ::mmdeploy::Value input_array; + for (int i = 0; i < n_example; i++) { + ::mmdeploy::Value input_i; + for (const auto& idx : group) { + auto&& inner = vec_inputs[idx]; + input_i.update(inner[i]); + } + input_array.push_back(std::move(input_i)); + } + inputs.push_back(std::move(input_array)); + } + vec_inputs = std::move(inputs); } - MMDEPLOY_DEBUG("input: {}", input_args); + if (need_cuda_input_sync) { +#if TRITON_ENABLE_GPU + cudaStreamSynchronize(CudaStream()); +#else + LOG_MESSAGE(TRITONSERVER_LOG_ERROR, + "mmdeploy backend: unexpected CUDA sync required by collector"); +#endif + } uint64_t compute_start_ns = 0; SET_TIMESTAMP(compute_start_ns); - ::mmdeploy::Value outputs = pipeline_.Apply(input_args); + ::mmdeploy::Value outputs = pipeline_.Apply(vec_inputs); + // MMDEPLOY_ERROR("outputs:\n{}", outputs); + + // preprocess and inference need cuda sync { std::string device_name = "cpu"; if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) { @@ -190,9 +293,8 @@ TRITONSERVER_Error* ModelInstanceState::Execute(TRITONBACKEND_Request** requests } std::vector strings; - auto output_tensors = - ConvertOutputToTensors(model_state->task_type(), request_count, outputs, strings); - + auto output_tensors = ConvertOutputToTensors(model_state->task_type(), request_count, + batch_per_request, outputs, strings); uint64_t compute_end_ns = 0; SET_TIMESTAMP(compute_end_ns); @@ -201,7 +303,7 @@ TRITONSERVER_Error* ModelInstanceState::Execute(TRITONBACKEND_Request** requests for (uint32_t request_index = 0; request_index < request_count; ++request_index) { responders[request_index] = std::make_unique( &requests[request_index], 1, &response_vecs[request_index], false, - model_state->TritonMemoryManager(), false, nullptr); + model_state->TritonMemoryManager(), false, CudaStream()); for (size_t output_id = 0; output_id < model_state->output_names().size(); ++output_id) { auto output_name = model_state->output_names()[output_id]; MMDEPLOY_DEBUG("output name {}", output_name); diff --git a/csrc/mmdeploy/triton/instance_state.h b/csrc/mmdeploy/triton/instance_state.h index bb0d9e71b6..7204a5237d 100644 --- a/csrc/mmdeploy/triton/instance_state.h +++ b/csrc/mmdeploy/triton/instance_state.h @@ -36,6 +36,7 @@ class ModelInstanceState : public BackendModelInstance { private: ModelState* model_state_; ::mmdeploy::Pipeline pipeline_; + std::vector> merge_inputs_; }; } // namespace triton::backend::mmdeploy diff --git a/csrc/mmdeploy/triton/json_input.cpp b/csrc/mmdeploy/triton/json_input.cpp new file mode 100644 index 0000000000..9ed0803cd4 --- /dev/null +++ b/csrc/mmdeploy/triton/json_input.cpp @@ -0,0 +1,14 @@ +#include "json_input.h" + +namespace triton::backend::mmdeploy { + +void CreateJsonInput(::mmdeploy::Value &input, const std::string &type, ::mmdeploy::Value &output) { + if (type == "TextBbox") { + output = input; + } + if (type == "PoseBbox") { + output = input; + } +} + +} // namespace triton::backend::mmdeploy diff --git a/csrc/mmdeploy/triton/json_input.h b/csrc/mmdeploy/triton/json_input.h new file mode 100644 index 0000000000..cf6c67102d --- /dev/null +++ b/csrc/mmdeploy/triton/json_input.h @@ -0,0 +1,22 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#ifndef MMDEPLOY_TRITON_JSON_INPUT_H +#define MMDEPLOY_TRITON_JSON_INPUT_H + +#include +#include + +#include "mmdeploy/archive/value_archive.h" + +namespace triton::backend::mmdeploy { + +struct TextBbox { + std::array bbox; + MMDEPLOY_ARCHIVE_MEMBERS(bbox); +}; + +void CreateJsonInput(::mmdeploy::Value &input, const std::string &type, ::mmdeploy::Value &output); + +} // namespace triton::backend::mmdeploy + +#endif // MMDEPLOY_TRITON_JSON_INPUT_H diff --git a/csrc/mmdeploy/triton/model_state.cpp b/csrc/mmdeploy/triton/model_state.cpp index ae8ae8eb4e..26c34e496f 100644 --- a/csrc/mmdeploy/triton/model_state.cpp +++ b/csrc/mmdeploy/triton/model_state.cpp @@ -2,6 +2,11 @@ #include "model_state.h" +#include + +#include "mmdeploy/archive/json_archive.h" +#include "mmdeploy/archive/value_archive.h" +#include "mmdeploy/core/utils/filesystem.h" #include "mmdeploy/core/utils/formatter.h" #include "mmdeploy/pipeline.hpp" @@ -19,8 +24,7 @@ TRITONSERVER_Error* ModelState::Create(TRITONBACKEND_Model* triton_model, ModelS return {}; } -ModelState::ModelState(TRITONBACKEND_Model* triton_model) - : BackendModel(triton_model), model_(JoinPath({RepositoryPath(), std::to_string(Version())})) { +ModelState::ModelState(TRITONBACKEND_Model* triton_model) : BackendModel(triton_model) { THROW_IF_BACKEND_MODEL_ERROR(ValidateModelConfig()); } @@ -79,12 +83,39 @@ ::mmdeploy::Pipeline ModelState::CreatePipeline(TRITONSERVER_InstanceGroupKind k device_name = "cuda"; } - auto config = model_.ReadConfig("pipeline.json").value(); - - config["context"]["model"] = model_; - - ::mmdeploy::Context context(::mmdeploy::Device(device_name, device_id)); - return {config, context}; + std::string pipeline_template_path = + JoinPath({RepositoryPath(), std::to_string(Version()), "pipeline_template.json"}); + if (fs::exists(pipeline_template_path)) { + std::ifstream ifs(pipeline_template_path, std::ios::binary | std::ios::in); + ifs.seekg(0, std::ios::end); + auto size = ifs.tellg(); + ifs.seekg(0, std::ios::beg); + std::string str(size, '\0'); + ifs.read(str.data(), size); + + auto config = ::mmdeploy::from_json<::mmdeploy::Value>(nlohmann::json::parse(str)); + ::mmdeploy::Context context(::mmdeploy::Device(device_name, device_id)); + config["task_type"].get_to(task_type_); + config.object().erase("task_type"); + if (config.contains("model_names")) { + std::vector model_names; + ::mmdeploy::from_value(config["model_names"], model_names); + for (const auto& name : model_names) { + std::string model_path = JoinPath({RepositoryPath(), std::to_string(Version()), name}); + context.Add(name, ::mmdeploy::Model(model_path)); + } + config.object().erase("model_names"); + } + return {config, context}; + + } else { + model_ = ::mmdeploy::framework::Model(JoinPath({RepositoryPath(), std::to_string(Version())})); + auto config = model_.ReadConfig("pipeline.json").value(); + config["context"]["model"] = model_; + ::mmdeploy::Context context(::mmdeploy::Device(device_name, device_id)); + task_type_ = model_.meta().task; + return {config, context}; + } } } // namespace triton::backend::mmdeploy diff --git a/csrc/mmdeploy/triton/model_state.h b/csrc/mmdeploy/triton/model_state.h index f0ef5fedbe..cbfaf00aa1 100644 --- a/csrc/mmdeploy/triton/model_state.h +++ b/csrc/mmdeploy/triton/model_state.h @@ -24,7 +24,7 @@ class ModelState : public BackendModel { ::mmdeploy::Pipeline CreatePipeline(TRITONSERVER_InstanceGroupKind kind, int device_id); - const std::string& task_type() { return model_.meta().task; } + const std::string& task_type() { return task_type_; } private: explicit ModelState(TRITONBACKEND_Model* triton_model); @@ -33,6 +33,7 @@ class ModelState : public BackendModel { private: ::mmdeploy::framework::Model model_; + std::string task_type_; std::vector input_names_; std::vector output_names_; std::vector input_data_types_; diff --git a/demo/python/to_triton_model.py b/demo/python/to_triton_model.py deleted file mode 100644 index 92d768870f..0000000000 --- a/demo/python/to_triton_model.py +++ /dev/null @@ -1,145 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import argparse -import json -import os.path as osp - -import tritonclient.grpc.model_config_pb2 as pb -from google.protobuf import text_format - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument('model_path', type=str) - parser.add_argument('--name', type=str) - parser.add_argument('--nocopy', type=bool) - return parser.parse_args() - - -IMAGE_INPUT = [ - dict( - name='ori_img', - dtype=pb.TYPE_UINT8, - format=pb.ModelInput.FORMAT_NHWC, - dims=[-1, -1, 3]), - dict(name='pix_fmt', dtype=pb.TYPE_INT32, dims=[1], optional=True) -] - -TASK_OUTPUT = dict( - Preprocess=[ - dict(name='img', dtype=pb.TYPE_FP32, dims=[3, -1, -1]), - dict(name='img_metas', dtype=pb.TYPE_STRING, dims=[1]) - ], - Classifier=[ - dict(name='scores', dtype=pb.TYPE_FP32, dims=[-1, 1]), - dict(name='label_ids', dtype=pb.TYPE_FP32, dims=[-1, 1]) - ], - Detector=[dict(name='dets', dtype=pb.TYPE_FP32, dims=[-1, 1])], - Segmentor=[ - dict(name='mask', dtype=pb.TYPE_INT32, dims=[-1, -1]), - dict(name='score', dtype=pb.TYPE_FP32, dims=[-1, -1, -1]) - ], - Restorer=[dict(name='output', dtype=pb.TYPE_FP32, dims=[-1, -1, 3])], - TextDetector=[], - TextRecognizer=[], - PoseDetector=[], - RotatedDetector=[], - TextOCR=[], - DetPose=[]) - - -def add_input(model_config, params): - p = model_config.input.add() - p.name = params['name'] - p.data_type = params['dtype'] - p.dims.extend(params['dims']) - if 'format' in params: - p.format = params['format'] - if 'optional' in params: - p.optional = params['optional'] - - -def add_output(model_config, params): - p = model_config.output.add() - p.name = params['name'] - p.data_type = params['dtype'] - p.dims.extend(params['dims']) - - -def serialize_model_config(model_config): - return text_format.MessageToString( - model_config, - use_short_repeated_primitives=True, - use_index_order=True, - print_unknown_fields=True) - - -def create_model_config(name, task, backend=None, platform=None): - model_config = pb.ModelConfig() - if backend: - model_config.backend = backend - if platform: - model_config.platform = platform - model_config.name = name - model_config.max_batch_size = 0 - - for input in IMAGE_INPUT: - add_input(model_config, input) - for output in TASK_OUTPUT[task]: - add_output(model_config, output) - return model_config - - -def create_preprocess_model(): - pass - - -def get_onnx_io_names(detail_info): - onnx_config = detail_info['onnx_config'] - return onnx_config['input_names'], onnx_config['output_names'] - - -def create_inference_model(deploy_info, pipeline_info, detail_info): - if 'pipeline' in pipeline_info: - # old-style pipeline specification - pipeline = pipeline_info['pipeline']['tasks'] - else: - pipeline = pipeline_info['tasks'] - - for task_cfg in pipeline: - if task_cfg['module'] == 'Net': - input_names, output_names = get_onnx_io_names(detail_info) - - -def create_postprocess_model(): - pass - - -def create_pipeline_model(): - pass - - -def create_ensemble_model(deploy_cfg, pipeline_cfg): - inference_model_config = create_inference_model(deploy_cfg, pipeline_cfg) - preprocess_model_config = create_preprocess_model() - postprocess_model_config = create_postprocess_model() - pipeline_model_config = create_pipeline_model() - - -def main(): - args = parse_args() - model_path = args.model_path - if not osp.isdir(model_path): - model_path = osp.split(model_path)[-2] - if osp.isdir(model_path): - with open(osp.join(model_path, 'deploy.json'), 'r') as f: - deploy_cfg = json.load(f) - with open(osp.join(model_path, 'pipeline.json'), 'r') as f: - pipeline_cfg = json.load(f) - task = deploy_cfg['task'] - model_config = create_model_config('model', task, 'onnxruntime') - data = serialize_model_config(model_config) - print(data) - - -if __name__ == '__main__': - main() diff --git a/demo/python/triton_client.py b/demo/python/triton_client.py deleted file mode 100644 index 98f3e38d82..0000000000 --- a/demo/python/triton_client.py +++ /dev/null @@ -1,140 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import argparse -from math import cos, sin - -import cv2 -import numpy as np -import tritonclient.grpc as grpcclient - - -def parse_args(): - parser = argparse.ArgumentParser(description='') - parser.add_argument('model_name', type=str) - parser.add_argument('image_path', type=str) - parser.add_argument( - '-v', '--model_version', type=str, required=False, default='1') - parser.add_argument( - '-u', '--url', type=str, required=False, default='localhost:8001') - return parser.parse_args() - - -def get_palette(num_classes=256): - state = np.random.get_state() - # random color - np.random.seed(42) - palette = np.random.randint(0, 256, size=(num_classes, 3)) - np.random.set_state(state) - return [tuple(c) for c in palette] - - -def vis_cls(img, scores, label_ids): - print('\n'.join(map(str, zip(scores, label_ids)))) - - -def vis_det(img, bboxes, labels): - for bbox, label in zip(bboxes, labels): - (left, top, right, bottom), score = bbox[0:4].astype(int), bbox[4] - if score < 0.3: - continue - cv2.rectangle(img, (left, top), (right, bottom), (0, 255, 0)) - return img - - -def vis_rdet(img, bboxes, labels): - for rbbox, label_id in zip(bboxes, labels): - [cx, cy, w, h, angle], score = rbbox[0:5], rbbox[-1] - if score < 0.1: - continue - [wx, wy, hx, hy] = \ - 0.5 * np.array([w, w, -h, h]) * \ - np.array([cos(angle), sin(angle), sin(angle), cos(angle)]) - points = np.array([[[int(cx - wx - hx), - int(cy - wy - hy)], - [int(cx + wx - hx), - int(cy + wy - hy)], - [int(cx + wx + hx), - int(cy + wy + hy)], - [int(cx - wx + hx), - int(cy - wy + hy)]]]) - cv2.drawContours(img, points, -1, (0, 255, 0), 2) - return img - - -def vis_seg(img, mask, scores): - if mask is not None: - print(f'mask {mask.shape}') - if scores is not None: - print(f'scores {scores.shape}') - if mask is None: - mask = np.argmax(scores, axis=0) - - palette = get_palette() - color_seg = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8) - for label, color in enumerate(palette): - color_seg[mask == label, :] = color - # convert to BGR - color_seg = color_seg[..., ::-1] - - img = img * 0.5 + color_seg * 0.5 - return img.astype(np.uint8) - - -def vis_ocr(img, dets, text, text_score): - pts = ((dets[:, 0:8] + 0.5).reshape(len(dets), -1, 2).astype(int)) - cv2.polylines(img, pts, True, (0, 255, 0), 2) - print('\n'.join(map(str, zip(range(len(text)), text, text_score)))) - return img - - -def vis_pose(img, dets, kpts): - pass - - -def main(): - args = parse_args() - triton_client = grpcclient.InferenceServerClient(url=args.url) - - model_config = triton_client.get_model_config( - model_name=args.model_name, model_version=args.model_version) - - img = cv2.imread(args.image_path) - - if img is None: - print(f'failed to load image {args.image_path}') - return - - task = model_config.config.parameters['task'].string_value - - task_map = dict( - Classifier=(('scores', 'labels'), vis_cls), - Detector=(('bboxes', 'labels'), vis_det), - TextOCR=(('dets', 'text', 'text_score'), vis_ocr), - Restorer=(('output', ), lambda _, hires: hires), - Segmentor=(('mask', 'score'), vis_seg), - RotatedDetector=(('bboxes', 'labels'), None), - DetPose=(('bboxes', 'keypoints'), vis_pose)) - - output_names, visualize = task_map[task] - - # request input - inputs = [grpcclient.InferInput('ori_img', img.shape, 'UINT8')] - inputs[0].set_data_from_numpy(img) - - # request outputs - outputs = map(grpcclient.InferRequestedOutput, output_names) - - # run inference - response = triton_client.infer( - model_config.config.name, inputs, outputs=list(outputs)) - - # visualize results - vis = visualize(img, *map(response.as_numpy, output_names)) - - if vis is not None: - cv2.imwrite('vis.jpg', vis) - # cv2.imshow('', vis) - # cv2.waitKey(0) - - -if __name__ == '__main__': - main() diff --git a/demo/triton/image-classification/README.md b/demo/triton/image-classification/README.md new file mode 100644 index 0000000000..f881368b4a --- /dev/null +++ b/demo/triton/image-classification/README.md @@ -0,0 +1 @@ +python tools/deploy.py configs/mmpretrain/classification_tensorrt_static-224x224.py ../mmpretrain/configs/resnet/resnet18_8xb32_in1k.py ../checkpoints/resnet18_8xb32_in1k_20210831-fbbb1da6.pth ../mmclassification/demo/demo.JPEG --device cuda --work-dir work_dirs/resnet --dump-info \ No newline at end of file diff --git a/demo/triton/image-classification/README_CN.md b/demo/triton/image-classification/README_CN.md new file mode 100644 index 0000000000..e69de29bb2 diff --git a/demo/triton/image-classification/grpc_client.py b/demo/triton/image-classification/grpc_client.py new file mode 100644 index 0000000000..51a876fb4b --- /dev/null +++ b/demo/triton/image-classification/grpc_client.py @@ -0,0 +1,76 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import cv2 +from tritonclient.grpc import InferenceServerClient, InferInput, InferRequestedOutput + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('model_name', type=str, + help='model name') + parser.add_argument('image', type=str, + help='image path') + return parser.parse_args() + + +class GRPCTritonClient: + + def __init__(self, url, model_name, model_version): + self._url = url + self._model_name = model_name + self._model_version = model_version + self._client = InferenceServerClient(self._url) + model_config = self._client.get_model_config(self._model_name, + self._model_version) + model_metadata = self._client.get_model_metadata(self._model_name, + self._model_version) + print(f'[model config]:\n{model_config}') + print(f'[model metadata]:\n{model_metadata}') + self._inputs = {input.name: input for input in model_metadata.inputs} + self._input_names = list(self._inputs) + self._outputs = { + output.name: output for output in model_metadata.outputs} + self._output_names = list(self._outputs) + self._outputs_req = [ + InferRequestedOutput(name) for name in self._outputs + ] + + def infer(self, image): + """ + Args: + image: np.ndarray + Returns: + results: dict, {name : numpy.array} + """ + + inputs = [InferInput(self._input_names[0], image.shape, + "UINT8")] + inputs[0].set_data_from_numpy(image) + results = self._client.infer( + model_name=self._model_name, + model_version=self._model_version, + inputs=inputs, + outputs=self._outputs_req) + results = {name: results.as_numpy(name) for name in self._output_names} + return results + + +def visualize(results): + labels = results['labels'] + scores = results['scores'] + assert len(labels) == len(scores) + topk = len(labels) + print(f'top {topk} results:') + for i in range(topk): + print(f'label {labels[i]} score {scores[i]}') + + +if __name__ == "__main__": + args = parse_args() + model_name = args.model_name + model_version = "1" + url = "localhost:8001" + client = GRPCTritonClient(url, model_name, model_version) + img = cv2.imread(args.image) + results = client.infer(img) + visualize(results) diff --git a/demo/triton/image-classification/serving/model/1/README.md b/demo/triton/image-classification/serving/model/1/README.md new file mode 100644 index 0000000000..ff6d1ae274 --- /dev/null +++ b/demo/triton/image-classification/serving/model/1/README.md @@ -0,0 +1 @@ +This directory holds the model files. \ No newline at end of file diff --git a/demo/triton/image-classification/serving/model/config.pbtxt b/demo/triton/image-classification/serving/model/config.pbtxt new file mode 100644 index 0000000000..c141614e46 --- /dev/null +++ b/demo/triton/image-classification/serving/model/config.pbtxt @@ -0,0 +1,20 @@ +backend: "mmdeploy" + +input { + name: "ori_img" + data_type: TYPE_UINT8 + format: FORMAT_NHWC + dims: [ -1, -1, 3 ] + allow_ragged_batch: true +} + +output { + name: "scores" + data_type: TYPE_FP32 + dims: [ -1 ] +} +output { + name: "labels" + data_type: TYPE_INT32 + dims: [ -1 ] +} diff --git a/demo/triton/instance-segmentation/README.md b/demo/triton/instance-segmentation/README.md new file mode 100644 index 0000000000..db199cbfe4 --- /dev/null +++ b/demo/triton/instance-segmentation/README.md @@ -0,0 +1 @@ +python tools/deploy.py configs/mmdet/instance-seg/instance-seg_tensorrt_dynamic-320x320-1344x1344.py ../mmdetection/configs/mask_rcnn/mask-rcnn_r50_fpn_2x_coco.py https://download.openmmlab.com/mmdetection/v2.0/mask_rcnn/mask_rcnn_r50_fpn_2x_coco/mask_rcnn_r50_fpn_2x_coco_bbox_mAP-0.392__segm_mAP-0.354_20200505_003907-3e542a40.pth ../mmdetection/demo/demo.jpg --work-dir work_dir/maskrcnn --dump-info --device cuda \ No newline at end of file diff --git a/demo/triton/instance-segmentation/README_zh-CN.md b/demo/triton/instance-segmentation/README_zh-CN.md new file mode 100644 index 0000000000..e69de29bb2 diff --git a/demo/triton/instance-segmentation/grpc_client.py b/demo/triton/instance-segmentation/grpc_client.py new file mode 100644 index 0000000000..23b33d7d43 --- /dev/null +++ b/demo/triton/instance-segmentation/grpc_client.py @@ -0,0 +1,94 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import cv2 +from tritonclient.grpc import InferenceServerClient, InferInput, InferRequestedOutput +import math + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('model_name', type=str, + help='model name') + parser.add_argument('image', type=str, + help='image path') + return parser.parse_args() + + +class GRPCTritonClient: + + def __init__(self, url, model_name, model_version): + self._url = url + self._model_name = model_name + self._model_version = model_version + self._client = InferenceServerClient(self._url) + model_config = self._client.get_model_config(self._model_name, + self._model_version) + model_metadata = self._client.get_model_metadata(self._model_name, + self._model_version) + print(f'[model config]:\n{model_config}') + print(f'[model metadata]:\n{model_metadata}') + self._inputs = {input.name: input for input in model_metadata.inputs} + self._input_names = list(self._inputs) + self._outputs = { + output.name: output for output in model_metadata.outputs} + self._output_names = list(self._outputs) + self._outputs_req = [ + InferRequestedOutput(name) for name in self._outputs + ] + + def infer(self, image): + """ + Args: + image: np.ndarray + Returns: + results: dict, {name : numpy.array} + """ + + inputs = [InferInput(self._input_names[0], image.shape, + "UINT8")] + inputs[0].set_data_from_numpy(image) + results = self._client.infer( + model_name=self._model_name, + model_version=self._model_version, + inputs=inputs, + outputs=self._outputs_req) + results = {name: results.as_numpy(name) for name in self._output_names} + return results + + +def visualize(img, results): + bboxes = results['bboxes'] + labels = results['labels'] + masks = results['masks'] + mask_offs = results['mask_offs'] + assert len(bboxes) == len(labels) + for i in range(len(bboxes)): + x1, y1, x2, y2, score = bboxes[i] + if score < 0.5: + continue + x1, y1, x2, y2 = map(int, (x1, y1, x2, y2)) + cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 1) + + off, w, h = mask_offs[i] + mask_data = masks[off:off + w * h] + mask = mask_data.reshape(h, w) + + blue, green, red = cv2.split(img) + x0 = int(max(math.floor(x1) - 1, 0)) + y0 = int(max(math.floor(y1) - 1, 0)) + mask_img = blue[y0:y0 + mask.shape[0], x0:x0 + mask.shape[1]] + cv2.bitwise_or(mask, mask_img, mask_img) + img = cv2.merge([blue, green, red]) + + cv2.imwrite('instance-segmentation.jpg', img) + + +if __name__ == "__main__": + args = parse_args() + model_name = args.model_name + model_version = "1" + url = "localhost:8001" + client = GRPCTritonClient(url, model_name, model_version) + img = cv2.imread(args.image) + results = client.infer(img) + visualize(img, results) diff --git a/demo/triton/instance-segmentation/serving/model/1/README.md b/demo/triton/instance-segmentation/serving/model/1/README.md new file mode 100644 index 0000000000..ff6d1ae274 --- /dev/null +++ b/demo/triton/instance-segmentation/serving/model/1/README.md @@ -0,0 +1 @@ +This directory holds the model files. \ No newline at end of file diff --git a/demo/triton/instance-segmentation/serving/model/config.pbtxt b/demo/triton/instance-segmentation/serving/model/config.pbtxt new file mode 100644 index 0000000000..4ec61a589d --- /dev/null +++ b/demo/triton/instance-segmentation/serving/model/config.pbtxt @@ -0,0 +1,30 @@ +backend: "mmdeploy" + +input { + name: "ori_img" + data_type: TYPE_UINT8 + format: FORMAT_NHWC + dims: [ -1, -1, 3 ] + allow_ragged_batch: true +} + +output { + name: "bboxes" + data_type: TYPE_FP32 + dims: [ -1, 5 ] +} +output { + name: "labels" + data_type: TYPE_INT32 + dims: [ -1 ] +} +output { + name: "masks" + data_type: TYPE_UINT8 + dims: [ -1 ] +} +output { + name: "mask_offs" + data_type: TYPE_INT32 + dims: [ -1, 3 ] +} \ No newline at end of file diff --git a/demo/triton/keypoint-detection/README.md b/demo/triton/keypoint-detection/README.md new file mode 100644 index 0000000000..2f2b9e881f --- /dev/null +++ b/demo/triton/keypoint-detection/README.md @@ -0,0 +1,4 @@ +python tools/deploy.py configs/mmpose/pose-detection_tensorrt_static-256x192.py ../mmpose/configs/body_2d_keypoint/topdown_heatmap/coco/td-hm_hrnet-w32_8xb64-210e_coco-256x192.py ../checkpoints/td-hm_hrnet-w32_8xb64-210e_coco-256x192-81c58e40_20220909.pth demo/resources/human-pose.jpg --work-dir work_dir/hrnet --dump-info --device cuda + +python tools/deploy.py configs/mmpose/pose-detection_simcc_tensorrt_dynamic-256x192.py ../mmpose/configs/body_2d_keypoint/simcc/coco/simcc_res50_8xb64-210e_coco-256x192.py https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/simcc/coco/simcc_res50_8xb64-210e_coco-256x192-8e0f5b59_20220919.pth demo/resources/human-pose.jpg --work-dir work_dir/pose2 --dump-info --device cuda + diff --git a/demo/triton/keypoint-detection/README_CN.md b/demo/triton/keypoint-detection/README_CN.md new file mode 100644 index 0000000000..e69de29bb2 diff --git a/demo/triton/keypoint-detection/grpc_client.py b/demo/triton/keypoint-detection/grpc_client.py new file mode 100644 index 0000000000..9a3ee02590 --- /dev/null +++ b/demo/triton/keypoint-detection/grpc_client.py @@ -0,0 +1,92 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import cv2 +from tritonclient.grpc import InferenceServerClient, InferInput, InferRequestedOutput +import numpy as np +import json + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('model_name', type=str, + help='model name') + parser.add_argument('image', type=str, + help='image path') + return parser.parse_args() + + +class GRPCTritonClient: + + def __init__(self, url, model_name, model_version): + self._url = url + self._model_name = model_name + self._model_version = model_version + self._client = InferenceServerClient(self._url) + model_config = self._client.get_model_config(self._model_name, + self._model_version) + model_metadata = self._client.get_model_metadata(self._model_name, + self._model_version) + print(f'[model config]:\n{model_config}') + print(f'[model metadata]:\n{model_metadata}') + self._inputs = {input.name: input for input in model_metadata.inputs} + self._input_names = list(self._inputs) + self._outputs = { + output.name: output for output in model_metadata.outputs} + self._output_names = list(self._outputs) + self._outputs_req = [ + InferRequestedOutput(name) for name in self._outputs + ] + + def infer(self, image, box): + """ + Args: + image: np.ndarray + box: np.ndarray + Returns: + results: dict, {name : numpy.array} + """ + + inputs = [InferInput(self._input_names[0], image.shape, + "UINT8"), + InferInput(self._input_names[1], box.shape, + "BYTES")] + inputs[0].set_data_from_numpy(image) + inputs[1].set_data_from_numpy(box) + results = self._client.infer( + model_name=self._model_name, + model_version=self._model_version, + inputs=inputs, + outputs=self._outputs_req) + results = {name: results.as_numpy(name) for name in self._output_names} + return results + + +def visualize(img, results): + batch_keypoints = results['keypoints'] + for keypoints in batch_keypoints: + n = keypoints.shape[0] + for i in range(n): + x, y, score = keypoints[i] + x, y = map(int, (x, y)) + cv2.circle(img, (x, y), 1, (0, 255, 0), 2) + cv2.imwrite('keypoint-detection.jpg', img) + + +if __name__ == "__main__": + args = parse_args() + model_name = args.model_name + model_version = "1" + url = "localhost:8001" + client = GRPCTritonClient(url, model_name, model_version) + img = cv2.imread(args.image) + bbox = { + 'type': 'PoseBbox', + 'value': [ + { + 'bbox': [0.0, 0.0, img.shape[1], img.shape[0]] + } + ] + } + bbox = np.array([json.dumps(bbox).encode('utf-8')]) + results = client.infer(img, bbox) + visualize(img, results) diff --git a/demo/triton/keypoint-detection/serving/model/1/README.md b/demo/triton/keypoint-detection/serving/model/1/README.md new file mode 100644 index 0000000000..ff6d1ae274 --- /dev/null +++ b/demo/triton/keypoint-detection/serving/model/1/README.md @@ -0,0 +1 @@ +This directory holds the model files. \ No newline at end of file diff --git a/demo/triton/keypoint-detection/serving/model/config.pbtxt b/demo/triton/keypoint-detection/serving/model/config.pbtxt new file mode 100644 index 0000000000..fa6380c421 --- /dev/null +++ b/demo/triton/keypoint-detection/serving/model/config.pbtxt @@ -0,0 +1,29 @@ +backend: "mmdeploy" + +input { + name: "ori_img" + data_type: TYPE_UINT8 + format: FORMAT_NHWC + dims: [ -1, -1, 3 ] + allow_ragged_batch: true +} + +input { + name: "PoseBbox" + data_type: TYPE_STRING + dims: [ 1 ] + allow_ragged_batch: true +} + +output { + name: "keypoints" + data_type: TYPE_INT32 + dims: [ -1, -1, 3 ] +} + +parameters { + key: "merge_inputs", + value: { + string_value: "0 1" + } +} \ No newline at end of file diff --git a/demo/triton/object-detection/README.md b/demo/triton/object-detection/README.md new file mode 100644 index 0000000000..1572845fde --- /dev/null +++ b/demo/triton/object-detection/README.md @@ -0,0 +1 @@ +python tools/deploy.py configs/mmdet/detection/detection_tensorrt_dynamic-320x320-1344x1344.py ../mmdetection/configs/retinanet/retinanet_r18_fpn_1x_coco.py https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r18_fpn_1x_coco/retinanet_r18_fpn_1x_coco_20220407_171055-614fd399.pth ../mmdetection/demo/demo.jpg --work-dir work_dir/retinanet --dump-info --device cuda \ No newline at end of file diff --git a/demo/triton/object-detection/README_zh-CN.md b/demo/triton/object-detection/README_zh-CN.md new file mode 100644 index 0000000000..e69de29bb2 diff --git a/demo/triton/object-detection/grpc_client.py b/demo/triton/object-detection/grpc_client.py new file mode 100644 index 0000000000..3795a84e90 --- /dev/null +++ b/demo/triton/object-detection/grpc_client.py @@ -0,0 +1,79 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import cv2 +from tritonclient.grpc import InferenceServerClient, InferInput, InferRequestedOutput + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('model_name', type=str, + help='model name') + parser.add_argument('image', type=str, + help='image path') + return parser.parse_args() + + +class GRPCTritonClient: + + def __init__(self, url, model_name, model_version): + self._url = url + self._model_name = model_name + self._model_version = model_version + self._client = InferenceServerClient(self._url) + model_config = self._client.get_model_config(self._model_name, + self._model_version) + model_metadata = self._client.get_model_metadata(self._model_name, + self._model_version) + print(f'[model config]:\n{model_config}') + print(f'[model metadata]:\n{model_metadata}') + self._inputs = {input.name: input for input in model_metadata.inputs} + self._input_names = list(self._inputs) + self._outputs = { + output.name: output for output in model_metadata.outputs} + self._output_names = list(self._outputs) + self._outputs_req = [ + InferRequestedOutput(name) for name in self._outputs + ] + + def infer(self, image): + """ + Args: + image: np.ndarray + Returns: + results: dict, {name : numpy.array} + """ + + inputs = [InferInput(self._input_names[0], image.shape, + "UINT8")] + inputs[0].set_data_from_numpy(image) + results = self._client.infer( + model_name=self._model_name, + model_version=self._model_version, + inputs=inputs, + outputs=self._outputs_req) + results = {name: results.as_numpy(name) for name in self._output_names} + return results + + +def visualize(img, results): + bboxes = results['bboxes'] + labels = results['labels'] + assert len(bboxes) == len(labels) + for i in range(len(bboxes)): + x1, y1, x2, y2, score = bboxes[i] + if score < 0.5: + continue + x1, y1, x2, y2 = map(int, (x1, y1, x2, y2)) + cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 1) + cv2.imwrite('object-detection.jpg', img) + + +if __name__ == "__main__": + args = parse_args() + model_name = args.model_name + model_version = "1" + url = "localhost:8001" + client = GRPCTritonClient(url, model_name, model_version) + img = cv2.imread(args.image) + results = client.infer(img) + visualize(img, results) diff --git a/demo/triton/object-detection/serving/model/1/README.md b/demo/triton/object-detection/serving/model/1/README.md new file mode 100644 index 0000000000..ff6d1ae274 --- /dev/null +++ b/demo/triton/object-detection/serving/model/1/README.md @@ -0,0 +1 @@ +This directory holds the model files. \ No newline at end of file diff --git a/demo/triton/object-detection/serving/model/config.pbtxt b/demo/triton/object-detection/serving/model/config.pbtxt new file mode 100644 index 0000000000..04913e0181 --- /dev/null +++ b/demo/triton/object-detection/serving/model/config.pbtxt @@ -0,0 +1,20 @@ +backend: "mmdeploy" + +input { + name: "ori_img" + data_type: TYPE_UINT8 + format: FORMAT_NHWC + dims: [ -1, -1, 3 ] + allow_ragged_batch: true +} + +output { + name: "bboxes" + data_type: TYPE_FP32 + dims: [ -1, 5 ] +} +output { + name: "labels" + data_type: TYPE_INT32 + dims: [ -1 ] +} diff --git a/demo/triton/oriented-object-detection/README.md b/demo/triton/oriented-object-detection/README.md new file mode 100644 index 0000000000..5a055b2be9 --- /dev/null +++ b/demo/triton/oriented-object-detection/README.md @@ -0,0 +1 @@ +python tools/deploy.py configs/mmrotate/rotated-detection_tensorrt_dynamic-320x320-1024x1024.py ../mmrotate/configs/rotated_faster_rcnn/rotated-faster-rcnn-le90_r50_fpn_1x_dota.py ../../mmrotate/checkpoint/rotated_faster_rcnn_r50_fpn_1x_dota_le90-0393aa5c.pth ../mmrotate/demo/demo.jpg --dump-info --work-dir work_dir/rrcnn --device cuda \ No newline at end of file diff --git a/demo/triton/oriented-object-detection/README_zh-CN.md b/demo/triton/oriented-object-detection/README_zh-CN.md new file mode 100644 index 0000000000..e69de29bb2 diff --git a/demo/triton/oriented-object-detection/grpc_client.py b/demo/triton/oriented-object-detection/grpc_client.py new file mode 100644 index 0000000000..299944f944 --- /dev/null +++ b/demo/triton/oriented-object-detection/grpc_client.py @@ -0,0 +1,90 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import cv2 +from tritonclient.grpc import InferenceServerClient, InferInput, InferRequestedOutput +import numpy as np +from math import cos, sin + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('model_name', type=str, + help='model name') + parser.add_argument('image', type=str, + help='image path') + return parser.parse_args() + + +class GRPCTritonClient: + + def __init__(self, url, model_name, model_version): + self._url = url + self._model_name = model_name + self._model_version = model_version + self._client = InferenceServerClient(self._url) + model_config = self._client.get_model_config(self._model_name, + self._model_version) + model_metadata = self._client.get_model_metadata(self._model_name, + self._model_version) + print(f'[model config]:\n{model_config}') + print(f'[model metadata]:\n{model_metadata}') + self._inputs = {input.name: input for input in model_metadata.inputs} + self._input_names = list(self._inputs) + self._outputs = { + output.name: output for output in model_metadata.outputs} + self._output_names = list(self._outputs) + self._outputs_req = [ + InferRequestedOutput(name) for name in self._outputs + ] + + def infer(self, image): + """ + Args: + image: np.ndarray + Returns: + results: dict, {name : numpy.array} + """ + + inputs = [InferInput(self._input_names[0], image.shape, + "UINT8")] + inputs[0].set_data_from_numpy(image) + results = self._client.infer( + model_name=self._model_name, + model_version=self._model_version, + inputs=inputs, + outputs=self._outputs_req) + results = {name: results.as_numpy(name) for name in self._output_names} + return results + + +def visualize(img, results): + bboxes = results['bboxes'] + labels = results['labels'] + for rbbox, label_id in zip(bboxes, labels): + [cx, cy, w, h, angle], score = rbbox[0:5], rbbox[-1] + if score < 0.1: + continue + [wx, wy, hx, hy] = \ + 0.5 * np.array([w, w, -h, h]) * \ + np.array([cos(angle), sin(angle), sin(angle), cos(angle)]) + points = np.array([[[int(cx - wx - hx), + int(cy - wy - hy)], + [int(cx + wx - hx), + int(cy + wy - hy)], + [int(cx + wx + hx), + int(cy + wy + hy)], + [int(cx - wx + hx), + int(cy - wy + hy)]]]) + cv2.drawContours(img, points, -1, (0, 255, 0), 2) + cv2.imwrite('oriented-object-detection.jpg', img) + + +if __name__ == "__main__": + args = parse_args() + model_name = args.model_name + model_version = "1" + url = "localhost:8001" + client = GRPCTritonClient(url, model_name, model_version) + img = cv2.imread(args.image) + results = client.infer(img) + visualize(img, results) diff --git a/demo/triton/oriented-object-detection/serving/model/1/README.md b/demo/triton/oriented-object-detection/serving/model/1/README.md new file mode 100644 index 0000000000..ff6d1ae274 --- /dev/null +++ b/demo/triton/oriented-object-detection/serving/model/1/README.md @@ -0,0 +1 @@ +This directory holds the model files. \ No newline at end of file diff --git a/demo/triton/oriented-object-detection/serving/model/config.pbtxt b/demo/triton/oriented-object-detection/serving/model/config.pbtxt new file mode 100644 index 0000000000..142570726f --- /dev/null +++ b/demo/triton/oriented-object-detection/serving/model/config.pbtxt @@ -0,0 +1,20 @@ +backend: "mmdeploy" + +input { + name: "ori_img" + data_type: TYPE_UINT8 + format: FORMAT_NHWC + dims: [ -1, -1, 3 ] + allow_ragged_batch: true +} + +output { + name: "bboxes" + data_type: TYPE_FP32 + dims: [ -1, 6 ] +} +output { + name: "labels" + data_type: TYPE_INT32 + dims: [ -1 ] +} diff --git a/demo/triton/semantic-segmentation/README.md b/demo/triton/semantic-segmentation/README.md new file mode 100644 index 0000000000..12c3275d29 --- /dev/null +++ b/demo/triton/semantic-segmentation/README.md @@ -0,0 +1 @@ +python tools/deploy.py configs/mmseg/segmentation_tensorrt-fp16_static-512x1024.py ../mmsegmentation/configs/pspnet/pspnet_r18-d8_4xb2-80k_cityscapes-512x1024.py ../../checkpoints/pspnet_r18-d8_512x1024_80k_cityscapes_20201225_021458-09ffa746.pth ../mmsegmentation/demo/demo.png --work-dir work_dir/pspnet --dump-info --device cuda \ No newline at end of file diff --git a/demo/triton/semantic-segmentation/README_CN.md b/demo/triton/semantic-segmentation/README_CN.md new file mode 100644 index 0000000000..e69de29bb2 diff --git a/demo/triton/semantic-segmentation/grpc_client.py b/demo/triton/semantic-segmentation/grpc_client.py new file mode 100644 index 0000000000..72e967a230 --- /dev/null +++ b/demo/triton/semantic-segmentation/grpc_client.py @@ -0,0 +1,95 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import cv2 +from tritonclient.grpc import InferenceServerClient, InferInput, InferRequestedOutput +import numpy as np + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('model_name', type=str, + help='model name') + parser.add_argument('image', type=str, + help='image path') + return parser.parse_args() + + +def get_palette(num_classes=256): + state = np.random.get_state() + # random color + np.random.seed(42) + palette = np.random.randint(0, 256, size=(num_classes, 3)) + np.random.set_state(state) + return [tuple(c) for c in palette] + + +class GRPCTritonClient: + + def __init__(self, url, model_name, model_version): + self._url = url + self._model_name = model_name + self._model_version = model_version + self._client = InferenceServerClient(self._url) + model_config = self._client.get_model_config(self._model_name, + self._model_version) + model_metadata = self._client.get_model_metadata(self._model_name, + self._model_version) + print(f'[model config]:\n{model_config}') + print(f'[model metadata]:\n{model_metadata}') + self._inputs = {input.name: input for input in model_metadata.inputs} + self._input_names = list(self._inputs) + self._outputs = { + output.name: output for output in model_metadata.outputs} + self._output_names = list(self._outputs) + self._outputs_req = [ + InferRequestedOutput(name) for name in self._outputs + ] + + def infer(self, image): + """ + Args: + image: np.ndarray + Returns: + results: dict, {name : numpy.array} + """ + + inputs = [InferInput(self._input_names[0], image.shape, + "UINT8")] + inputs[0].set_data_from_numpy(image) + results = self._client.infer( + model_name=self._model_name, + model_version=self._model_version, + inputs=inputs, + outputs=self._outputs_req) + results = {name: results.as_numpy(name) for name in self._output_names} + return results + + +def visualize(img, results): + if 'mask' in results: + seg = results['mask'] + else: + score = results['score'] + seg = np.argmax(score, axis=0) + + palette = get_palette() + color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) + for label, color in enumerate(palette): + color_seg[seg == label, :] = color + # convert to BGR + color_seg = color_seg[..., ::-1] + + img = img * 0.5 + color_seg * 0.5 + img = img.astype(np.uint8) + cv2.imwrite('semantic-segmentation.png', img) + + +if __name__ == "__main__": + args = parse_args() + model_name = args.model_name + model_version = "1" + url = "localhost:8001" + client = GRPCTritonClient(url, model_name, model_version) + img = cv2.imread(args.image) + results = client.infer(img) + visualize(img, results) diff --git a/demo/triton/semantic-segmentation/serving/mask/model/1/README.md b/demo/triton/semantic-segmentation/serving/mask/model/1/README.md new file mode 100644 index 0000000000..ff6d1ae274 --- /dev/null +++ b/demo/triton/semantic-segmentation/serving/mask/model/1/README.md @@ -0,0 +1 @@ +This directory holds the model files. \ No newline at end of file diff --git a/demo/triton/semantic-segmentation/serving/mask/model/config.pbtxt b/demo/triton/semantic-segmentation/serving/mask/model/config.pbtxt new file mode 100644 index 0000000000..e7a71e570d --- /dev/null +++ b/demo/triton/semantic-segmentation/serving/mask/model/config.pbtxt @@ -0,0 +1,15 @@ +backend: "mmdeploy" + +input { + name: "ori_img" + data_type: TYPE_UINT8 + format: FORMAT_NHWC + dims: [-1, -1, 3] + allow_ragged_batch: true +} + +output { + name: "mask" + data_type: TYPE_INT32 + dims: [ -1, -1 ] +} diff --git a/demo/triton/semantic-segmentation/serving/score/model/1/README.md b/demo/triton/semantic-segmentation/serving/score/model/1/README.md new file mode 100644 index 0000000000..ff6d1ae274 --- /dev/null +++ b/demo/triton/semantic-segmentation/serving/score/model/1/README.md @@ -0,0 +1 @@ +This directory holds the model files. \ No newline at end of file diff --git a/demo/triton/semantic-segmentation/serving/score/model/config.pbtxt b/demo/triton/semantic-segmentation/serving/score/model/config.pbtxt new file mode 100644 index 0000000000..e905be48e4 --- /dev/null +++ b/demo/triton/semantic-segmentation/serving/score/model/config.pbtxt @@ -0,0 +1,15 @@ +backend: "mmdeploy" + +input { + name: "ori_img" + data_type: TYPE_UINT8 + format: FORMAT_NHWC + dims: [-1, -1, 3] + allow_ragged_batch: true +} + +output { + name: "score" + data_type: TYPE_FP32 + dims: [ -1, -1, -1 ] +} \ No newline at end of file diff --git a/demo/triton/text-detection/README.md b/demo/triton/text-detection/README.md new file mode 100644 index 0000000000..7ecb52c498 --- /dev/null +++ b/demo/triton/text-detection/README.md @@ -0,0 +1,6 @@ +python tools/deploy.py configs/mmocr/text-detection/text-detection_tensorrt_dynamic-320x320-2240x2240.py ../mmocr/configs/textdet/panet/panet_resnet18_fpem-ffm_600e_icdar2015.py https://download.openmmlab.com/mmocr/textdet/panet/panet_resnet18_fpem-ffm_600e_icdar2015/panet_resnet18_fpem-ffm_600e_icdar2015_20220826_144817-be2acdb4.pth ../mmocr/demo/demo_text_det.jpg --work-dir work_dir/panet --dump-info --device cuda:0 + +python tools/deploy.py configs/mmocr/text-detection/text-detection_tensorrt_dynamic-320x320-2240x2240.py ../mmocr/configs/textdet/dbnet/dbnet_resnet18_fpnc_1200e_icdar2015.py https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_resnet18_fpnc_1200e_icdar2015/dbnet_resnet18_fpnc_1200e_icdar2015_20220825_221614-7c0e94f2.pth ../mmocr/demo/demo_text_det.jpg --work-dir work_dir/dbnet --dump-info --device cuda:0 + + +python tools/deploy.py configs/mmocr/text-detection/text-detection_tensorrt_dynamic-320x320-2240x2240.py ../mmocr/configs/textdet/psenet/psenet_resnet50_fpnf_600e_icdar2015.py https://download.openmmlab.com/mmocr/textdet/psenet/psenet_resnet50_fpnf_600e_icdar2015/psenet_resnet50_fpnf_600e_icdar2015_20220825_222709-b6741ec3.pth ../mmocr/demo/demo_text_det.jpg --work-dir work_dir/psenet --dump-info --device cuda:0 \ No newline at end of file diff --git a/demo/triton/text-detection/README_CN.md b/demo/triton/text-detection/README_CN.md new file mode 100644 index 0000000000..e69de29bb2 diff --git a/demo/triton/text-detection/grpc_client.py b/demo/triton/text-detection/grpc_client.py new file mode 100644 index 0000000000..64abd1e128 --- /dev/null +++ b/demo/triton/text-detection/grpc_client.py @@ -0,0 +1,82 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import cv2 +from tritonclient.grpc import InferenceServerClient, InferInput, InferRequestedOutput +import numpy as np +import json + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('model_name', type=str, + help='model name') + parser.add_argument('image', type=str, + help='image path') + return parser.parse_args() + + +class GRPCTritonClient: + + def __init__(self, url, model_name, model_version): + self._url = url + self._model_name = model_name + self._model_version = model_version + self._client = InferenceServerClient(self._url) + model_config = self._client.get_model_config(self._model_name, + self._model_version) + model_metadata = self._client.get_model_metadata(self._model_name, + self._model_version) + print(f'[model config]:\n{model_config}') + print(f'[model metadata]:\n{model_metadata}') + self._inputs = {input.name: input for input in model_metadata.inputs} + self._input_names = list(self._inputs) + self._outputs = { + output.name: output for output in model_metadata.outputs} + self._output_names = list(self._outputs) + self._outputs_req = [ + InferRequestedOutput(name) for name in self._outputs + ] + + def infer(self, image): + """ + Args: + image: np.ndarray + Returns: + results: dict, {name : numpy.array} + """ + + inputs = [InferInput(self._input_names[0], image.shape, + "UINT8")] + inputs[0].set_data_from_numpy(image) + results = self._client.infer( + model_name=self._model_name, + model_version=self._model_version, + inputs=inputs, + outputs=self._outputs_req) + results = {name: results.as_numpy(name) for name in self._output_names} + return results + + +def visualize(img, results): + bboxes = results['bboxes'] + scores = results['scores'] + for (bbox, score) in zip(bboxes, scores): + x = list(map(int, bbox[::2])) + y = list(map(int, bbox[1::2])) + n = len(x) + for i in range(n): + p1 = (x[i], y[i]) + p2 = (x[(i + 1) % n], y[(i + 1) % n]) + img = cv2.line(img, p1, p2, (0, 255, 0), 1) + cv2.imwrite('text-detection.jpg', img) + + +if __name__ == "__main__": + args = parse_args() + model_name = args.model_name + model_version = "1" + url = "localhost:8001" + client = GRPCTritonClient(url, model_name, model_version) + img = cv2.imread(args.image) + results = client.infer(img) + visualize(img, results) diff --git a/demo/triton/text-detection/serving/model/1/README.md b/demo/triton/text-detection/serving/model/1/README.md new file mode 100644 index 0000000000..ff6d1ae274 --- /dev/null +++ b/demo/triton/text-detection/serving/model/1/README.md @@ -0,0 +1 @@ +This directory holds the model files. \ No newline at end of file diff --git a/demo/triton/text-detection/serving/model/config.pbtxt b/demo/triton/text-detection/serving/model/config.pbtxt new file mode 100644 index 0000000000..b8afc6aa76 --- /dev/null +++ b/demo/triton/text-detection/serving/model/config.pbtxt @@ -0,0 +1,21 @@ +backend: "mmdeploy" + +input { + name: "ori_img" + data_type: TYPE_UINT8 + format: FORMAT_NHWC + dims: [-1, -1, 3] + allow_ragged_batch: true +} + +output { + name: "bboxes" + data_type: TYPE_FP32 + dims: [ -1, 8 ] +} + +output { + name: "scores" + data_type: TYPE_FP32 + dims: [ -1, 1 ] +} diff --git a/demo/triton/text-ocr/grpc_client.py b/demo/triton/text-ocr/grpc_client.py new file mode 100644 index 0000000000..28729f10fa --- /dev/null +++ b/demo/triton/text-ocr/grpc_client.py @@ -0,0 +1,80 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import cv2 +from tritonclient.grpc import InferenceServerClient, InferInput, InferRequestedOutput + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('model_name', type=str, + help='model name') + parser.add_argument('image', type=str, + help='image path') + return parser.parse_args() + + +class GRPCTritonClient: + + def __init__(self, url, model_name, model_version): + self._url = url + self._model_name = model_name + self._model_version = model_version + self._client = InferenceServerClient(self._url) + model_config = self._client.get_model_config(self._model_name, + self._model_version) + model_metadata = self._client.get_model_metadata(self._model_name, + self._model_version) + print(f'[model config]:\n{model_config}') + print(f'[model metadata]:\n{model_metadata}') + self._inputs = {input.name: input for input in model_metadata.inputs} + self._input_names = list(self._inputs) + self._outputs = { + output.name: output for output in model_metadata.outputs} + self._output_names = list(self._outputs) + self._outputs_req = [ + InferRequestedOutput(name) for name in self._outputs + ] + + def infer(self, image): + """ + Args: + image: np.ndarray + Returns: + results: dict, {name : numpy.array} + """ + + inputs = [InferInput(self._input_names[0], image.shape, + "UINT8")] + inputs[0].set_data_from_numpy(image) + results = self._client.infer( + model_name=self._model_name, + model_version=self._model_version, + inputs=inputs, + outputs=self._outputs_req) + results = {name: results.as_numpy(name) for name in self._output_names} + return results + + +def visualize(results): + det_bboxes = results['bboxes'] + det_scores = results['scores'] + rec_texts = results['rec_texts'] + rec_scores = results['rec_scores'] + for i, (det_bbox, det_score, rec_text, rec_score) in \ + enumerate(zip(det_bboxes, det_scores, rec_texts, rec_scores)): + print(f'bbox[{i}] ({det_bbox[0]:.2f}, {det_bbox[1]:.2f}), ' + f'({det_bbox[2]:.2f}, {det_bbox[3]:.2f}), ({det_bbox[4]:.2f}, {det_bbox[5]:.2f}), ' + f'({det_bbox[6]:.2f}, {det_bbox[7]:.2f}), {det_score[0]:.2f}') + text = rec_text.decode('utf-8') + print(f'text[{i}] {text}') + + +if __name__ == "__main__": + args = parse_args() + model_name = args.model_name + model_version = "1" + url = "localhost:8001" + client = GRPCTritonClient(url, model_name, model_version) + img = cv2.imread(args.image) + results = client.infer(img) + visualize(results) diff --git a/demo/triton/text-ocr/serving/model/1/pipeline_template.json b/demo/triton/text-ocr/serving/model/1/pipeline_template.json new file mode 100644 index 0000000000..f841e83c9f --- /dev/null +++ b/demo/triton/text-ocr/serving/model/1/pipeline_template.json @@ -0,0 +1,50 @@ +{ + "model_names": [ + "text_detection", + "text_recognition" + ], + "task_type": "TextOCR", + "type": "Pipeline", + "input": "img", + "output": [ + "dets", + "texts" + ], + "tasks": [ + { + "type": "Inference", + "input": "img", + "output": "dets", + "params": { + "model": "text_detection" + } + }, + { + "type": "Pipeline", + "input": [ + "bboxes=*dets", + "imgs=+img" + ], + "tasks": [ + { + "type": "Task", + "module": "WarpBbox", + "input": [ + "imgs", + "bboxes" + ], + "output": "patches" + }, + { + "type": "Inference", + "input": "patches", + "output": "texts", + "params": { + "model": "text_recognition" + } + } + ], + "output": "*texts" + } + ] +} \ No newline at end of file diff --git a/demo/triton/text-ocr/serving/model/1/text_detection/README.md b/demo/triton/text-ocr/serving/model/1/text_detection/README.md new file mode 100644 index 0000000000..e69de29bb2 diff --git a/demo/triton/text-ocr/serving/model/1/text_recognition/README.md b/demo/triton/text-ocr/serving/model/1/text_recognition/README.md new file mode 100644 index 0000000000..e69de29bb2 diff --git a/demo/triton/text-ocr/serving/model/config.pbtxt b/demo/triton/text-ocr/serving/model/config.pbtxt new file mode 100644 index 0000000000..9864802ad8 --- /dev/null +++ b/demo/triton/text-ocr/serving/model/config.pbtxt @@ -0,0 +1,33 @@ +backend: "mmdeploy" + +input { + name: "ori_img" + data_type: TYPE_UINT8 + format: FORMAT_NHWC + dims: [ -1, -1, 3 ] + allow_ragged_batch: true +} + +output { + name: "bboxes" + data_type: TYPE_FP32 + dims: [ -1, 8 ] +} + +output { + name: "scores" + data_type: TYPE_FP32 + dims: [ -1, 1 ] +} + +output { + name: "rec_texts" + data_type: TYPE_STRING + dims: [ -1, 1] +} + +output { + name: "rec_scores" + data_type: TYPE_STRING + dims: [ -1, 1 ] +} diff --git a/demo/triton/text-recognition/README.md b/demo/triton/text-recognition/README.md new file mode 100644 index 0000000000..9eaa1df85a --- /dev/null +++ b/demo/triton/text-recognition/README.md @@ -0,0 +1 @@ +python tools/deploy.py configs/mmocr/text-recognition/text-recognition_tensorrt-fp16_dynamic-1x32x32-1x32x640.py ../mmocr/configs/textrecog/crnn/crnn_mini-vgg_5e_mj.py https://download.openmmlab.com/mmocr/textrecog/crnn/crnn_mini-vgg_5e_mj/crnn_mini-vgg_5e_mj_20220826_224120-8afbedbb.pth ../mmocr/demo/demo_text_recog.jpg --work-dir work_dir/crnn --device cuda --dump-info \ No newline at end of file diff --git a/demo/triton/text-recognition/README_CN.md b/demo/triton/text-recognition/README_CN.md new file mode 100644 index 0000000000..e69de29bb2 diff --git a/demo/triton/text-recognition/grpc_client.py b/demo/triton/text-recognition/grpc_client.py new file mode 100644 index 0000000000..8d045e2aae --- /dev/null +++ b/demo/triton/text-recognition/grpc_client.py @@ -0,0 +1,90 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import cv2 +from tritonclient.grpc import InferenceServerClient, InferInput, InferRequestedOutput +import numpy as np +import json + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('model_name', type=str, + help='model name') + parser.add_argument('image', type=str, + help='image path') + return parser.parse_args() + + +class GRPCTritonClient: + + def __init__(self, url, model_name, model_version): + self._url = url + self._model_name = model_name + self._model_version = model_version + self._client = InferenceServerClient(self._url) + model_config = self._client.get_model_config(self._model_name, + self._model_version) + model_metadata = self._client.get_model_metadata(self._model_name, + self._model_version) + print(f'[model config]:\n{model_config}') + print(f'[model metadata]:\n{model_metadata}') + self._inputs = {input.name: input for input in model_metadata.inputs} + self._input_names = list(self._inputs) + self._outputs = { + output.name: output for output in model_metadata.outputs} + self._output_names = list(self._outputs) + self._outputs_req = [ + InferRequestedOutput(name) for name in self._outputs + ] + + def infer(self, image, box): + """ + Args: + image: np.ndarray + box: np.ndarray + Returns: + results: dict, {name : numpy.array} + """ + + inputs = [InferInput(self._input_names[0], image.shape, + "UINT8"), + InferInput(self._input_names[1], box.shape, + "BYTES")] + inputs[0].set_data_from_numpy(image) + inputs[1].set_data_from_numpy(box) + results = self._client.infer( + model_name=self._model_name, + model_version=self._model_version, + inputs=inputs, + outputs=self._outputs_req) + results = {name: results.as_numpy(name) for name in self._output_names} + return results + + +def visualize(results): + texts = results['texts'] + scores = results['scores'] + for box_texts_, box_scores_ in zip(texts, scores): + box_texts = box_texts_.decode('utf-8') + box_scores = json.loads(box_scores_.decode('utf-8')) + print(box_texts, box_scores) + + +if __name__ == "__main__": + args = parse_args() + model_name = args.model_name + model_version = "1" + url = "localhost:8001" + client = GRPCTritonClient(url, model_name, model_version) + img = cv2.imread(args.image) + bbox = { + 'type': 'TextBbox', + 'value': [ + { + 'bbox': [0.0, 0.0, img.shape[1], 0, img.shape[1], img.shape[0], 0, img.shape[0]], + } + ] + } + bbox = np.array([json.dumps(bbox).encode('utf-8')]) + results = client.infer(img, bbox) + visualize(results) diff --git a/demo/triton/text-recognition/serving/model/1/README.md b/demo/triton/text-recognition/serving/model/1/README.md new file mode 100644 index 0000000000..ff6d1ae274 --- /dev/null +++ b/demo/triton/text-recognition/serving/model/1/README.md @@ -0,0 +1 @@ +This directory holds the model files. \ No newline at end of file diff --git a/demo/triton/text-recognition/serving/model/config.pbtxt b/demo/triton/text-recognition/serving/model/config.pbtxt new file mode 100644 index 0000000000..759911a09b --- /dev/null +++ b/demo/triton/text-recognition/serving/model/config.pbtxt @@ -0,0 +1,28 @@ +backend: "mmdeploy" + +input { + name: "ori_img" + data_type: TYPE_UINT8 + format: FORMAT_NHWC + dims: [ -1, -1, 3 ] + allow_ragged_batch: true +} + +input { + name: "TextBbox" + data_type: TYPE_STRING + dims: [ 1 ] + allow_ragged_batch: true +} + +output { + name: "texts" + data_type: TYPE_STRING + dims: [ -1, 1] +} + +output { + name: "scores" + data_type: TYPE_STRING + dims: [ -1, 1 ] +} diff --git a/demo/triton/to_triton_model.py b/demo/triton/to_triton_model.py new file mode 100644 index 0000000000..cdeb2d9ff6 --- /dev/null +++ b/demo/triton/to_triton_model.py @@ -0,0 +1,178 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import json +import os +import os.path as osp +from enum import Enum, unique +import shutil +from glob import glob + +BASEDIR = os.path.dirname(__file__) + + +@unique +class Template(str, Enum): + ImageClassification = 'image-classification/serving' + InstanceSegmentation = 'instance-segmentation/serving' + KeypointDetection = 'keypoint-detection/serving' + ObjectDetection = 'object-detection/serving' + OrientedObjectDetection = 'oriented-object-detection/serving' + SemanticSegmentation1 = 'semantic-segmentation/serving/mask' + SemanticSegmentation2 = 'semantic-segmentation/serving/score' + TextRecognition = 'text-recognition/serving' + TextDetection = 'text-detection/serving' + + +def copy_template(src_folder, dst_folder): + files = glob(osp.join(src_folder, '*')) + for src in files: + dst = osp.join(dst_folder, osp.basename(src)) + if osp.isdir(src): + shutil.copytree(src, dst, dirs_exist_ok=True) + else: + shutil.copy(src, dst) + + +class Convert: + + def __init__(self, model_type, model_dir, deploy_cfg, pipeline_cfg, detail_cfg, output_dir): + self._model_type = model_type + self._model_dir = model_dir + self._deploy_cfg = deploy_cfg + self._pipeline_cfg = pipeline_cfg + self._detail_cfg = detail_cfg + self._output_dir = output_dir + + def copy_file(self, file_name, src_folder, dst_folder): + src_path = osp.join(src_folder, file_name) + dst_path = osp.join(dst_folder, file_name) + if osp.isdir(src_path): + shutil.copytree(src_path, dst_path) + else: + shutil.copy(src_path, dst_path) + + def write_json_file(self, data, file_name, dst_folder): + dst_path = osp.join(dst_folder, file_name) + with open(dst_path, 'w') as f: + json.dump(data, f, indent=4) + + def create_single_model(self): + output_model_folder = osp.join(self._output_dir, 'model', '1') + if (self._model_type == Template.TextRecognition): + self._pipeline_cfg['pipeline']['input'].append('bbox') + self._pipeline_cfg['pipeline']['tasks'][0]['input'] = ['patch'] + warpbbox = { + "type": "Task", + "module": "WarpBbox", + "input": [ + "img", + "bbox" + ], + "output": [ + "patch" + ] + } + self._pipeline_cfg['pipeline']['tasks'].insert(0, warpbbox) + self.write_json_file(self._pipeline_cfg, + 'pipeline.json', output_model_folder) + else: + self.copy_file('pipeline.json', self._model_dir, + output_model_folder) + + self.copy_file('deploy.json', self._model_dir, output_model_folder) + models = self._deploy_cfg['models'] + for model in models: + net = model['net'] + self.copy_file(net, self._model_dir, output_model_folder) + for custom in self._deploy_cfg['customs']: + self.copy_file(custom, self._model_dir, output_model_folder) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('model_dir', type=str, + help='converted model dir with `--dump-info` flag when convert the model') + parser.add_argument('output_dir', type=str, + help='output dir') + return parser.parse_args() + + +def get_model_type(detail_cfg, pipeline_cfg): + task = detail_cfg['codebase_config']['task'] + output_names = detail_cfg['onnx_config']['output_names'] + + if task == 'Classification': + return Template.ImageClassification + if task == 'ObjectDetection': + if 'masks' in output_names: + return Template.InstanceSegmentation + else: + return Template.ObjectDetection + if task == 'Segmentation': + with_argmax = pipeline_cfg['pipeline']['tasks'][-1]['params'].get( + 'with_argmax', True) + if with_argmax: + return Template.SemanticSegmentation1 + else: + return Template.SemanticSegmentation2 + if task == 'PoseDetection': + return Template.KeypointDetection + if task == 'RotatedDetection': + return Template.OrientedObjectDetection + if task == 'TextRecognition': + return Template.TextRecognition + if task == 'TextDetection': + return Template.TextDetection + + assert 0, f'doesn\'t support task {task} with output_names: {output_names}' + + +if __name__ == '__main__': + args = parse_args() + model_dir = args.model_dir + output_dir = args.output_dir + + # check + assert osp.isdir(model_dir), f'model dir {model_dir} doesn\'t exist' + info_files = ['deploy.json', 'pipeline.json', 'detail.json'] + for file in info_files: + path = osp.join(model_dir, file) + assert osp.exists(path), f'{path} doesn\'t exist in {model_dir}' + + with open(osp.join(model_dir, 'deploy.json')) as f: + deploy_cfg = json.load(f) + with open(osp.join(model_dir, 'pipeline.json')) as f: + pipeline_cfg = json.load(f) + with open(osp.join(model_dir, 'detail.json')) as f: + detail_cfg = json.load(f) + assert 'onnx_config' in detail_cfg, f'currently, only support onnx as middle ir' + + # process + model_type = get_model_type(detail_cfg, pipeline_cfg) + convert = Convert(model_type, model_dir, deploy_cfg, pipeline_cfg, + detail_cfg, output_dir) + + src_folder = osp.join(BASEDIR, model_type.value) + + if not osp.exists(output_dir): + os.makedirs(output_dir) + copy_template(src_folder, output_dir) + convert.create_single_model() + + +# /data/testmodel/mmcls + +# /home/cx/ws/2.0/mmdeploy/work_dir/rtn2/ +# /home/cx/ws/2.0/mmdeploy/work_dir/maskrcnn/ + +# /home/cx/ws/2.0/mmdeploy/work_dir/pose1 +# /home/cx/ws/2.0/mmdeploy/work_dir/pose2 + +# /home/cx/ws/2.0/mmdeploy/work_dir/pspnet +# /home/cx/ws/2.0/mmdeploy/work_dir/pspnet_mask + +# /home/cx/ws/2.0/mmdeploy/work_dirs/rrcnn + +# /home/cx/ws/2.0/mmdeploy/work_dir/panet + +# /home/cx/ws/2.0/mmdeploy/work_dir/crnn From bfa3c64e5fb2bd16c3e79b1c35e56d2ca8118b95 Mon Sep 17 00:00:00 2001 From: irexyc Date: Mon, 15 May 2023 17:48:05 +0800 Subject: [PATCH 10/16] update triton cmakelist --- csrc/mmdeploy/triton/CMakeLists.txt | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/csrc/mmdeploy/triton/CMakeLists.txt b/csrc/mmdeploy/triton/CMakeLists.txt index efdf880d7a..c77b8610d9 100644 --- a/csrc/mmdeploy/triton/CMakeLists.txt +++ b/csrc/mmdeploy/triton/CMakeLists.txt @@ -38,9 +38,13 @@ project(tritonmmdeploybackend LANGUAGES C CXX) # doesn't use GPUs. # -set(TRITON_COMMON_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/common repo") -set(TRITON_CORE_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/core repo") -set(TRITON_BACKEND_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/backend repo") +if (NOT TRITON_TAG) + set(TRITON_TAG main) +endif() + +set(TRITON_COMMON_REPO_TAG ${TRITON_TAG} CACHE STRING "Tag for triton-inference-server/common repo") +set(TRITON_CORE_REPO_TAG ${TRITON_TAG} CACHE STRING "Tag for triton-inference-server/core repo") +set(TRITON_BACKEND_REPO_TAG ${TRITON_TAG} CACHE STRING "Tag for triton-inference-server/backend repo") # @@ -73,6 +77,7 @@ add_library(triton-mmdeploy-backend SHARED model_state.cpp instance_state.cpp convert.cpp + json_input.cpp mmdeploy.cpp) target_include_directories(triton-mmdeploy-backend PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) @@ -98,7 +103,7 @@ target_link_libraries(triton-mmdeploy-backend PRIVATE MMDeployLibs) set_target_properties(triton-mmdeploy-backend PROPERTIES INSTALL_RPATH "\$ORIGIN") -install(TARGETS triton-mmdeploy-backend DESTINATION backend/mmdeploy) +install(TARGETS triton-mmdeploy-backend DESTINATION backends/mmdeploy) if (WIN32) set_target_properties( From b901112a44b50042a7bd572f4c55024e34629f48 Mon Sep 17 00:00:00 2001 From: irexyc Date: Tue, 16 May 2023 11:31:20 +0800 Subject: [PATCH 11/16] update dockerfile --- docker/triton/Dockerfile | 78 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 docker/triton/Dockerfile diff --git a/docker/triton/Dockerfile b/docker/triton/Dockerfile new file mode 100644 index 0000000000..a1e9a16adb --- /dev/null +++ b/docker/triton/Dockerfile @@ -0,0 +1,78 @@ +FROM nvcr.io/nvidia/tritonserver:22.12-pyt-python-py3 + +ARG CUDA=11.3 +ARG TORCH_VERSION=1.10.0 +ARG TORCHVISION_VERSION=0.11.0 +ARG ONNXRUNTIME_VERSION=1.8.1 +ARG PPLCV_VERSION=0.7.0 +ENV FORCE_CUDA="1" +ARG MMCV_VERSION=">=2.0.0rc2" +ARG MMENGINE_VERSION=">=0.3.0" + +WORKDIR /root/workspace + +RUN wget https://github.com/Kitware/CMake/releases/download/v3.26.3/cmake-3.26.3-linux-x86_64.sh &&\ + bash cmake-3.26.3-linux-x86_64.sh --skip-license --prefix=/usr + +RUN git clone --depth 1 --branch v${PPLCV_VERSION} https://github.com/openppl-public/ppl.cv.git &&\ + cd ppl.cv &&\ + ./build.sh cuda &&\ + mv cuda-build/install ./ &&\ + rm -rf cuda-build +ENV pplcv_DIR=/root/workspace/ppl.cv/install/lib/cmake/ppl + +RUN apt-get update &&\ + apt-get install -y libopencv-dev + +RUN wget https://github.com/microsoft/onnxruntime/releases/download/v${ONNXRUNTIME_VERSION}/onnxruntime-linux-x64-gpu-${ONNXRUNTIME_VERSION}.tgz \ + && tar -zxvf onnxruntime-linux-x64-gpu-${ONNXRUNTIME_VERSION}.tgz +ENV ONNXRUNTIME_DIR=/root/workspace/onnxruntime-linux-x64-gpu-${ONNXRUNTIME_VERSION} +ENV LD_LIBRARY_PATH=/root/workspace/onnxruntime-linux-x64-gpu-${ONNXRUNTIME_VERSION}/lib:$LD_LIBRARY_PATH + +RUN python3 -m pip install -U pip &&\ + pip install torch==1.10.0+cu113 torchvision==0.11.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html &&\ + pip install openmim &&\ + mim install "mmcv"${MMCV_VERSION} onnxruntime-gpu==${ONNXRUNTIME_VERSION} mmengine${MMENGINE_VERSION} &&\ + ln /usr/bin/python3 /usr/bin/python + +COPY TensorRT-8.2.3.0 /root/workspace/tensorrt +RUN pip install /root/workspace/tensorrt/python/*cp38*whl +ENV TENSORRT_DIR=/root/workspace/tensorrt +ENV LD_LIBRARY_PATH=/root/workspace/tensorrt/lib:$LD_LIBRARY_PATH + +RUN apt-get install -y rapidjson-dev + +RUN git clone -b v1.0.0rc7 https://github.com/open-mmlab/mmpretrain.git &&\ + cd mmpretrain && pip install . + +RUN git clone -b v3.0.0 https://github.com/open-mmlab/mmdetection.git &&\ + cd mmdetection && pip install . + +RUN git clone -b v1.0.0 https://github.com/open-mmlab/mmsegmentation.git &&\ + cd mmsegmentation && pip install . + +RUN git clone -b v1.0.0 https://github.com/open-mmlab/mmocr.git &&\ + cd mmocr && pip install . + +RUN git clone -b v1.0.0rc1 https://github.com/open-mmlab/mmrotate.git &&\ + cd mmrotate && pip install . + +RUN git clone -b v1.0.0 https://github.com/open-mmlab/mmpose.git &&\ + cd mmpose && pip install . + +RUN git clone -b triton-server --recursive https://github.com/irexyc/mmdeploy &&\ + cd mmdeploy && mkdir -p build && cd build &&\ + cmake .. \ + -DMMDEPLOY_BUILD_SDK=ON \ + -DMMDEPLOY_TARGET_DEVICES="cuda;cpu" \ + -DMMDEPLOY_BUILD_TEST=OFF \ + -DMMDEPLOY_TARGET_BACKENDS="trt;ort" \ + -DMMDEPLOY_CODEBASES=all \ + -Dpplcv_DIR=${pplcv_DIR} \ + -DMMDEPLOY_BUILD_EXAMPLES=OFF \ + -DMMDEPLOY_DYNAMIC_BACKEND=OFF \ + -DTRITON_MMDEPLOY_BACKEND=ON \ + -DTRITON_TAG="r22.12" &&\ + make -j$(nproc) && make install &&\ + cp -r install/backends /opt/tritonserver/ &&\ + cd .. && pip install -e . --user From bc9db633f56b545ca677aed6214792453d110b52 Mon Sep 17 00:00:00 2001 From: irexyc Date: Wed, 17 May 2023 16:05:52 +0800 Subject: [PATCH 12/16] add triton demo readme --- demo/triton/image-classification/README.md | 41 ++++++++++++++- demo/triton/image-classification/README_CN.md | 0 demo/triton/instance-segmentation/README.md | 41 ++++++++++++++- .../instance-segmentation/README_zh-CN.md | 0 demo/triton/keypoint-detection/README.md | 39 +++++++++++++- demo/triton/keypoint-detection/README_CN.md | 0 demo/triton/object-detection/README.md | 40 ++++++++++++++- demo/triton/object-detection/README_zh-CN.md | 0 .../oriented-object-detection/README.md | 40 ++++++++++++++- .../oriented-object-detection/README_zh-CN.md | 0 demo/triton/semantic-segmentation/README.md | 40 ++++++++++++++- .../triton/semantic-segmentation/README_CN.md | 0 demo/triton/text-detection/README.md | 39 ++++++++++++-- demo/triton/text-detection/README_CN.md | 0 demo/triton/text-ocr/README.md | 51 +++++++++++++++++++ demo/triton/text-recognition/README.md | 40 ++++++++++++++- demo/triton/text-recognition/README_CN.md | 0 demo/triton/to_triton_model.py | 18 ------- 18 files changed, 360 insertions(+), 29 deletions(-) delete mode 100644 demo/triton/image-classification/README_CN.md delete mode 100644 demo/triton/instance-segmentation/README_zh-CN.md delete mode 100644 demo/triton/keypoint-detection/README_CN.md delete mode 100644 demo/triton/object-detection/README_zh-CN.md delete mode 100644 demo/triton/oriented-object-detection/README_zh-CN.md delete mode 100644 demo/triton/semantic-segmentation/README_CN.md delete mode 100644 demo/triton/text-detection/README_CN.md create mode 100644 demo/triton/text-ocr/README.md delete mode 100644 demo/triton/text-recognition/README_CN.md diff --git a/demo/triton/image-classification/README.md b/demo/triton/image-classification/README.md index f881368b4a..ab02023d65 100644 --- a/demo/triton/image-classification/README.md +++ b/demo/triton/image-classification/README.md @@ -1 +1,40 @@ -python tools/deploy.py configs/mmpretrain/classification_tensorrt_static-224x224.py ../mmpretrain/configs/resnet/resnet18_8xb32_in1k.py ../checkpoints/resnet18_8xb32_in1k_20210831-fbbb1da6.pth ../mmclassification/demo/demo.JPEG --device cuda --work-dir work_dirs/resnet --dump-info \ No newline at end of file +# Image classification serving + + +## Starting a docker container +``` +docker run -it --rm --gpus all openmmlab/mmdeploy:triton-22.12 +``` + +## Convert pytorch model to tensorrt model +``` +cd /root/workspace/mmdeploy +python3 tools/deploy.py \ + configs/mmpretrain/classification_tensorrt_static-224x224.py \ + ../mmpretrain/configs/resnet/resnet18_8xb32_in1k.py \ + https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth \ + ../mmpretrain/demo/demo.JPEG \ + --device cuda \ + --work-dir work_dir/resnet \ + --dump-info +``` + +## Convert tensorrt model to triton format +``` +cd /root/workspace/mmdeploy +python3 demo/triton/to_triton_model.py \ + /root/workspace/mmdeploy/work_dir/resnet \ + /model-repository +``` + +## Start triton server +``` +tritonserver --model-repository=/model-repository +``` + +## Run client code output container +``` +python3 demo/triton/image-classification/grpc_client.py \ + model \ + /path/to/image +``` diff --git a/demo/triton/image-classification/README_CN.md b/demo/triton/image-classification/README_CN.md deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/demo/triton/instance-segmentation/README.md b/demo/triton/instance-segmentation/README.md index db199cbfe4..755e2bdaa3 100644 --- a/demo/triton/instance-segmentation/README.md +++ b/demo/triton/instance-segmentation/README.md @@ -1 +1,40 @@ -python tools/deploy.py configs/mmdet/instance-seg/instance-seg_tensorrt_dynamic-320x320-1344x1344.py ../mmdetection/configs/mask_rcnn/mask-rcnn_r50_fpn_2x_coco.py https://download.openmmlab.com/mmdetection/v2.0/mask_rcnn/mask_rcnn_r50_fpn_2x_coco/mask_rcnn_r50_fpn_2x_coco_bbox_mAP-0.392__segm_mAP-0.354_20200505_003907-3e542a40.pth ../mmdetection/demo/demo.jpg --work-dir work_dir/maskrcnn --dump-info --device cuda \ No newline at end of file +# Instance segmentation serving + + +## Starting a docker container +``` +docker run -it --rm --gpus all openmmlab/mmdeploy:triton-22.12 +``` + +## Convert pytorch model to tensorrt model +``` +cd /root/workspace/mmdeploy +python3 tools/deploy.py \ + configs/mmdet/instance-seg/instance-seg_tensorrt_dynamic-320x320-1344x1344.py \ + ../mmdetection/configs/mask_rcnn/mask-rcnn_r50_fpn_2x_coco.py \ + https://download.openmmlab.com/mmdetection/v2.0/mask_rcnn/mask_rcnn_r50_fpn_2x_coco/mask_rcnn_r50_fpn_2x_coco_bbox_mAP-0.392__segm_mAP-0.354_20200505_003907-3e542a40.pth \ + ../mmdetection/demo/demo.jpg \ + --work-dir work_dir/maskrcnn \ + --dump-info \ + --device cuda +``` + +## Convert tensorrt model to triton format +``` +cd /root/workspace/mmdeploy +python3 demo/triton/to_triton_model.py \ + /root/workspace/mmdeploy/work_dir/maskrcnn \ + /model-repository +``` + +## Start triton server +``` +tritonserver --model-repository=/model-repository +``` + +## Run client code output container +``` +python3 demo/triton/instance-segmentation/grpc_client.py \ + model \ + /path/to/image +``` diff --git a/demo/triton/instance-segmentation/README_zh-CN.md b/demo/triton/instance-segmentation/README_zh-CN.md deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/demo/triton/keypoint-detection/README.md b/demo/triton/keypoint-detection/README.md index 2f2b9e881f..46b8e5ce5f 100644 --- a/demo/triton/keypoint-detection/README.md +++ b/demo/triton/keypoint-detection/README.md @@ -1,4 +1,39 @@ -python tools/deploy.py configs/mmpose/pose-detection_tensorrt_static-256x192.py ../mmpose/configs/body_2d_keypoint/topdown_heatmap/coco/td-hm_hrnet-w32_8xb64-210e_coco-256x192.py ../checkpoints/td-hm_hrnet-w32_8xb64-210e_coco-256x192-81c58e40_20220909.pth demo/resources/human-pose.jpg --work-dir work_dir/hrnet --dump-info --device cuda +# Keypoint detection serving -python tools/deploy.py configs/mmpose/pose-detection_simcc_tensorrt_dynamic-256x192.py ../mmpose/configs/body_2d_keypoint/simcc/coco/simcc_res50_8xb64-210e_coco-256x192.py https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/simcc/coco/simcc_res50_8xb64-210e_coco-256x192-8e0f5b59_20220919.pth demo/resources/human-pose.jpg --work-dir work_dir/pose2 --dump-info --device cuda +## Starting a docker container +``` +docker run -it --rm --gpus all openmmlab/mmdeploy:triton-22.12 +``` +## Convert pytorch model to tensorrt model +``` +cd /root/workspace/mmdeploy +python3 tools/deploy.py \ + configs/mmpose/pose-detection_tensorrt_static-256x192.py \ + ../mmpose/configs/body_2d_keypoint/topdown_heatmap/coco/td-hm_hrnet-w32_8xb64-210e_coco-256x192.py \ + https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/topdown_heatmap/coco/td-hm_hrnet-w32_8xb64-210e_coco-256x192-81c58e40_20220909.pth \ + demo/resources/human-pose.jpg \ + --work-dir work_dir/hrnet \ + --dump-info \ + --device cuda +``` + +## Convert tensorrt model to triton format +``` +cd /root/workspace/mmdeploy +python3 demo/triton/to_triton_model.py \ + /root/workspace/mmdeploy/work_dir/hrnet \ + /model-repository +``` + +## Start triton server +``` +tritonserver --model-repository=/model-repository +``` + +## Run client code output container +``` +python3 demo/triton/keypoint-detection/grpc_client.py \ + model \ + /path/to/image +``` \ No newline at end of file diff --git a/demo/triton/keypoint-detection/README_CN.md b/demo/triton/keypoint-detection/README_CN.md deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/demo/triton/object-detection/README.md b/demo/triton/object-detection/README.md index 1572845fde..57270719af 100644 --- a/demo/triton/object-detection/README.md +++ b/demo/triton/object-detection/README.md @@ -1 +1,39 @@ -python tools/deploy.py configs/mmdet/detection/detection_tensorrt_dynamic-320x320-1344x1344.py ../mmdetection/configs/retinanet/retinanet_r18_fpn_1x_coco.py https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r18_fpn_1x_coco/retinanet_r18_fpn_1x_coco_20220407_171055-614fd399.pth ../mmdetection/demo/demo.jpg --work-dir work_dir/retinanet --dump-info --device cuda \ No newline at end of file +# Object detection serving + +## Starting a docker container +``` +docker run -it --rm --gpus all openmmlab/mmdeploy:triton-22.12 +``` + +## Convert pytorch model to tensorrt model +``` +cd /root/workspace/mmdeploy +python3 tools/deploy.py \ + configs/mmdet/detection/detection_tensorrt_dynamic-320x320-1344x1344.py \ + ../mmdetection/configs/retinanet/retinanet_r18_fpn_1x_coco.py \ + https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r18_fpn_1x_coco/retinanet_r18_fpn_1x_coco_20220407_171055-614fd399.pth \ + ../mmdetection/demo/demo.jpg \ + --work-dir work_dir/retinanet \ + --dump-info \ + --device cuda +``` + +## Convert tensorrt model to triton format +``` +cd /root/workspace/mmdeploy +python3 demo/triton/to_triton_model.py \ + /root/workspace/mmdeploy/work_dir/retinanet \ + /model-repository +``` + +## Start triton server +``` +tritonserver --model-repository=/model-repository +``` + +## Run client code output container +``` +python3 demo/triton/object-detection/grpc_client.py \ + model \ + /path/to/image +``` \ No newline at end of file diff --git a/demo/triton/object-detection/README_zh-CN.md b/demo/triton/object-detection/README_zh-CN.md deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/demo/triton/oriented-object-detection/README.md b/demo/triton/oriented-object-detection/README.md index 5a055b2be9..7da9a45054 100644 --- a/demo/triton/oriented-object-detection/README.md +++ b/demo/triton/oriented-object-detection/README.md @@ -1 +1,39 @@ -python tools/deploy.py configs/mmrotate/rotated-detection_tensorrt_dynamic-320x320-1024x1024.py ../mmrotate/configs/rotated_faster_rcnn/rotated-faster-rcnn-le90_r50_fpn_1x_dota.py ../../mmrotate/checkpoint/rotated_faster_rcnn_r50_fpn_1x_dota_le90-0393aa5c.pth ../mmrotate/demo/demo.jpg --dump-info --work-dir work_dir/rrcnn --device cuda \ No newline at end of file +# Oriented object detection serving + +## Starting a docker container +``` +docker run -it --rm --gpus all openmmlab/mmdeploy:triton-22.12 +``` + +## Convert pytorch model to tensorrt model +``` +cd /root/workspace/mmdeploy +python3 tools/deploy.py \ + configs/mmrotate/rotated-detection_tensorrt_dynamic-320x320-1024x1024.py \ + ../mmrotate/configs/rotated_faster_rcnn/rotated-faster-rcnn-le90_r50_fpn_1x_dota.py \ + https://download.openmmlab.com/mmrotate/v0.1.0/rotated_faster_rcnn/rotated_faster_rcnn_r50_fpn_1x_dota_le90/rotated_faster_rcnn_r50_fpn_1x_dota_le90-0393aa5c.pth \ + ../mmrotate/demo/demo.jpg \ + --dump-info \ + --work-dir work_dir/rrcnn \ + --device cuda +``` + +## Convert tensorrt model to triton format +``` +cd /root/workspace/mmdeploy +python3 demo/triton/to_triton_model.py \ + /root/workspace/mmdeploy/work_dir/rrcnn \ + /model-repository +``` + +## Start triton server +``` +tritonserver --model-repository=/model-repository +``` + +## Run client code output container +``` +python3 demo/triton/oriented-object-detection/grpc_client.py \ + model \ + /path/to/image +``` \ No newline at end of file diff --git a/demo/triton/oriented-object-detection/README_zh-CN.md b/demo/triton/oriented-object-detection/README_zh-CN.md deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/demo/triton/semantic-segmentation/README.md b/demo/triton/semantic-segmentation/README.md index 12c3275d29..19af161ef8 100644 --- a/demo/triton/semantic-segmentation/README.md +++ b/demo/triton/semantic-segmentation/README.md @@ -1 +1,39 @@ -python tools/deploy.py configs/mmseg/segmentation_tensorrt-fp16_static-512x1024.py ../mmsegmentation/configs/pspnet/pspnet_r18-d8_4xb2-80k_cityscapes-512x1024.py ../../checkpoints/pspnet_r18-d8_512x1024_80k_cityscapes_20201225_021458-09ffa746.pth ../mmsegmentation/demo/demo.png --work-dir work_dir/pspnet --dump-info --device cuda \ No newline at end of file +# Semantic segmentation serving + +## Starting a docker container +``` +docker run -it --rm --gpus all openmmlab/mmdeploy:triton-22.12 +``` + +## Convert pytorch model to tensorrt model +``` +cd /root/workspace/mmdeploy +python3 tools/deploy.py \ + configs/mmseg/segmentation_tensorrt-fp16_static-512x1024.py \ + ../mmsegmentation/configs/pspnet/pspnet_r18-d8_4xb2-80k_cityscapes-512x1024.py \ + https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r18-d8_512x1024_80k_cityscapes/pspnet_r18-d8_512x1024_80k_cityscapes_20201225_021458-09ffa746.pth \ + ../mmsegmentation/demo/demo.png \ + --work-dir work_dir/pspnet \ + --dump-info \ + --device cuda +``` + +## Convert tensorrt model to triton format +``` +cd /root/workspace/mmdeploy +python3 demo/triton/to_triton_model.py \ + /root/workspace/mmdeploy/work_dir/pspnet \ + /model-repository +``` + +## Start triton server +``` +tritonserver --model-repository=/model-repository +``` + +## Run client code output container +``` +python3 demo/triton/semantic-segmentation/grpc_client.py \ + model \ + /path/to/image +``` \ No newline at end of file diff --git a/demo/triton/semantic-segmentation/README_CN.md b/demo/triton/semantic-segmentation/README_CN.md deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/demo/triton/text-detection/README.md b/demo/triton/text-detection/README.md index 7ecb52c498..7c3880e4e0 100644 --- a/demo/triton/text-detection/README.md +++ b/demo/triton/text-detection/README.md @@ -1,6 +1,39 @@ -python tools/deploy.py configs/mmocr/text-detection/text-detection_tensorrt_dynamic-320x320-2240x2240.py ../mmocr/configs/textdet/panet/panet_resnet18_fpem-ffm_600e_icdar2015.py https://download.openmmlab.com/mmocr/textdet/panet/panet_resnet18_fpem-ffm_600e_icdar2015/panet_resnet18_fpem-ffm_600e_icdar2015_20220826_144817-be2acdb4.pth ../mmocr/demo/demo_text_det.jpg --work-dir work_dir/panet --dump-info --device cuda:0 +# Text detection serving -python tools/deploy.py configs/mmocr/text-detection/text-detection_tensorrt_dynamic-320x320-2240x2240.py ../mmocr/configs/textdet/dbnet/dbnet_resnet18_fpnc_1200e_icdar2015.py https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_resnet18_fpnc_1200e_icdar2015/dbnet_resnet18_fpnc_1200e_icdar2015_20220825_221614-7c0e94f2.pth ../mmocr/demo/demo_text_det.jpg --work-dir work_dir/dbnet --dump-info --device cuda:0 +## Starting a docker container +``` +docker run -it --rm --gpus all openmmlab/mmdeploy:triton-22.12 +``` +## Convert pytorch model to tensorrt model +``` +cd /root/workspace/mmdeploy +python3 tools/deploy.py \ + configs/mmocr/text-detection/text-detection_tensorrt_dynamic-320x320-2240x2240.py \ + ../mmocr/configs/textdet/panet/panet_resnet18_fpem-ffm_600e_icdar2015.py \ + https://download.openmmlab.com/mmocr/textdet/panet/panet_resnet18_fpem-ffm_600e_icdar2015/panet_resnet18_fpem-ffm_600e_icdar2015_20220826_144817-be2acdb4.pth \ + ../mmocr/demo/demo_text_det.jpg \ + --work-dir work_dir/panet \ + --dump-info \ + --device cuda:0 +``` -python tools/deploy.py configs/mmocr/text-detection/text-detection_tensorrt_dynamic-320x320-2240x2240.py ../mmocr/configs/textdet/psenet/psenet_resnet50_fpnf_600e_icdar2015.py https://download.openmmlab.com/mmocr/textdet/psenet/psenet_resnet50_fpnf_600e_icdar2015/psenet_resnet50_fpnf_600e_icdar2015_20220825_222709-b6741ec3.pth ../mmocr/demo/demo_text_det.jpg --work-dir work_dir/psenet --dump-info --device cuda:0 \ No newline at end of file +## Convert tensorrt model to triton format +``` +cd /root/workspace/mmdeploy +python3 demo/triton/to_triton_model.py \ + /root/workspace/mmdeploy/work_dir/panet \ + /model-repository +``` + +## Start triton server +``` +tritonserver --model-repository=/model-repository +``` + +## Run client code output container +``` +python3 demo/triton/text-detection/grpc_client.py \ + model \ + /path/to/image +``` \ No newline at end of file diff --git a/demo/triton/text-detection/README_CN.md b/demo/triton/text-detection/README_CN.md deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/demo/triton/text-ocr/README.md b/demo/triton/text-ocr/README.md new file mode 100644 index 0000000000..66cb7e4373 --- /dev/null +++ b/demo/triton/text-ocr/README.md @@ -0,0 +1,51 @@ +# Text ocr serving + +## Starting a docker container +``` +docker run -it --rm --gpus all openmmlab/mmdeploy:triton-22.12 +``` + +## Convert pytorch model to tensorrt model +``` +cd /root/workspace/mmdeploy + +# text-detection +python3 tools/deploy.py \ + configs/mmocr/text-detection/text-detection_tensorrt_dynamic-320x320-2240x2240.py \ + ../mmocr/configs/textdet/panet/panet_resnet18_fpem-ffm_600e_icdar2015.py \ + https://download.openmmlab.com/mmocr/textdet/panet/panet_resnet18_fpem-ffm_600e_icdar2015/panet_resnet18_fpem-ffm_600e_icdar2015_20220826_144817-be2acdb4.pth \ + ../mmocr/demo/demo_text_det.jpg \ + --work-dir work_dir/panet \ + --dump-info \ + --device cuda:0 + +# text-recognition +python3 tools/deploy.py \ + configs/mmocr/text-recognition/text-recognition_tensorrt-fp16_dynamic-1x32x32-1x32x640.py \ + ../mmocr/configs/textrecog/crnn/crnn_mini-vgg_5e_mj.py \ + https://download.openmmlab.com/mmocr/textrecog/crnn/crnn_mini-vgg_5e_mj/crnn_mini-vgg_5e_mj_20220826_224120-8afbedbb.pth \ + ../mmocr/demo/demo_text_recog.jpg \ + --work-dir work_dir/crnn \ + --device cuda \ + --dump-info +``` + +## Ensemble detection and recognition model +``` +cd /root/workspace/mmdeploy +cp -r demo/triton/text-ocr/serving /model-repository +cp -r work_dir/panet/* /model-repository/model/1/text_detection/ +cp -r work_dir/crnn/* /model-repository/model/1/text_recognition/ +``` + +## Start triton server +``` +tritonserver --model-repository=/model-repository +``` + +## Run client code output container +``` +python3 demo/triton/text-ocr/grpc_client.py \ + model \ + /path/to/image +``` diff --git a/demo/triton/text-recognition/README.md b/demo/triton/text-recognition/README.md index 9eaa1df85a..5b4c428559 100644 --- a/demo/triton/text-recognition/README.md +++ b/demo/triton/text-recognition/README.md @@ -1 +1,39 @@ -python tools/deploy.py configs/mmocr/text-recognition/text-recognition_tensorrt-fp16_dynamic-1x32x32-1x32x640.py ../mmocr/configs/textrecog/crnn/crnn_mini-vgg_5e_mj.py https://download.openmmlab.com/mmocr/textrecog/crnn/crnn_mini-vgg_5e_mj/crnn_mini-vgg_5e_mj_20220826_224120-8afbedbb.pth ../mmocr/demo/demo_text_recog.jpg --work-dir work_dir/crnn --device cuda --dump-info \ No newline at end of file +# Text recognition serving + +## Starting a docker container +``` +docker run -it --rm --gpus all openmmlab/mmdeploy:triton-22.12 +``` + +## Convert pytorch model to tensorrt model +``` +cd /root/workspace/mmdeploy +python3 tools/deploy.py \ + configs/mmocr/text-recognition/text-recognition_tensorrt-fp16_dynamic-1x32x32-1x32x640.py \ + ../mmocr/configs/textrecog/crnn/crnn_mini-vgg_5e_mj.py \ + https://download.openmmlab.com/mmocr/textrecog/crnn/crnn_mini-vgg_5e_mj/crnn_mini-vgg_5e_mj_20220826_224120-8afbedbb.pth \ + ../mmocr/demo/demo_text_recog.jpg \ + --work-dir work_dir/crnn \ + --device cuda \ + --dump-info +``` + +## Convert tensorrt model to triton format +``` +cd /root/workspace/mmdeploy +python3 demo/triton/to_triton_model.py \ + /root/workspace/mmdeploy/work_dir/crnn \ + /model-repository +``` + +## Start triton server +``` +tritonserver --model-repository=/model-repository +``` + +## Run client code output container +``` +python3 demo/triton/text-detection/grpc_client.py \ + model \ + /path/to/image +``` diff --git a/demo/triton/text-recognition/README_CN.md b/demo/triton/text-recognition/README_CN.md deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/demo/triton/to_triton_model.py b/demo/triton/to_triton_model.py index cdeb2d9ff6..ee3f95891a 100644 --- a/demo/triton/to_triton_model.py +++ b/demo/triton/to_triton_model.py @@ -158,21 +158,3 @@ def get_model_type(detail_cfg, pipeline_cfg): os.makedirs(output_dir) copy_template(src_folder, output_dir) convert.create_single_model() - - -# /data/testmodel/mmcls - -# /home/cx/ws/2.0/mmdeploy/work_dir/rtn2/ -# /home/cx/ws/2.0/mmdeploy/work_dir/maskrcnn/ - -# /home/cx/ws/2.0/mmdeploy/work_dir/pose1 -# /home/cx/ws/2.0/mmdeploy/work_dir/pose2 - -# /home/cx/ws/2.0/mmdeploy/work_dir/pspnet -# /home/cx/ws/2.0/mmdeploy/work_dir/pspnet_mask - -# /home/cx/ws/2.0/mmdeploy/work_dirs/rrcnn - -# /home/cx/ws/2.0/mmdeploy/work_dir/panet - -# /home/cx/ws/2.0/mmdeploy/work_dir/crnn From 121d9bac2b7dfd6647a7801ea37ba4c219d05936 Mon Sep 17 00:00:00 2001 From: irexyc Date: Thu, 18 May 2023 14:02:20 +0800 Subject: [PATCH 13/16] update docs --- docs/en/02-how-to-run/triton_server.md | 38 +++++++++++++++++++++++ docs/en/get_started.md | 4 +++ docs/en/index.rst | 1 + docs/zh_cn/02-how-to-run/triton_server.md | 38 +++++++++++++++++++++++ docs/zh_cn/get_started.md | 4 +++ docs/zh_cn/index.rst | 1 + 6 files changed, 86 insertions(+) create mode 100644 docs/en/02-how-to-run/triton_server.md create mode 100644 docs/zh_cn/02-how-to-run/triton_server.md diff --git a/docs/en/02-how-to-run/triton_server.md b/docs/en/02-how-to-run/triton_server.md new file mode 100644 index 0000000000..7652d47755 --- /dev/null +++ b/docs/en/02-how-to-run/triton_server.md @@ -0,0 +1,38 @@ +# Model serving + +MMDeploy provides model server deployment based on Triton Inference Server. + +## Supported tasks + +The following tasks are currently supported: + +- [image-classification](../../../demo/triton/image-classification/README.md) +- [instance-segmentation](../../../demo/triton/instance-segmentation) +- [keypoint-detection](../../../demo/triton/keypoint-detection) +- [object-detection](../../../demo/triton/object-detection) +- [oriented-object-detection](../../../demo/triton/oriented-object-detection) +- [semantic-segmentation](../../../demo/triton/semantic-segmentation) +- [text-detection](../../../demo/triton/text-detection) +- [text-recognition](../../../demo/triton/text-recognition) +- [text-ocr](../../../demo/triton/text-ocr) + +## Run Triton + +In order to use Triton Inference Server, we need: + +1. Compile MMDeploy Triton Backend +2. Prepare the model repository (including model files, and configuration files) + +### Compile MMDeploy Triton Backend + +a) Using Docker images + +For ease of use, we provide a Docker image to support the deployment of models converted by MMDeploy. The image supports Tensorrt and ONNX Runtime as backends. If you need other backends, you can choose build from source. + +b) Build from source + +You can refer [build from source](../01-how-to-build/build_from_source.md) to build MMDeploy. In order to build MMDeploy Triton Backend, you need to add `-DTRITON_MMDEPLOY_BACKEND=ON` to cmake configure command. By default, the latest version of Triton Backend is used. If you want to use an older version of Triton Backend, you can add `-DTRITON_TAG=r22.12` to the cmake configure command. + +### Prepare the model repository + +Triton Inference Server has its own model description rules. Therefore the models converted through `tools/deploy.py ... --dump-info` need to be formatted to make Triton load correctly. We have prepared templates for each task. You can use `demo/triton/to_triton_model.py` script for model formatting. For complete samples, please refer to the description of each demo. diff --git a/docs/en/get_started.md b/docs/en/get_started.md index 9fce872ea1..f3afdb1080 100644 --- a/docs/en/get_started.md +++ b/docs/en/get_started.md @@ -330,6 +330,10 @@ We'll talk about them more in our next release. If you want to fuse preprocess for acceleration,please refer to this [doc](./02-how-to-run/fuse_transform.md) +## Model serving (triton) + +For server-side deployment, please read [model serving](02-how-to-run/triton_server.md) for more details. + ## Evaluate Model You can test the performance of deployed model using `tool/test.py`. For example, diff --git a/docs/en/index.rst b/docs/en/index.rst index 0704aeaf8f..1833ef1d0c 100644 --- a/docs/en/index.rst +++ b/docs/en/index.rst @@ -27,6 +27,7 @@ You can switch between Chinese and English documents in the lower-left corner of 02-how-to-run/profile_model.md 02-how-to-run/quantize_model.md 02-how-to-run/useful_tools.md + 02-how-to-run/triton_server.md .. toctree:: :maxdepth: 1 diff --git a/docs/zh_cn/02-how-to-run/triton_server.md b/docs/zh_cn/02-how-to-run/triton_server.md new file mode 100644 index 0000000000..d9097252ca --- /dev/null +++ b/docs/zh_cn/02-how-to-run/triton_server.md @@ -0,0 +1,38 @@ +# 如何进行服务端部署 + +模型转换后,MMDeploy 提供基于 Triton Inference Server 的模型服务端部署。 + +## 支持的任务 + +目前支持以下任务: + +- [image-classification](../../../demo/triton/image-classification/README.md) +- [instance-segmentation](../../../demo/triton/instance-segmentation) +- [keypoint-detection](../../../demo/triton/keypoint-detection) +- [object-detection](../../../demo/triton/object-detection) +- [oriented-object-detection](../../../demo/triton/oriented-object-detection) +- [semantic-segmentation](../../../demo/triton/semantic-segmentation) +- [text-detection](../../../demo/triton/text-detection) +- [text-recognition](../../../demo/triton/text-recognition) +- [text-ocr](../../../demo/triton/text-ocr) + +## 如何部署 Triton 服务 + +为了使用 Triton Inference Server, 我们需要: + +1. 编译 MMDeploy Triton Backend +2. 准备模型库(包括模型文件,以及配置文件) + +### 编译 MMDeploy Triton Backend + +a) 使用 Docker 镜像 + +为了方便使用,我们提供了 Docker 镜像,支持对通过 MMDeploy 转换的模型进行部署。镜像支持 Tensorrt 以及 ONNX Runtime 作为后端。若需要其他后端,可选择从源码进行编译。 + +b) 从源码编译 + +从源码编译 MMDeploy 的方式可参考[源码手动安装](../01-how-to-build/build_from_source.md),要编译 MMDeploy Triton Backend,需要在编译命令中添加:`-DTRITON_MMDEPLOY_BACKEND=ON`。默认使用最新版本的 Triton Backend,若要使用旧版本的 Triton Backend,可在编译命令中添加`-DTRITON_TAG=r22.12` + +### 准备模型库 + +Triton Inference Server 有一套自己的模型描述规则,通过 `tools/deploy.py ... --dump-info ` 转换的模型需要调整格式才能使 Triton 正确加载,我们为各任务准备了模版,可以运行 `demo/triton/to_triton_model.py` 转换脚本格式进行修改。完整的样例可参考各个 demo 的说明。 diff --git a/docs/zh_cn/get_started.md b/docs/zh_cn/get_started.md index 27b4e55245..ed85e52d25 100644 --- a/docs/zh_cn/get_started.md +++ b/docs/zh_cn/get_started.md @@ -331,6 +331,10 @@ target_link_libraries(${name} PRIVATE mmdeploy ${OpenCV_LIBS}) 若要对预处理进行加速,请查阅[此处](./02-how-to-run/fuse_transform.md) +## 服务端部署 (triton) + +若需要进行服务端部署,请阅读 [服务端部署](02-how-to-run/triton_server.md) 了解更多细节 + ## 模型精度评估 为了测试部署模型的精度,推理效率,我们提供了 `tools/test.py` 来帮助完成相关工作。以上文中的部署模型为例: diff --git a/docs/zh_cn/index.rst b/docs/zh_cn/index.rst index e52a40c7aa..12959be2aa 100644 --- a/docs/zh_cn/index.rst +++ b/docs/zh_cn/index.rst @@ -27,6 +27,7 @@ 02-how-to-run/profile_model.md 02-how-to-run/quantize_model.md 02-how-to-run/useful_tools.md + 02-how-to-run/triton_server.md .. toctree:: :maxdepth: 1 From 9aed1bfbd680b3d05816cb54c9399a65e9653573 Mon Sep 17 00:00:00 2001 From: irexyc Date: Thu, 18 May 2023 14:07:36 +0800 Subject: [PATCH 14/16] fix lint --- csrc/mmdeploy/triton/instance_state.cpp | 2 +- demo/triton/image-classification/README.md | 6 ++- .../image-classification/grpc_client.py | 27 ++++++----- .../serving/model/1/README.md | 2 +- demo/triton/instance-segmentation/README.md | 6 ++- .../instance-segmentation/grpc_client.py | 29 +++++------ .../serving/model/1/README.md | 2 +- .../serving/model/config.pbtxt | 2 +- demo/triton/keypoint-detection/README.md | 7 ++- demo/triton/keypoint-detection/grpc_client.py | 42 ++++++++-------- .../serving/model/1/README.md | 2 +- .../serving/model/config.pbtxt | 2 +- demo/triton/object-detection/README.md | 7 ++- demo/triton/object-detection/grpc_client.py | 27 ++++++----- .../serving/model/1/README.md | 2 +- .../oriented-object-detection/README.md | 7 ++- .../oriented-object-detection/grpc_client.py | 29 +++++------ .../serving/model/1/README.md | 2 +- demo/triton/semantic-segmentation/README.md | 7 ++- .../semantic-segmentation/grpc_client.py | 27 ++++++----- .../serving/mask/model/1/README.md | 2 +- .../serving/score/model/1/README.md | 2 +- .../serving/score/model/config.pbtxt | 2 +- demo/triton/text-detection/README.md | 7 ++- demo/triton/text-detection/grpc_client.py | 29 ++++++----- .../text-detection/serving/model/1/README.md | 2 +- demo/triton/text-ocr/README.md | 5 ++ demo/triton/text-ocr/grpc_client.py | 32 +++++++------ .../serving/model/1/pipeline_template.json | 2 +- demo/triton/text-recognition/README.md | 5 ++ demo/triton/text-recognition/grpc_client.py | 48 ++++++++++--------- .../serving/model/1/README.md | 2 +- demo/triton/to_triton_model.py | 35 +++++++------- 33 files changed, 231 insertions(+), 179 deletions(-) diff --git a/csrc/mmdeploy/triton/instance_state.cpp b/csrc/mmdeploy/triton/instance_state.cpp index 65a0f1af97..fe8808cdfb 100644 --- a/csrc/mmdeploy/triton/instance_state.cpp +++ b/csrc/mmdeploy/triton/instance_state.cpp @@ -247,7 +247,7 @@ TRITONSERVER_Error* ModelInstanceState::Execute(TRITONBACKEND_Request** requests } } - // merget inputs for example: [[a,a,a], [b,b,b], [c,c,c]] -> [[aaa], [(b,c), (b,c), (b,c)]] + // merge inputs for example: [[a,a,a], [b,b,b], [c,c,c]] -> [[aaa], [(b,c), (b,c), (b,c)]] if (!merge_inputs_.empty()) { int n_example = vec_inputs[0].size(); ::mmdeploy::Value inputs; diff --git a/demo/triton/image-classification/README.md b/demo/triton/image-classification/README.md index ab02023d65..e7cb270009 100644 --- a/demo/triton/image-classification/README.md +++ b/demo/triton/image-classification/README.md @@ -1,12 +1,13 @@ # Image classification serving - ## Starting a docker container + ``` docker run -it --rm --gpus all openmmlab/mmdeploy:triton-22.12 ``` ## Convert pytorch model to tensorrt model + ``` cd /root/workspace/mmdeploy python3 tools/deploy.py \ @@ -20,6 +21,7 @@ python3 tools/deploy.py \ ``` ## Convert tensorrt model to triton format + ``` cd /root/workspace/mmdeploy python3 demo/triton/to_triton_model.py \ @@ -28,11 +30,13 @@ python3 demo/triton/to_triton_model.py \ ``` ## Start triton server + ``` tritonserver --model-repository=/model-repository ``` ## Run client code output container + ``` python3 demo/triton/image-classification/grpc_client.py \ model \ diff --git a/demo/triton/image-classification/grpc_client.py b/demo/triton/image-classification/grpc_client.py index 51a876fb4b..92755daf7b 100644 --- a/demo/triton/image-classification/grpc_client.py +++ b/demo/triton/image-classification/grpc_client.py @@ -1,15 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse + import cv2 -from tritonclient.grpc import InferenceServerClient, InferInput, InferRequestedOutput +from tritonclient.grpc import (InferenceServerClient, InferInput, + InferRequestedOutput) def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('model_name', type=str, - help='model name') - parser.add_argument('image', type=str, - help='image path') + parser.add_argument('model_name', type=str, help='model name') + parser.add_argument('image', type=str, help='image path') return parser.parse_args() @@ -22,14 +22,16 @@ def __init__(self, url, model_name, model_version): self._client = InferenceServerClient(self._url) model_config = self._client.get_model_config(self._model_name, self._model_version) - model_metadata = self._client.get_model_metadata(self._model_name, - self._model_version) + model_metadata = self._client.get_model_metadata( + self._model_name, self._model_version) print(f'[model config]:\n{model_config}') print(f'[model metadata]:\n{model_metadata}') self._inputs = {input.name: input for input in model_metadata.inputs} self._input_names = list(self._inputs) self._outputs = { - output.name: output for output in model_metadata.outputs} + output.name: output + for output in model_metadata.outputs + } self._output_names = list(self._outputs) self._outputs_req = [ InferRequestedOutput(name) for name in self._outputs @@ -43,8 +45,7 @@ def infer(self, image): results: dict, {name : numpy.array} """ - inputs = [InferInput(self._input_names[0], image.shape, - "UINT8")] + inputs = [InferInput(self._input_names[0], image.shape, 'UINT8')] inputs[0].set_data_from_numpy(image) results = self._client.infer( model_name=self._model_name, @@ -65,11 +66,11 @@ def visualize(results): print(f'label {labels[i]} score {scores[i]}') -if __name__ == "__main__": +if __name__ == '__main__': args = parse_args() model_name = args.model_name - model_version = "1" - url = "localhost:8001" + model_version = '1' + url = 'localhost:8001' client = GRPCTritonClient(url, model_name, model_version) img = cv2.imread(args.image) results = client.infer(img) diff --git a/demo/triton/image-classification/serving/model/1/README.md b/demo/triton/image-classification/serving/model/1/README.md index ff6d1ae274..3b5ec0b47e 100644 --- a/demo/triton/image-classification/serving/model/1/README.md +++ b/demo/triton/image-classification/serving/model/1/README.md @@ -1 +1 @@ -This directory holds the model files. \ No newline at end of file +This directory holds the model files. diff --git a/demo/triton/instance-segmentation/README.md b/demo/triton/instance-segmentation/README.md index 755e2bdaa3..fabceeca4d 100644 --- a/demo/triton/instance-segmentation/README.md +++ b/demo/triton/instance-segmentation/README.md @@ -1,12 +1,13 @@ # Instance segmentation serving - ## Starting a docker container + ``` docker run -it --rm --gpus all openmmlab/mmdeploy:triton-22.12 ``` ## Convert pytorch model to tensorrt model + ``` cd /root/workspace/mmdeploy python3 tools/deploy.py \ @@ -20,6 +21,7 @@ python3 tools/deploy.py \ ``` ## Convert tensorrt model to triton format + ``` cd /root/workspace/mmdeploy python3 demo/triton/to_triton_model.py \ @@ -28,11 +30,13 @@ python3 demo/triton/to_triton_model.py \ ``` ## Start triton server + ``` tritonserver --model-repository=/model-repository ``` ## Run client code output container + ``` python3 demo/triton/instance-segmentation/grpc_client.py \ model \ diff --git a/demo/triton/instance-segmentation/grpc_client.py b/demo/triton/instance-segmentation/grpc_client.py index 23b33d7d43..6d9473ff0d 100644 --- a/demo/triton/instance-segmentation/grpc_client.py +++ b/demo/triton/instance-segmentation/grpc_client.py @@ -1,16 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse -import cv2 -from tritonclient.grpc import InferenceServerClient, InferInput, InferRequestedOutput import math +import cv2 +from tritonclient.grpc import (InferenceServerClient, InferInput, + InferRequestedOutput) + def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('model_name', type=str, - help='model name') - parser.add_argument('image', type=str, - help='image path') + parser.add_argument('model_name', type=str, help='model name') + parser.add_argument('image', type=str, help='image path') return parser.parse_args() @@ -23,14 +23,16 @@ def __init__(self, url, model_name, model_version): self._client = InferenceServerClient(self._url) model_config = self._client.get_model_config(self._model_name, self._model_version) - model_metadata = self._client.get_model_metadata(self._model_name, - self._model_version) + model_metadata = self._client.get_model_metadata( + self._model_name, self._model_version) print(f'[model config]:\n{model_config}') print(f'[model metadata]:\n{model_metadata}') self._inputs = {input.name: input for input in model_metadata.inputs} self._input_names = list(self._inputs) self._outputs = { - output.name: output for output in model_metadata.outputs} + output.name: output + for output in model_metadata.outputs + } self._output_names = list(self._outputs) self._outputs_req = [ InferRequestedOutput(name) for name in self._outputs @@ -44,8 +46,7 @@ def infer(self, image): results: dict, {name : numpy.array} """ - inputs = [InferInput(self._input_names[0], image.shape, - "UINT8")] + inputs = [InferInput(self._input_names[0], image.shape, 'UINT8')] inputs[0].set_data_from_numpy(image) results = self._client.infer( model_name=self._model_name, @@ -83,11 +84,11 @@ def visualize(img, results): cv2.imwrite('instance-segmentation.jpg', img) -if __name__ == "__main__": +if __name__ == '__main__': args = parse_args() model_name = args.model_name - model_version = "1" - url = "localhost:8001" + model_version = '1' + url = 'localhost:8001' client = GRPCTritonClient(url, model_name, model_version) img = cv2.imread(args.image) results = client.infer(img) diff --git a/demo/triton/instance-segmentation/serving/model/1/README.md b/demo/triton/instance-segmentation/serving/model/1/README.md index ff6d1ae274..3b5ec0b47e 100644 --- a/demo/triton/instance-segmentation/serving/model/1/README.md +++ b/demo/triton/instance-segmentation/serving/model/1/README.md @@ -1 +1 @@ -This directory holds the model files. \ No newline at end of file +This directory holds the model files. diff --git a/demo/triton/instance-segmentation/serving/model/config.pbtxt b/demo/triton/instance-segmentation/serving/model/config.pbtxt index 4ec61a589d..8e8d145bdd 100644 --- a/demo/triton/instance-segmentation/serving/model/config.pbtxt +++ b/demo/triton/instance-segmentation/serving/model/config.pbtxt @@ -27,4 +27,4 @@ output { name: "mask_offs" data_type: TYPE_INT32 dims: [ -1, 3 ] -} \ No newline at end of file +} diff --git a/demo/triton/keypoint-detection/README.md b/demo/triton/keypoint-detection/README.md index 46b8e5ce5f..5839ac10bc 100644 --- a/demo/triton/keypoint-detection/README.md +++ b/demo/triton/keypoint-detection/README.md @@ -1,11 +1,13 @@ # Keypoint detection serving ## Starting a docker container + ``` docker run -it --rm --gpus all openmmlab/mmdeploy:triton-22.12 ``` ## Convert pytorch model to tensorrt model + ``` cd /root/workspace/mmdeploy python3 tools/deploy.py \ @@ -19,6 +21,7 @@ python3 tools/deploy.py \ ``` ## Convert tensorrt model to triton format + ``` cd /root/workspace/mmdeploy python3 demo/triton/to_triton_model.py \ @@ -27,13 +30,15 @@ python3 demo/triton/to_triton_model.py \ ``` ## Start triton server + ``` tritonserver --model-repository=/model-repository ``` ## Run client code output container + ``` python3 demo/triton/keypoint-detection/grpc_client.py \ model \ /path/to/image -``` \ No newline at end of file +``` diff --git a/demo/triton/keypoint-detection/grpc_client.py b/demo/triton/keypoint-detection/grpc_client.py index 9a3ee02590..191ee54a91 100644 --- a/demo/triton/keypoint-detection/grpc_client.py +++ b/demo/triton/keypoint-detection/grpc_client.py @@ -1,17 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse +import json + import cv2 -from tritonclient.grpc import InferenceServerClient, InferInput, InferRequestedOutput import numpy as np -import json +from tritonclient.grpc import (InferenceServerClient, InferInput, + InferRequestedOutput) def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('model_name', type=str, - help='model name') - parser.add_argument('image', type=str, - help='image path') + parser.add_argument('model_name', type=str, help='model name') + parser.add_argument('image', type=str, help='image path') return parser.parse_args() @@ -24,14 +24,16 @@ def __init__(self, url, model_name, model_version): self._client = InferenceServerClient(self._url) model_config = self._client.get_model_config(self._model_name, self._model_version) - model_metadata = self._client.get_model_metadata(self._model_name, - self._model_version) + model_metadata = self._client.get_model_metadata( + self._model_name, self._model_version) print(f'[model config]:\n{model_config}') print(f'[model metadata]:\n{model_metadata}') self._inputs = {input.name: input for input in model_metadata.inputs} self._input_names = list(self._inputs) self._outputs = { - output.name: output for output in model_metadata.outputs} + output.name: output + for output in model_metadata.outputs + } self._output_names = list(self._outputs) self._outputs_req = [ InferRequestedOutput(name) for name in self._outputs @@ -46,10 +48,10 @@ def infer(self, image, box): results: dict, {name : numpy.array} """ - inputs = [InferInput(self._input_names[0], image.shape, - "UINT8"), - InferInput(self._input_names[1], box.shape, - "BYTES")] + inputs = [ + InferInput(self._input_names[0], image.shape, 'UINT8'), + InferInput(self._input_names[1], box.shape, 'BYTES') + ] inputs[0].set_data_from_numpy(image) inputs[1].set_data_from_numpy(box) results = self._client.infer( @@ -72,20 +74,18 @@ def visualize(img, results): cv2.imwrite('keypoint-detection.jpg', img) -if __name__ == "__main__": +if __name__ == '__main__': args = parse_args() model_name = args.model_name - model_version = "1" - url = "localhost:8001" + model_version = '1' + url = 'localhost:8001' client = GRPCTritonClient(url, model_name, model_version) img = cv2.imread(args.image) bbox = { 'type': 'PoseBbox', - 'value': [ - { - 'bbox': [0.0, 0.0, img.shape[1], img.shape[0]] - } - ] + 'value': [{ + 'bbox': [0.0, 0.0, img.shape[1], img.shape[0]] + }] } bbox = np.array([json.dumps(bbox).encode('utf-8')]) results = client.infer(img, bbox) diff --git a/demo/triton/keypoint-detection/serving/model/1/README.md b/demo/triton/keypoint-detection/serving/model/1/README.md index ff6d1ae274..3b5ec0b47e 100644 --- a/demo/triton/keypoint-detection/serving/model/1/README.md +++ b/demo/triton/keypoint-detection/serving/model/1/README.md @@ -1 +1 @@ -This directory holds the model files. \ No newline at end of file +This directory holds the model files. diff --git a/demo/triton/keypoint-detection/serving/model/config.pbtxt b/demo/triton/keypoint-detection/serving/model/config.pbtxt index fa6380c421..1877c3b0fa 100644 --- a/demo/triton/keypoint-detection/serving/model/config.pbtxt +++ b/demo/triton/keypoint-detection/serving/model/config.pbtxt @@ -26,4 +26,4 @@ parameters { value: { string_value: "0 1" } -} \ No newline at end of file +} diff --git a/demo/triton/object-detection/README.md b/demo/triton/object-detection/README.md index 57270719af..eff432c6e3 100644 --- a/demo/triton/object-detection/README.md +++ b/demo/triton/object-detection/README.md @@ -1,11 +1,13 @@ # Object detection serving ## Starting a docker container + ``` docker run -it --rm --gpus all openmmlab/mmdeploy:triton-22.12 ``` ## Convert pytorch model to tensorrt model + ``` cd /root/workspace/mmdeploy python3 tools/deploy.py \ @@ -19,6 +21,7 @@ python3 tools/deploy.py \ ``` ## Convert tensorrt model to triton format + ``` cd /root/workspace/mmdeploy python3 demo/triton/to_triton_model.py \ @@ -27,13 +30,15 @@ python3 demo/triton/to_triton_model.py \ ``` ## Start triton server + ``` tritonserver --model-repository=/model-repository ``` ## Run client code output container + ``` python3 demo/triton/object-detection/grpc_client.py \ model \ /path/to/image -``` \ No newline at end of file +``` diff --git a/demo/triton/object-detection/grpc_client.py b/demo/triton/object-detection/grpc_client.py index 3795a84e90..321cabfcac 100644 --- a/demo/triton/object-detection/grpc_client.py +++ b/demo/triton/object-detection/grpc_client.py @@ -1,15 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse + import cv2 -from tritonclient.grpc import InferenceServerClient, InferInput, InferRequestedOutput +from tritonclient.grpc import (InferenceServerClient, InferInput, + InferRequestedOutput) def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('model_name', type=str, - help='model name') - parser.add_argument('image', type=str, - help='image path') + parser.add_argument('model_name', type=str, help='model name') + parser.add_argument('image', type=str, help='image path') return parser.parse_args() @@ -22,14 +22,16 @@ def __init__(self, url, model_name, model_version): self._client = InferenceServerClient(self._url) model_config = self._client.get_model_config(self._model_name, self._model_version) - model_metadata = self._client.get_model_metadata(self._model_name, - self._model_version) + model_metadata = self._client.get_model_metadata( + self._model_name, self._model_version) print(f'[model config]:\n{model_config}') print(f'[model metadata]:\n{model_metadata}') self._inputs = {input.name: input for input in model_metadata.inputs} self._input_names = list(self._inputs) self._outputs = { - output.name: output for output in model_metadata.outputs} + output.name: output + for output in model_metadata.outputs + } self._output_names = list(self._outputs) self._outputs_req = [ InferRequestedOutput(name) for name in self._outputs @@ -43,8 +45,7 @@ def infer(self, image): results: dict, {name : numpy.array} """ - inputs = [InferInput(self._input_names[0], image.shape, - "UINT8")] + inputs = [InferInput(self._input_names[0], image.shape, 'UINT8')] inputs[0].set_data_from_numpy(image) results = self._client.infer( model_name=self._model_name, @@ -68,11 +69,11 @@ def visualize(img, results): cv2.imwrite('object-detection.jpg', img) -if __name__ == "__main__": +if __name__ == '__main__': args = parse_args() model_name = args.model_name - model_version = "1" - url = "localhost:8001" + model_version = '1' + url = 'localhost:8001' client = GRPCTritonClient(url, model_name, model_version) img = cv2.imread(args.image) results = client.infer(img) diff --git a/demo/triton/object-detection/serving/model/1/README.md b/demo/triton/object-detection/serving/model/1/README.md index ff6d1ae274..3b5ec0b47e 100644 --- a/demo/triton/object-detection/serving/model/1/README.md +++ b/demo/triton/object-detection/serving/model/1/README.md @@ -1 +1 @@ -This directory holds the model files. \ No newline at end of file +This directory holds the model files. diff --git a/demo/triton/oriented-object-detection/README.md b/demo/triton/oriented-object-detection/README.md index 7da9a45054..670c53c483 100644 --- a/demo/triton/oriented-object-detection/README.md +++ b/demo/triton/oriented-object-detection/README.md @@ -1,11 +1,13 @@ # Oriented object detection serving ## Starting a docker container + ``` docker run -it --rm --gpus all openmmlab/mmdeploy:triton-22.12 ``` ## Convert pytorch model to tensorrt model + ``` cd /root/workspace/mmdeploy python3 tools/deploy.py \ @@ -19,6 +21,7 @@ python3 tools/deploy.py \ ``` ## Convert tensorrt model to triton format + ``` cd /root/workspace/mmdeploy python3 demo/triton/to_triton_model.py \ @@ -27,13 +30,15 @@ python3 demo/triton/to_triton_model.py \ ``` ## Start triton server + ``` tritonserver --model-repository=/model-repository ``` ## Run client code output container + ``` python3 demo/triton/oriented-object-detection/grpc_client.py \ model \ /path/to/image -``` \ No newline at end of file +``` diff --git a/demo/triton/oriented-object-detection/grpc_client.py b/demo/triton/oriented-object-detection/grpc_client.py index 299944f944..9e0525a068 100644 --- a/demo/triton/oriented-object-detection/grpc_client.py +++ b/demo/triton/oriented-object-detection/grpc_client.py @@ -1,17 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse +from math import cos, sin + import cv2 -from tritonclient.grpc import InferenceServerClient, InferInput, InferRequestedOutput import numpy as np -from math import cos, sin +from tritonclient.grpc import (InferenceServerClient, InferInput, + InferRequestedOutput) def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('model_name', type=str, - help='model name') - parser.add_argument('image', type=str, - help='image path') + parser.add_argument('model_name', type=str, help='model name') + parser.add_argument('image', type=str, help='image path') return parser.parse_args() @@ -24,14 +24,16 @@ def __init__(self, url, model_name, model_version): self._client = InferenceServerClient(self._url) model_config = self._client.get_model_config(self._model_name, self._model_version) - model_metadata = self._client.get_model_metadata(self._model_name, - self._model_version) + model_metadata = self._client.get_model_metadata( + self._model_name, self._model_version) print(f'[model config]:\n{model_config}') print(f'[model metadata]:\n{model_metadata}') self._inputs = {input.name: input for input in model_metadata.inputs} self._input_names = list(self._inputs) self._outputs = { - output.name: output for output in model_metadata.outputs} + output.name: output + for output in model_metadata.outputs + } self._output_names = list(self._outputs) self._outputs_req = [ InferRequestedOutput(name) for name in self._outputs @@ -45,8 +47,7 @@ def infer(self, image): results: dict, {name : numpy.array} """ - inputs = [InferInput(self._input_names[0], image.shape, - "UINT8")] + inputs = [InferInput(self._input_names[0], image.shape, 'UINT8')] inputs[0].set_data_from_numpy(image) results = self._client.infer( model_name=self._model_name, @@ -79,11 +80,11 @@ def visualize(img, results): cv2.imwrite('oriented-object-detection.jpg', img) -if __name__ == "__main__": +if __name__ == '__main__': args = parse_args() model_name = args.model_name - model_version = "1" - url = "localhost:8001" + model_version = '1' + url = 'localhost:8001' client = GRPCTritonClient(url, model_name, model_version) img = cv2.imread(args.image) results = client.infer(img) diff --git a/demo/triton/oriented-object-detection/serving/model/1/README.md b/demo/triton/oriented-object-detection/serving/model/1/README.md index ff6d1ae274..3b5ec0b47e 100644 --- a/demo/triton/oriented-object-detection/serving/model/1/README.md +++ b/demo/triton/oriented-object-detection/serving/model/1/README.md @@ -1 +1 @@ -This directory holds the model files. \ No newline at end of file +This directory holds the model files. diff --git a/demo/triton/semantic-segmentation/README.md b/demo/triton/semantic-segmentation/README.md index 19af161ef8..31e53a504a 100644 --- a/demo/triton/semantic-segmentation/README.md +++ b/demo/triton/semantic-segmentation/README.md @@ -1,11 +1,13 @@ # Semantic segmentation serving ## Starting a docker container + ``` docker run -it --rm --gpus all openmmlab/mmdeploy:triton-22.12 ``` ## Convert pytorch model to tensorrt model + ``` cd /root/workspace/mmdeploy python3 tools/deploy.py \ @@ -19,6 +21,7 @@ python3 tools/deploy.py \ ``` ## Convert tensorrt model to triton format + ``` cd /root/workspace/mmdeploy python3 demo/triton/to_triton_model.py \ @@ -27,13 +30,15 @@ python3 demo/triton/to_triton_model.py \ ``` ## Start triton server + ``` tritonserver --model-repository=/model-repository ``` ## Run client code output container + ``` python3 demo/triton/semantic-segmentation/grpc_client.py \ model \ /path/to/image -``` \ No newline at end of file +``` diff --git a/demo/triton/semantic-segmentation/grpc_client.py b/demo/triton/semantic-segmentation/grpc_client.py index 72e967a230..296723b034 100644 --- a/demo/triton/semantic-segmentation/grpc_client.py +++ b/demo/triton/semantic-segmentation/grpc_client.py @@ -1,16 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse + import cv2 -from tritonclient.grpc import InferenceServerClient, InferInput, InferRequestedOutput import numpy as np +from tritonclient.grpc import (InferenceServerClient, InferInput, + InferRequestedOutput) def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('model_name', type=str, - help='model name') - parser.add_argument('image', type=str, - help='image path') + parser.add_argument('model_name', type=str, help='model name') + parser.add_argument('image', type=str, help='image path') return parser.parse_args() @@ -32,14 +32,16 @@ def __init__(self, url, model_name, model_version): self._client = InferenceServerClient(self._url) model_config = self._client.get_model_config(self._model_name, self._model_version) - model_metadata = self._client.get_model_metadata(self._model_name, - self._model_version) + model_metadata = self._client.get_model_metadata( + self._model_name, self._model_version) print(f'[model config]:\n{model_config}') print(f'[model metadata]:\n{model_metadata}') self._inputs = {input.name: input for input in model_metadata.inputs} self._input_names = list(self._inputs) self._outputs = { - output.name: output for output in model_metadata.outputs} + output.name: output + for output in model_metadata.outputs + } self._output_names = list(self._outputs) self._outputs_req = [ InferRequestedOutput(name) for name in self._outputs @@ -53,8 +55,7 @@ def infer(self, image): results: dict, {name : numpy.array} """ - inputs = [InferInput(self._input_names[0], image.shape, - "UINT8")] + inputs = [InferInput(self._input_names[0], image.shape, 'UINT8')] inputs[0].set_data_from_numpy(image) results = self._client.infer( model_name=self._model_name, @@ -84,11 +85,11 @@ def visualize(img, results): cv2.imwrite('semantic-segmentation.png', img) -if __name__ == "__main__": +if __name__ == '__main__': args = parse_args() model_name = args.model_name - model_version = "1" - url = "localhost:8001" + model_version = '1' + url = 'localhost:8001' client = GRPCTritonClient(url, model_name, model_version) img = cv2.imread(args.image) results = client.infer(img) diff --git a/demo/triton/semantic-segmentation/serving/mask/model/1/README.md b/demo/triton/semantic-segmentation/serving/mask/model/1/README.md index ff6d1ae274..3b5ec0b47e 100644 --- a/demo/triton/semantic-segmentation/serving/mask/model/1/README.md +++ b/demo/triton/semantic-segmentation/serving/mask/model/1/README.md @@ -1 +1 @@ -This directory holds the model files. \ No newline at end of file +This directory holds the model files. diff --git a/demo/triton/semantic-segmentation/serving/score/model/1/README.md b/demo/triton/semantic-segmentation/serving/score/model/1/README.md index ff6d1ae274..3b5ec0b47e 100644 --- a/demo/triton/semantic-segmentation/serving/score/model/1/README.md +++ b/demo/triton/semantic-segmentation/serving/score/model/1/README.md @@ -1 +1 @@ -This directory holds the model files. \ No newline at end of file +This directory holds the model files. diff --git a/demo/triton/semantic-segmentation/serving/score/model/config.pbtxt b/demo/triton/semantic-segmentation/serving/score/model/config.pbtxt index e905be48e4..3fe8eda342 100644 --- a/demo/triton/semantic-segmentation/serving/score/model/config.pbtxt +++ b/demo/triton/semantic-segmentation/serving/score/model/config.pbtxt @@ -12,4 +12,4 @@ output { name: "score" data_type: TYPE_FP32 dims: [ -1, -1, -1 ] -} \ No newline at end of file +} diff --git a/demo/triton/text-detection/README.md b/demo/triton/text-detection/README.md index 7c3880e4e0..ac4074c4f3 100644 --- a/demo/triton/text-detection/README.md +++ b/demo/triton/text-detection/README.md @@ -1,11 +1,13 @@ # Text detection serving ## Starting a docker container + ``` docker run -it --rm --gpus all openmmlab/mmdeploy:triton-22.12 ``` ## Convert pytorch model to tensorrt model + ``` cd /root/workspace/mmdeploy python3 tools/deploy.py \ @@ -19,6 +21,7 @@ python3 tools/deploy.py \ ``` ## Convert tensorrt model to triton format + ``` cd /root/workspace/mmdeploy python3 demo/triton/to_triton_model.py \ @@ -27,13 +30,15 @@ python3 demo/triton/to_triton_model.py \ ``` ## Start triton server + ``` tritonserver --model-repository=/model-repository ``` ## Run client code output container + ``` python3 demo/triton/text-detection/grpc_client.py \ model \ /path/to/image -``` \ No newline at end of file +``` diff --git a/demo/triton/text-detection/grpc_client.py b/demo/triton/text-detection/grpc_client.py index 64abd1e128..d93076b7ec 100644 --- a/demo/triton/text-detection/grpc_client.py +++ b/demo/triton/text-detection/grpc_client.py @@ -1,17 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse + import cv2 -from tritonclient.grpc import InferenceServerClient, InferInput, InferRequestedOutput -import numpy as np -import json +from tritonclient.grpc import (InferenceServerClient, InferInput, + InferRequestedOutput) def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('model_name', type=str, - help='model name') - parser.add_argument('image', type=str, - help='image path') + parser.add_argument('model_name', type=str, help='model name') + parser.add_argument('image', type=str, help='image path') return parser.parse_args() @@ -24,14 +22,16 @@ def __init__(self, url, model_name, model_version): self._client = InferenceServerClient(self._url) model_config = self._client.get_model_config(self._model_name, self._model_version) - model_metadata = self._client.get_model_metadata(self._model_name, - self._model_version) + model_metadata = self._client.get_model_metadata( + self._model_name, self._model_version) print(f'[model config]:\n{model_config}') print(f'[model metadata]:\n{model_metadata}') self._inputs = {input.name: input for input in model_metadata.inputs} self._input_names = list(self._inputs) self._outputs = { - output.name: output for output in model_metadata.outputs} + output.name: output + for output in model_metadata.outputs + } self._output_names = list(self._outputs) self._outputs_req = [ InferRequestedOutput(name) for name in self._outputs @@ -45,8 +45,7 @@ def infer(self, image): results: dict, {name : numpy.array} """ - inputs = [InferInput(self._input_names[0], image.shape, - "UINT8")] + inputs = [InferInput(self._input_names[0], image.shape, 'UINT8')] inputs[0].set_data_from_numpy(image) results = self._client.infer( model_name=self._model_name, @@ -71,11 +70,11 @@ def visualize(img, results): cv2.imwrite('text-detection.jpg', img) -if __name__ == "__main__": +if __name__ == '__main__': args = parse_args() model_name = args.model_name - model_version = "1" - url = "localhost:8001" + model_version = '1' + url = 'localhost:8001' client = GRPCTritonClient(url, model_name, model_version) img = cv2.imread(args.image) results = client.infer(img) diff --git a/demo/triton/text-detection/serving/model/1/README.md b/demo/triton/text-detection/serving/model/1/README.md index ff6d1ae274..3b5ec0b47e 100644 --- a/demo/triton/text-detection/serving/model/1/README.md +++ b/demo/triton/text-detection/serving/model/1/README.md @@ -1 +1 @@ -This directory holds the model files. \ No newline at end of file +This directory holds the model files. diff --git a/demo/triton/text-ocr/README.md b/demo/triton/text-ocr/README.md index 66cb7e4373..acda4e62e8 100644 --- a/demo/triton/text-ocr/README.md +++ b/demo/triton/text-ocr/README.md @@ -1,11 +1,13 @@ # Text ocr serving ## Starting a docker container + ``` docker run -it --rm --gpus all openmmlab/mmdeploy:triton-22.12 ``` ## Convert pytorch model to tensorrt model + ``` cd /root/workspace/mmdeploy @@ -31,6 +33,7 @@ python3 tools/deploy.py \ ``` ## Ensemble detection and recognition model + ``` cd /root/workspace/mmdeploy cp -r demo/triton/text-ocr/serving /model-repository @@ -39,11 +42,13 @@ cp -r work_dir/crnn/* /model-repository/model/1/text_recognition/ ``` ## Start triton server + ``` tritonserver --model-repository=/model-repository ``` ## Run client code output container + ``` python3 demo/triton/text-ocr/grpc_client.py \ model \ diff --git a/demo/triton/text-ocr/grpc_client.py b/demo/triton/text-ocr/grpc_client.py index 28729f10fa..7cf70cad29 100644 --- a/demo/triton/text-ocr/grpc_client.py +++ b/demo/triton/text-ocr/grpc_client.py @@ -1,15 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse + import cv2 -from tritonclient.grpc import InferenceServerClient, InferInput, InferRequestedOutput +from tritonclient.grpc import (InferenceServerClient, InferInput, + InferRequestedOutput) def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('model_name', type=str, - help='model name') - parser.add_argument('image', type=str, - help='image path') + parser.add_argument('model_name', type=str, help='model name') + parser.add_argument('image', type=str, help='image path') return parser.parse_args() @@ -22,14 +22,16 @@ def __init__(self, url, model_name, model_version): self._client = InferenceServerClient(self._url) model_config = self._client.get_model_config(self._model_name, self._model_version) - model_metadata = self._client.get_model_metadata(self._model_name, - self._model_version) + model_metadata = self._client.get_model_metadata( + self._model_name, self._model_version) print(f'[model config]:\n{model_config}') print(f'[model metadata]:\n{model_metadata}') self._inputs = {input.name: input for input in model_metadata.inputs} self._input_names = list(self._inputs) self._outputs = { - output.name: output for output in model_metadata.outputs} + output.name: output + for output in model_metadata.outputs + } self._output_names = list(self._outputs) self._outputs_req = [ InferRequestedOutput(name) for name in self._outputs @@ -43,8 +45,7 @@ def infer(self, image): results: dict, {name : numpy.array} """ - inputs = [InferInput(self._input_names[0], image.shape, - "UINT8")] + inputs = [InferInput(self._input_names[0], image.shape, 'UINT8')] inputs[0].set_data_from_numpy(image) results = self._client.infer( model_name=self._model_name, @@ -63,17 +64,18 @@ def visualize(results): for i, (det_bbox, det_score, rec_text, rec_score) in \ enumerate(zip(det_bboxes, det_scores, rec_texts, rec_scores)): print(f'bbox[{i}] ({det_bbox[0]:.2f}, {det_bbox[1]:.2f}), ' - f'({det_bbox[2]:.2f}, {det_bbox[3]:.2f}), ({det_bbox[4]:.2f}, {det_bbox[5]:.2f}), ' - f'({det_bbox[6]:.2f}, {det_bbox[7]:.2f}), {det_score[0]:.2f}') + f'({det_bbox[2]:.2f}, {det_bbox[3]:.2f}), ({det_bbox[4]:.2f}, ' + f'{det_bbox[5]:.2f}), ({det_bbox[6]:.2f}, {det_bbox[7]:.2f}), ' + f'{det_score[0]:.2f}') text = rec_text.decode('utf-8') print(f'text[{i}] {text}') -if __name__ == "__main__": +if __name__ == '__main__': args = parse_args() model_name = args.model_name - model_version = "1" - url = "localhost:8001" + model_version = '1' + url = 'localhost:8001' client = GRPCTritonClient(url, model_name, model_version) img = cv2.imread(args.image) results = client.infer(img) diff --git a/demo/triton/text-ocr/serving/model/1/pipeline_template.json b/demo/triton/text-ocr/serving/model/1/pipeline_template.json index f841e83c9f..37ab975475 100644 --- a/demo/triton/text-ocr/serving/model/1/pipeline_template.json +++ b/demo/triton/text-ocr/serving/model/1/pipeline_template.json @@ -47,4 +47,4 @@ "output": "*texts" } ] -} \ No newline at end of file +} diff --git a/demo/triton/text-recognition/README.md b/demo/triton/text-recognition/README.md index 5b4c428559..38c4b72cad 100644 --- a/demo/triton/text-recognition/README.md +++ b/demo/triton/text-recognition/README.md @@ -1,11 +1,13 @@ # Text recognition serving ## Starting a docker container + ``` docker run -it --rm --gpus all openmmlab/mmdeploy:triton-22.12 ``` ## Convert pytorch model to tensorrt model + ``` cd /root/workspace/mmdeploy python3 tools/deploy.py \ @@ -19,6 +21,7 @@ python3 tools/deploy.py \ ``` ## Convert tensorrt model to triton format + ``` cd /root/workspace/mmdeploy python3 demo/triton/to_triton_model.py \ @@ -27,11 +30,13 @@ python3 demo/triton/to_triton_model.py \ ``` ## Start triton server + ``` tritonserver --model-repository=/model-repository ``` ## Run client code output container + ``` python3 demo/triton/text-detection/grpc_client.py \ model \ diff --git a/demo/triton/text-recognition/grpc_client.py b/demo/triton/text-recognition/grpc_client.py index 8d045e2aae..8e5384b252 100644 --- a/demo/triton/text-recognition/grpc_client.py +++ b/demo/triton/text-recognition/grpc_client.py @@ -1,17 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse +import json + import cv2 -from tritonclient.grpc import InferenceServerClient, InferInput, InferRequestedOutput import numpy as np -import json +from tritonclient.grpc import (InferenceServerClient, InferInput, + InferRequestedOutput) def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('model_name', type=str, - help='model name') - parser.add_argument('image', type=str, - help='image path') + parser.add_argument('model_name', type=str, help='model name') + parser.add_argument('image', type=str, help='image path') return parser.parse_args() @@ -24,14 +24,16 @@ def __init__(self, url, model_name, model_version): self._client = InferenceServerClient(self._url) model_config = self._client.get_model_config(self._model_name, self._model_version) - model_metadata = self._client.get_model_metadata(self._model_name, - self._model_version) + model_metadata = self._client.get_model_metadata( + self._model_name, self._model_version) print(f'[model config]:\n{model_config}') print(f'[model metadata]:\n{model_metadata}') self._inputs = {input.name: input for input in model_metadata.inputs} self._input_names = list(self._inputs) self._outputs = { - output.name: output for output in model_metadata.outputs} + output.name: output + for output in model_metadata.outputs + } self._output_names = list(self._outputs) self._outputs_req = [ InferRequestedOutput(name) for name in self._outputs @@ -46,10 +48,10 @@ def infer(self, image, box): results: dict, {name : numpy.array} """ - inputs = [InferInput(self._input_names[0], image.shape, - "UINT8"), - InferInput(self._input_names[1], box.shape, - "BYTES")] + inputs = [ + InferInput(self._input_names[0], image.shape, 'UINT8'), + InferInput(self._input_names[1], box.shape, 'BYTES') + ] inputs[0].set_data_from_numpy(image) inputs[1].set_data_from_numpy(box) results = self._client.infer( @@ -70,20 +72,22 @@ def visualize(results): print(box_texts, box_scores) -if __name__ == "__main__": +if __name__ == '__main__': args = parse_args() model_name = args.model_name - model_version = "1" - url = "localhost:8001" + model_version = '1' + url = 'localhost:8001' client = GRPCTritonClient(url, model_name, model_version) img = cv2.imread(args.image) bbox = { - 'type': 'TextBbox', - 'value': [ - { - 'bbox': [0.0, 0.0, img.shape[1], 0, img.shape[1], img.shape[0], 0, img.shape[0]], - } - ] + 'type': + 'TextBbox', + 'value': [{ + 'bbox': [ + 0.0, 0.0, img.shape[1], 0, img.shape[1], img.shape[0], 0, + img.shape[0] + ], + }] } bbox = np.array([json.dumps(bbox).encode('utf-8')]) results = client.infer(img, bbox) diff --git a/demo/triton/text-recognition/serving/model/1/README.md b/demo/triton/text-recognition/serving/model/1/README.md index ff6d1ae274..3b5ec0b47e 100644 --- a/demo/triton/text-recognition/serving/model/1/README.md +++ b/demo/triton/text-recognition/serving/model/1/README.md @@ -1 +1 @@ -This directory holds the model files. \ No newline at end of file +This directory holds the model files. diff --git a/demo/triton/to_triton_model.py b/demo/triton/to_triton_model.py index ee3f95891a..81cb864ad1 100644 --- a/demo/triton/to_triton_model.py +++ b/demo/triton/to_triton_model.py @@ -3,8 +3,8 @@ import json import os import os.path as osp -from enum import Enum, unique import shutil +from enum import Enum, unique from glob import glob BASEDIR = os.path.dirname(__file__) @@ -35,7 +35,8 @@ def copy_template(src_folder, dst_folder): class Convert: - def __init__(self, model_type, model_dir, deploy_cfg, pipeline_cfg, detail_cfg, output_dir): + def __init__(self, model_type, model_dir, deploy_cfg, pipeline_cfg, + detail_cfg, output_dir): self._model_type = model_type self._model_dir = model_dir self._deploy_cfg = deploy_cfg @@ -62,19 +63,14 @@ def create_single_model(self): self._pipeline_cfg['pipeline']['input'].append('bbox') self._pipeline_cfg['pipeline']['tasks'][0]['input'] = ['patch'] warpbbox = { - "type": "Task", - "module": "WarpBbox", - "input": [ - "img", - "bbox" - ], - "output": [ - "patch" - ] + 'type': 'Task', + 'module': 'WarpBbox', + 'input': ['img', 'bbox'], + 'output': ['patch'] } self._pipeline_cfg['pipeline']['tasks'].insert(0, warpbbox) - self.write_json_file(self._pipeline_cfg, - 'pipeline.json', output_model_folder) + self.write_json_file(self._pipeline_cfg, 'pipeline.json', + output_model_folder) else: self.copy_file('pipeline.json', self._model_dir, output_model_folder) @@ -90,10 +86,12 @@ def create_single_model(self): def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('model_dir', type=str, - help='converted model dir with `--dump-info` flag when convert the model') - parser.add_argument('output_dir', type=str, - help='output dir') + parser.add_argument( + 'model_dir', + type=str, + help='converted model dir with ' + '`--dump-info` flag when convert the model') + parser.add_argument('output_dir', type=str, help='output dir') return parser.parse_args() @@ -145,7 +143,8 @@ def get_model_type(detail_cfg, pipeline_cfg): pipeline_cfg = json.load(f) with open(osp.join(model_dir, 'detail.json')) as f: detail_cfg = json.load(f) - assert 'onnx_config' in detail_cfg, f'currently, only support onnx as middle ir' + assert 'onnx_config' in detail_cfg, \ + 'currently, only support onnx as middle ir' # process model_type = get_model_type(detail_cfg, pipeline_cfg) From 6fe281c205af5dea4750921b600115fed42b34b2 Mon Sep 17 00:00:00 2001 From: irexyc Date: Thu, 18 May 2023 14:12:30 +0800 Subject: [PATCH 15/16] update docs --- docs/en/02-how-to-run/triton_server.md | 4 ++++ docs/zh_cn/02-how-to-run/triton_server.md | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/docs/en/02-how-to-run/triton_server.md b/docs/en/02-how-to-run/triton_server.md index 7652d47755..a8fe4b1df1 100644 --- a/docs/en/02-how-to-run/triton_server.md +++ b/docs/en/02-how-to-run/triton_server.md @@ -29,6 +29,10 @@ a) Using Docker images For ease of use, we provide a Docker image to support the deployment of models converted by MMDeploy. The image supports Tensorrt and ONNX Runtime as backends. If you need other backends, you can choose build from source. +``` +docker run -it --rm --gpus all openmmlab/mmdeploy:triton-22.12 +``` + b) Build from source You can refer [build from source](../01-how-to-build/build_from_source.md) to build MMDeploy. In order to build MMDeploy Triton Backend, you need to add `-DTRITON_MMDEPLOY_BACKEND=ON` to cmake configure command. By default, the latest version of Triton Backend is used. If you want to use an older version of Triton Backend, you can add `-DTRITON_TAG=r22.12` to the cmake configure command. diff --git a/docs/zh_cn/02-how-to-run/triton_server.md b/docs/zh_cn/02-how-to-run/triton_server.md index d9097252ca..bc25d6da03 100644 --- a/docs/zh_cn/02-how-to-run/triton_server.md +++ b/docs/zh_cn/02-how-to-run/triton_server.md @@ -29,6 +29,10 @@ a) 使用 Docker 镜像 为了方便使用,我们提供了 Docker 镜像,支持对通过 MMDeploy 转换的模型进行部署。镜像支持 Tensorrt 以及 ONNX Runtime 作为后端。若需要其他后端,可选择从源码进行编译。 +``` +docker run -it --rm --gpus all openmmlab/mmdeploy:triton-22.12 +``` + b) 从源码编译 从源码编译 MMDeploy 的方式可参考[源码手动安装](../01-how-to-build/build_from_source.md),要编译 MMDeploy Triton Backend,需要在编译命令中添加:`-DTRITON_MMDEPLOY_BACKEND=ON`。默认使用最新版本的 Triton Backend,若要使用旧版本的 Triton Backend,可在编译命令中添加`-DTRITON_TAG=r22.12` From fcdf52f0a504f86e9d3528dd72856a71ef5e28d9 Mon Sep 17 00:00:00 2001 From: irexyc Date: Thu, 18 May 2023 15:39:13 +0800 Subject: [PATCH 16/16] fix lint --- csrc/mmdeploy/triton/instance_state.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/mmdeploy/triton/instance_state.cpp b/csrc/mmdeploy/triton/instance_state.cpp index fe8808cdfb..ec8a9414a0 100644 --- a/csrc/mmdeploy/triton/instance_state.cpp +++ b/csrc/mmdeploy/triton/instance_state.cpp @@ -388,7 +388,7 @@ TRITONSERVER_Error* ModelInstanceState::Execute(TRITONBACKEND_Request** requests TritonModelInstance(), total_batch_size, exec_start_ns, compute_start_ns, compute_end_ns, exec_end_ns), "failed reporting batch request statistics"); -#endif // TRITON_ENABLE_STATS +#endif return nullptr; // success }