From c50c04d0caa90eede9d1af16d6f0d5470a1c89ad Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Mon, 16 Sep 2024 18:38:58 -0700 Subject: [PATCH 01/11] Address ZeroK case for Gemm for CPU and CUDA --- cmake/onnxruntime_python.cmake | 4 +- onnxruntime/core/providers/cpu/math/gemm.cc | 59 ++++++++++++------- .../core/providers/cpu/math/gemm_helper.h | 3 +- onnxruntime/core/providers/cuda/math/gemm.cc | 10 ++++ .../test/providers/cpu/math/gemm_test.cc | 41 +++++++++++++ 5 files changed, 91 insertions(+), 26 deletions(-) diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index 574cffbb716b..ad9324aa3024 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -71,9 +71,7 @@ onnxruntime_add_shared_library_module(onnxruntime_pybind11_state ${onnxruntime_p if(MSVC) target_compile_options(onnxruntime_pybind11_state PRIVATE "$<$:SHELL:--compiler-options /utf-8>" "$<$>:/utf-8>") - if(onnxruntime_ENABLE_TRAINING) - target_compile_options(onnxruntime_pybind11_state PRIVATE "/bigobj") - endif() + target_compile_options(onnxruntime_pybind11_state PRIVATE "/bigobj") endif() if(HAS_CAST_FUNCTION_TYPE) target_compile_options(onnxruntime_pybind11_state PRIVATE "-Wno-cast-function-type") diff --git a/onnxruntime/core/providers/cpu/math/gemm.cc b/onnxruntime/core/providers/cpu/math/gemm.cc index 5a886cce9d5d..fb39d952cc71 100644 --- a/onnxruntime/core/providers/cpu/math/gemm.cc +++ b/onnxruntime/core/providers/cpu/math/gemm.cc @@ -154,6 +154,14 @@ void Gemm::ComputeGemm(CBLAS_TRANSPOSE trans_a, CBLAS_TRANSPOSE trans_b, // Broadcast the bias as needed if bias is given GemmBroadcastBias(M, N, beta, c_data, c_shape, y_data); + if (K == 0) { + if (beta == 0 || c_data == nullptr) { + auto output_span = gsl::make_span(y_data, SafeInt(M) * N); + std::fill(output_span.begin(), output_span.end(), T{}); + } + return; + } + math::Gemm(trans_a, trans_b, M, N, K, alpha, @@ -179,16 +187,18 @@ void Gemm::ComputeGemm(CBLAS_TRANSPOSE trans_a, CBLAS_TRANSPOSE trans if (M == 0 || N == 0) return; -#if defined(__GNUC__) && defined(HAS_CLASS_MEMACCESS) -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wclass-memaccess" -#endif - // MLFloat16's constructor is explicit, so here we need to use memset + if (K == 0) { + if (beta != onnxruntime::MLFloat16::Zero && c_data != nullptr) { + GemmBroadcastBias(M, N, beta, c_data, c_shape, y_data); + } else { + auto output_span = gsl::make_span(y_data, SafeInt(M) * N); + std::fill(output_span.begin(), output_span.end(), onnxruntime::MLFloat16::Zero); + } + return; + } + if (c_data == nullptr) - memset(&beta, 0, sizeof(MLFloat16)); -#if defined(__GNUC__) && defined(HAS_CLASS_MEMACCESS) -#pragma GCC diagnostic pop -#endif + beta = onnxruntime::MLFloat16::Zero; #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED bool support_mlas = false; if (c_shape == nullptr) { @@ -413,19 +423,24 @@ Status Gemm::Compute(OpKernelContext* context) const { c_data, c_shape, y_data, thread_pool); } else { GemmBroadcastBias(M, N, beta_, c_data, c_shape, y_data); - MlasGemm( - trans_A_, - static_cast(M), - static_cast(N), - static_cast(K), - alpha_, - A->Data(), - static_cast(trans_A_ != CblasNoTrans ? M : K), - packed_b_.get(), - c_data != nullptr ? beta_ : 0.0f, - y_data, - static_cast(N), - thread_pool); + if (K > 0) { + MlasGemm( + trans_A_, + static_cast(M), + static_cast(N), + static_cast(K), + alpha_, + A->Data(), + static_cast(trans_A_ != CblasNoTrans ? M : K), + packed_b_.get(), + c_data != nullptr ? beta_ : 0.0f, + y_data, + static_cast(N), + thread_pool); + } else { + auto output_span = Y->MutableDataAsSpan(); + std::fill(output_span.begin(), output_span.end(), onnxruntime::MLFloat16::Zero); + } } ComputeActivation(y_data, SafeInt(M) * N, thread_pool); diff --git a/onnxruntime/core/providers/cpu/math/gemm_helper.h b/onnxruntime/core/providers/cpu/math/gemm_helper.h index f37b00ac2c16..b55bf2b5dbbf 100644 --- a/onnxruntime/core/providers/cpu/math/gemm_helper.h +++ b/onnxruntime/core/providers/cpu/math/gemm_helper.h @@ -56,7 +56,8 @@ class GemmHelper { status_ = common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Gemm: Invalid bias shape for broadcast"); // it is possible the input is empty tensor, for example the output of roipool in fast rcnn. - ORT_ENFORCE(M_ >= 0 && K_ > 0 && N_ >= 0); + // it is also possible that K == 0 + ORT_ENFORCE(M_ >= 0 && K_ >= 0 && N_ >= 0); } ptrdiff_t M() const { return M_; } diff --git a/onnxruntime/core/providers/cuda/math/gemm.cc b/onnxruntime/core/providers/cuda/math/gemm.cc index 4e61e0c8c69c..7fa5e74b5424 100644 --- a/onnxruntime/core/providers/cuda/math/gemm.cc +++ b/onnxruntime/core/providers/cuda/math/gemm.cc @@ -137,6 +137,16 @@ Status Gemm::ComputeDefault(OpKernelContext* ctx, int M, int N, int K) const } } + if (K == 0) { + if (beta_ == 0 || B == nullptr) { + // When we have (M, 0, N) then the output should be filled out with zeros + // unless we have a bias + Fill(Stream(ctx), reinterpret_cast(Y->MutableData()), CudaT(0.f), + Y->Shape().Size()); + } + return Status::OK(); + } + CudaT alpha = ToCudaType::FromFloat(alpha_); CudaT beta = ToCudaType::FromFloat(beta_); // Gemm, note that CUDA assumes col-major, so Y(N,M) = alpha * op(W) x op(X) + beta * Y diff --git a/onnxruntime/test/providers/cpu/math/gemm_test.cc b/onnxruntime/test/providers/cpu/math/gemm_test.cc index 7ec84d87b2a8..9a3313e826cf 100644 --- a/onnxruntime/test/providers/cpu/math/gemm_test.cc +++ b/onnxruntime/test/providers/cpu/math/gemm_test.cc @@ -641,6 +641,47 @@ TYPED_TEST(GemmOpTypedTests, GemmEmptyTensor) { .Config(run_with_tunable_op) .RunWithConfig(); } + +TYPED_TEST(GemmOpTypedTests, ZeroKWithBias) { + OpTester test("Gemm", 13); + + test.AddAttribute("transA", static_cast(0)); + test.AddAttribute("transB", static_cast(0)); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", 1.0f); + + test.AddInput("A", {4, 0}, {}); + test.AddInput("B", {0, 4}, {}); + test.AddInput("C", {4}, std::vector(4, static_cast(1.0f))); + test.AddOutput("Y", {4, 4}, std::vector(16, static_cast(1.0f))); + + test.ConfigExcludeEps({kCoreMLExecutionProvider, kNnapiExecutionProvider, + kDmlExecutionProvider, kDnnlExecutionProvider, kQnnExecutionProvider, + kOpenVINOExecutionProvider}) + .Config(run_with_tunable_op) + .RunWithConfig(); +} + +TYPED_TEST(GemmOpTypedTests, ZeroKWithNoBias) { + OpTester test("Gemm", 13); + + test.AddAttribute("transA", static_cast(0)); + test.AddAttribute("transB", static_cast(0)); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", .0f); + + test.AddInput("A", {4, 0}, {}); + test.AddInput("B", {0, 4}, {}); + test.AddOutput("Y", {4, 4}, std::vector(16, static_cast(0.0f))); + + test.ConfigExcludeEps({kCoreMLExecutionProvider, kNnapiExecutionProvider, + kDmlExecutionProvider, kDnnlExecutionProvider, kQnnExecutionProvider, + kOpenVINOExecutionProvider}) + .Config(run_with_tunable_op) + .RunWithConfig(); +} + + TYPED_TEST(GemmOpTypedTests, MissingBias) { OpTester test("Gemm", 11); From 4713219f71d74ee28981a6ff7a60c524bb28c513 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Wed, 18 Sep 2024 14:22:19 -0700 Subject: [PATCH 02/11] Add QGemm K == 0 handling --- .../cpu/quantization/quant_gemm.cc | 94 ++++++++++++++++++- .../test/providers/cpu/math/gemm_test.cc | 1 - 2 files changed, 91 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc b/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc index ff8ad090820b..6731d595f40f 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc @@ -12,6 +12,89 @@ namespace onnxruntime { namespace contrib { +template +void GemmBroadcastBiasScaleBackWithCast(int64_t M, int64_t N, const S* c_data, const TensorShape& bias_shape, + T* output, float a_scale, float b_scale) { + auto output_mat = EigenMatrixMapRowMajor(output, M, N); + if (bias_shape.Size() == 1) { + // C is (), (1,) or (1, 1), set the scalar + const auto constant = static_cast(static_cast(c_data[0]) * a_scale * b_scale); + output_mat.setConstant(constant); + } else if (bias_shape.NumDimensions() == 1 || bias_shape[0] == 1) { + // C is (N,) or (1, N) + output_mat.rowwise() = (ConstEigenVectorMap(c_data, N).transpose().cast() * + a_scale * b_scale) + .cast(); + } else if (bias_shape[1] == 1) { + // C is (M, 1) + output_mat.colwise() = (ConstEigenVectorMap(c_data, M).cast() * + a_scale * b_scale) + .cast(); + } else { + // C is (M, N), no broadcast needed. + output_mat = (ConstEigenMatrixMapRowMajor(c_data, M, N).cast() * + a_scale * b_scale) + .cast(); + } +} + +/// +/// This function will attempt to handle the case where K is zero while M and N are not +/// We need to fill the output either with zeros or with c_data if present. +/// +/// +/// +/// +/// +/// +/// +/// +static void HandleZeroKCase(const Tensor& a_scale, const Tensor& b_scale, Tensor& y, const AllocatorPtr& allocator, + const Tensor* y_scale, const Tensor* y_zp, const Tensor* bias) { + const auto output_dims = y.Shape().GetDims(); + const int64_t M = output_dims[0]; + const int64_t N = output_dims[1]; + const float a_scale_value = a_scale.Data()[0]; + const float b_scale_value = b_scale.Data()[0]; + + if (y_zp == nullptr) { + // Either fill with c_data if present or 0 + int8_t* output = reinterpret_cast(y.MutableDataRaw()); + if (bias != nullptr) { + GemmBroadcastBiasScaleBackWithCast(M, N, bias->Data(), bias->Shape(), output, + 1.f, 1.f); + } else { + EigenMatrixMapRowMajor output_mat(output, M, N); + output_mat.setZero(); + } + } else { + if (bias != nullptr) { + // scale c_data back to float with result = c_data * a_scale * b_scale. + Tensor scaled_back(DataTypeImpl::GetType(), y.Shape(), allocator); + GemmBroadcastBiasScaleBackWithCast(M, N, bias->Data(), bias->Shape(), + scaled_back.MutableData(), + a_scale_value, b_scale_value); + + // re-quantize + if (y_zp->IsDataType()) { + auto q_params = quantization::GetTensorQuantizationParams(y_scale, y_zp); + quantization::Quantize(scaled_back.Data(), + reinterpret_cast(y.MutableDataRaw()), q_params, + scaled_back.Shape().Size()); + } else { + auto q_params = quantization::GetTensorQuantizationParams(y_scale, y_zp); + quantization::Quantize(scaled_back.Data(), + reinterpret_cast(y.MutableDataRaw()), q_params, + scaled_back.Shape().Size()); + } + } else { + // Fill with y_zp + int32_t y_zp_value = y_zp->IsDataType() ? *(y_zp->Data()) : *(y_zp->Data()); + memset(y.MutableDataRaw(), y_zp_value, M * N); + } + } +} + class QGemm : protected GemmBase, public MatMulIntegerBase { public: QGemm(const OpKernelInfo& info) : GemmBase(info), MatMulIntegerBase(info) { @@ -45,6 +128,14 @@ class QGemm : protected GemmBase, public MatMulIntegerBase { AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + auto y = context->Output(OUT_Y, {M, N}); + if (M == 0 || N == 0) return Status::OK(); + + if (K == 0) { + HandleZeroKCase(*a_scale, *b_scale, *y, allocator, y_scale, y_zp, c); + return Status::OK(); + } + bool a_is_signed = a->IsDataType(); const uint8_t* a_data = static_cast(a->DataRaw()); @@ -67,9 +158,6 @@ class QGemm : protected GemmBase, public MatMulIntegerBase { } } - auto y = context->Output(OUT_Y, {M, N}); - if (M == 0 || N == 0) return Status::OK(); - // prepare output buffer of GEMM int32_t* gemm_output_data = nullptr; std::optional gemm_output_buffer; diff --git a/onnxruntime/test/providers/cpu/math/gemm_test.cc b/onnxruntime/test/providers/cpu/math/gemm_test.cc index 9a3313e826cf..625ff29d4ccf 100644 --- a/onnxruntime/test/providers/cpu/math/gemm_test.cc +++ b/onnxruntime/test/providers/cpu/math/gemm_test.cc @@ -681,7 +681,6 @@ TYPED_TEST(GemmOpTypedTests, ZeroKWithNoBias) { .RunWithConfig(); } - TYPED_TEST(GemmOpTypedTests, MissingBias) { OpTester test("Gemm", 11); From aba1d52e79f612353aa0a0773ae135a6299099c3 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Wed, 18 Sep 2024 15:07:47 -0700 Subject: [PATCH 03/11] Make GCC happy with .template<>() --- .../contrib_ops/cpu/quantization/quant_gemm.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc b/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc index 6731d595f40f..dfffd9d939d0 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc @@ -22,19 +22,19 @@ void GemmBroadcastBiasScaleBackWithCast(int64_t M, int64_t N, const S* c_data, c output_mat.setConstant(constant); } else if (bias_shape.NumDimensions() == 1 || bias_shape[0] == 1) { // C is (N,) or (1, N) - output_mat.rowwise() = (ConstEigenVectorMap(c_data, N).transpose().cast() * + output_mat.rowwise() = (ConstEigenVectorMap(c_data, N).transpose().template cast() * a_scale * b_scale) - .cast(); + .template cast(); } else if (bias_shape[1] == 1) { // C is (M, 1) - output_mat.colwise() = (ConstEigenVectorMap(c_data, M).cast() * + output_mat.colwise() = (ConstEigenVectorMap(c_data, M).template cast() * a_scale * b_scale) - .cast(); + .template cast(); } else { // C is (M, N), no broadcast needed. - output_mat = (ConstEigenMatrixMapRowMajor(c_data, M, N).cast() * + output_mat = (ConstEigenMatrixMapRowMajor(c_data, M, N).template cast() * a_scale * b_scale) - .cast(); + .template cast(); } } From 5c336c777f53c495e605338466d5a25c5c311308 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Wed, 18 Sep 2024 16:56:20 -0700 Subject: [PATCH 04/11] Address data conversion issues --- .../cpu/quantization/quant_gemm.cc | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc b/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc index dfffd9d939d0..aee31c1fc9cd 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc @@ -13,7 +13,7 @@ namespace onnxruntime { namespace contrib { template -void GemmBroadcastBiasScaleBackWithCast(int64_t M, int64_t N, const S* c_data, const TensorShape& bias_shape, +void GemmBroadcastBiasScaleBackWithCast(Eigen::Index M, Eigen::Index N, const S* c_data, const TensorShape& bias_shape, T* output, float a_scale, float b_scale) { auto output_mat = EigenMatrixMapRowMajor(output, M, N); if (bias_shape.Size() == 1) { @@ -52,8 +52,8 @@ void GemmBroadcastBiasScaleBackWithCast(int64_t M, int64_t N, const S* c_data, c static void HandleZeroKCase(const Tensor& a_scale, const Tensor& b_scale, Tensor& y, const AllocatorPtr& allocator, const Tensor* y_scale, const Tensor* y_zp, const Tensor* bias) { const auto output_dims = y.Shape().GetDims(); - const int64_t M = output_dims[0]; - const int64_t N = output_dims[1]; + const auto M = narrow(output_dims[0]); + const auto N = narrow(output_dims[1]); const float a_scale_value = a_scale.Data()[0]; const float b_scale_value = b_scale.Data()[0]; @@ -80,17 +80,24 @@ static void HandleZeroKCase(const Tensor& a_scale, const Tensor& b_scale, Tensor auto q_params = quantization::GetTensorQuantizationParams(y_scale, y_zp); quantization::Quantize(scaled_back.Data(), reinterpret_cast(y.MutableDataRaw()), q_params, - scaled_back.Shape().Size()); + narrow(scaled_back.Shape().Size())); } else { auto q_params = quantization::GetTensorQuantizationParams(y_scale, y_zp); quantization::Quantize(scaled_back.Data(), reinterpret_cast(y.MutableDataRaw()), q_params, - scaled_back.Shape().Size()); + narrow(scaled_back.Shape().Size())); } } else { // Fill with y_zp - int32_t y_zp_value = y_zp->IsDataType() ? *(y_zp->Data()) : *(y_zp->Data()); - memset(y.MutableDataRaw(), y_zp_value, M * N); + if (y_zp->IsDataType()) { + int8_t* output = reinterpret_cast(y.MutableDataRaw()); + EigenMatrixMapRowMajor output_mat(output, M, N); + output_mat.setConstant(*(y_zp->Data())); + } else { + uint8_t* output = reinterpret_cast(y.MutableDataRaw()); + EigenMatrixMapRowMajor output_mat(output, M, N); + output_mat.setConstant(*(y_zp->Data())); + } } } } From 837fcf732ef2f5ef7d2cecfa42ae1749d9fa0799 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Wed, 18 Sep 2024 17:29:43 -0700 Subject: [PATCH 05/11] Implemenet clipping properly --- .../cpu/quantization/quant_gemm.cc | 69 ++++++++++++------- 1 file changed, 44 insertions(+), 25 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc b/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc index aee31c1fc9cd..0d36d5301e94 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc @@ -12,35 +12,55 @@ namespace onnxruntime { namespace contrib { -template -void GemmBroadcastBiasScaleBackWithCast(Eigen::Index M, Eigen::Index N, const S* c_data, const TensorShape& bias_shape, - T* output, float a_scale, float b_scale) { - auto output_mat = EigenMatrixMapRowMajor(output, M, N); +void GemmBroadcastBiasScaleBack(Eigen::Index M, Eigen::Index N, const int32_t* bias_data, + const TensorShape& bias_shape, + float* output, float a_scale, float b_scale) { + auto output_mat = EigenMatrixMapRowMajor(output, M, N); if (bias_shape.Size() == 1) { // C is (), (1,) or (1, 1), set the scalar - const auto constant = static_cast(static_cast(c_data[0]) * a_scale * b_scale); + const auto constant = static_cast(bias_data[0]) * a_scale * b_scale; output_mat.setConstant(constant); } else if (bias_shape.NumDimensions() == 1 || bias_shape[0] == 1) { // C is (N,) or (1, N) - output_mat.rowwise() = (ConstEigenVectorMap(c_data, N).transpose().template cast() * - a_scale * b_scale) - .template cast(); + output_mat.rowwise() = ConstEigenVectorMap(bias_data, N).transpose().cast() * + a_scale * b_scale; } else if (bias_shape[1] == 1) { // C is (M, 1) - output_mat.colwise() = (ConstEigenVectorMap(c_data, M).template cast() * - a_scale * b_scale) - .template cast(); + output_mat.colwise() = ConstEigenVectorMap(bias_data, M).cast() * + a_scale * b_scale; } else { // C is (M, N), no broadcast needed. - output_mat = (ConstEigenMatrixMapRowMajor(c_data, M, N).template cast() * - a_scale * b_scale) - .template cast(); + output_mat = ConstEigenMatrixMapRowMajor(bias_data, M, N).cast() * + a_scale * b_scale; + } +} + +void GemmBroadcastBiasAndClip(Eigen::Index M, Eigen::Index N, const int32_t* bias_data, + const TensorShape& bias_shape, uint8_t* output) { + auto clip = [](int32_t v) -> uint8_t { + return static_cast(v & 0xff); + }; + + auto output_mat = EigenMatrixMapRowMajor(output, M, N); + if (bias_shape.Size() == 1) { + // C is (), (1,) or (1, 1), set the scalar + const auto constant = clip(bias_data[0]); + output_mat.setConstant(constant); + } else if (bias_shape.NumDimensions() == 1 || bias_shape[0] == 1) { + // C is (N,) or (1, N) + output_mat.rowwise() = ConstEigenVectorMap(bias_data, N).transpose().unaryExpr(clip); + } else if (bias_shape[1] == 1) { + // C is (M, 1) + output_mat.colwise() = ConstEigenVectorMap(bias_data, M).unaryExpr(clip); + } else { + // C is (M, N), no broadcast needed. + output_mat = ConstEigenMatrixMapRowMajor(bias_data, M, N).unaryExpr(clip); } } /// /// This function will attempt to handle the case where K is zero while M and N are not -/// We need to fill the output either with zeros or with c_data if present. +/// We need to fill the output either with zeros or with bias_data if present. /// /// /// @@ -48,7 +68,7 @@ void GemmBroadcastBiasScaleBackWithCast(Eigen::Index M, Eigen::Index N, const S* /// /// /// -/// +/// static void HandleZeroKCase(const Tensor& a_scale, const Tensor& b_scale, Tensor& y, const AllocatorPtr& allocator, const Tensor* y_scale, const Tensor* y_zp, const Tensor* bias) { const auto output_dims = y.Shape().GetDims(); @@ -58,22 +78,21 @@ static void HandleZeroKCase(const Tensor& a_scale, const Tensor& b_scale, Tensor const float b_scale_value = b_scale.Data()[0]; if (y_zp == nullptr) { - // Either fill with c_data if present or 0 - int8_t* output = reinterpret_cast(y.MutableDataRaw()); + // Either fill with bias_data if present or 0 + uint8_t* output = reinterpret_cast(y.MutableDataRaw()); if (bias != nullptr) { - GemmBroadcastBiasScaleBackWithCast(M, N, bias->Data(), bias->Shape(), output, - 1.f, 1.f); + GemmBroadcastBiasAndClip(M, N, bias->Data(), bias->Shape(), output); } else { - EigenMatrixMapRowMajor output_mat(output, M, N); + EigenMatrixMapRowMajor output_mat(output, M, N); output_mat.setZero(); } } else { if (bias != nullptr) { - // scale c_data back to float with result = c_data * a_scale * b_scale. + // scale bias_data back to float with result = bias_data * a_scale * b_scale. Tensor scaled_back(DataTypeImpl::GetType(), y.Shape(), allocator); - GemmBroadcastBiasScaleBackWithCast(M, N, bias->Data(), bias->Shape(), - scaled_back.MutableData(), - a_scale_value, b_scale_value); + GemmBroadcastBiasScaleBack(M, N, bias->Data(), bias->Shape(), + scaled_back.MutableData(), + a_scale_value, b_scale_value); // re-quantize if (y_zp->IsDataType()) { From 0353d5f6d60408d2f805db537eb0d263515795e5 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Thu, 19 Sep 2024 10:47:57 -0700 Subject: [PATCH 06/11] Rework BiasBroadcast --- .../cpu/quantization/quant_gemm.cc | 75 +++++++------------ 1 file changed, 25 insertions(+), 50 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc b/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc index 0d36d5301e94..c6dc465c56b5 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc @@ -12,49 +12,23 @@ namespace onnxruntime { namespace contrib { -void GemmBroadcastBiasScaleBack(Eigen::Index M, Eigen::Index N, const int32_t* bias_data, - const TensorShape& bias_shape, - float* output, float a_scale, float b_scale) { - auto output_mat = EigenMatrixMapRowMajor(output, M, N); +template +void GemmBroadcastBiasAndApplyFn(Eigen::Index M, Eigen::Index N, const int32_t* bias_data, + const TensorShape& bias_shape, T* output, ApplyFn apply_fn) { + auto output_mat = EigenMatrixMapRowMajor(output, M, N); if (bias_shape.Size() == 1) { // C is (), (1,) or (1, 1), set the scalar - const auto constant = static_cast(bias_data[0]) * a_scale * b_scale; + const auto constant = apply_fn(bias_data[0]); output_mat.setConstant(constant); } else if (bias_shape.NumDimensions() == 1 || bias_shape[0] == 1) { // C is (N,) or (1, N) - output_mat.rowwise() = ConstEigenVectorMap(bias_data, N).transpose().cast() * - a_scale * b_scale; + output_mat.rowwise() = ConstEigenVectorMap(bias_data, N).transpose().unaryExpr(apply_fn); } else if (bias_shape[1] == 1) { // C is (M, 1) - output_mat.colwise() = ConstEigenVectorMap(bias_data, M).cast() * - a_scale * b_scale; + output_mat.colwise() = ConstEigenVectorMap(bias_data, M).unaryExpr(apply_fn); } else { // C is (M, N), no broadcast needed. - output_mat = ConstEigenMatrixMapRowMajor(bias_data, M, N).cast() * - a_scale * b_scale; - } -} - -void GemmBroadcastBiasAndClip(Eigen::Index M, Eigen::Index N, const int32_t* bias_data, - const TensorShape& bias_shape, uint8_t* output) { - auto clip = [](int32_t v) -> uint8_t { - return static_cast(v & 0xff); - }; - - auto output_mat = EigenMatrixMapRowMajor(output, M, N); - if (bias_shape.Size() == 1) { - // C is (), (1,) or (1, 1), set the scalar - const auto constant = clip(bias_data[0]); - output_mat.setConstant(constant); - } else if (bias_shape.NumDimensions() == 1 || bias_shape[0] == 1) { - // C is (N,) or (1, N) - output_mat.rowwise() = ConstEigenVectorMap(bias_data, N).transpose().unaryExpr(clip); - } else if (bias_shape[1] == 1) { - // C is (M, 1) - output_mat.colwise() = ConstEigenVectorMap(bias_data, M).unaryExpr(clip); - } else { - // C is (M, N), no broadcast needed. - output_mat = ConstEigenMatrixMapRowMajor(bias_data, M, N).unaryExpr(clip); + output_mat = ConstEigenMatrixMapRowMajor(bias_data, M, N).unaryExpr(apply_fn); } } @@ -74,25 +48,32 @@ static void HandleZeroKCase(const Tensor& a_scale, const Tensor& b_scale, Tensor const auto output_dims = y.Shape().GetDims(); const auto M = narrow(output_dims[0]); const auto N = narrow(output_dims[1]); - const float a_scale_value = a_scale.Data()[0]; - const float b_scale_value = b_scale.Data()[0]; if (y_zp == nullptr) { // Either fill with bias_data if present or 0 uint8_t* output = reinterpret_cast(y.MutableDataRaw()); if (bias != nullptr) { - GemmBroadcastBiasAndClip(M, N, bias->Data(), bias->Shape(), output); + auto clip_fn = [](uint32_t v) -> uint8_t { return static_cast(v & 0xFF); }; + GemmBroadcastBiasAndApplyFn(M, N, bias->Data(), + bias->Shape(), output, clip_fn); } else { EigenMatrixMapRowMajor output_mat(output, M, N); output_mat.setZero(); } } else { if (bias != nullptr) { + const float a_scale_value = a_scale.Data()[0]; + const float b_scale_value = b_scale.Data()[0]; + // scale bias_data back to float with result = bias_data * a_scale * b_scale. Tensor scaled_back(DataTypeImpl::GetType(), y.Shape(), allocator); - GemmBroadcastBiasScaleBack(M, N, bias->Data(), bias->Shape(), - scaled_back.MutableData(), - a_scale_value, b_scale_value); + + auto scale_back_fn = [a_scale_value, b_scale_value](int32_t v) -> float { + return static_cast(v) * a_scale_value * b_scale_value; + }; + + GemmBroadcastBiasAndApplyFn(M, N, bias->Data(), bias->Shape(), + scaled_back.MutableData(), scale_back_fn); // re-quantize if (y_zp->IsDataType()) { @@ -107,16 +88,10 @@ static void HandleZeroKCase(const Tensor& a_scale, const Tensor& b_scale, Tensor narrow(scaled_back.Shape().Size())); } } else { - // Fill with y_zp - if (y_zp->IsDataType()) { - int8_t* output = reinterpret_cast(y.MutableDataRaw()); - EigenMatrixMapRowMajor output_mat(output, M, N); - output_mat.setConstant(*(y_zp->Data())); - } else { - uint8_t* output = reinterpret_cast(y.MutableDataRaw()); - EigenMatrixMapRowMajor output_mat(output, M, N); - output_mat.setConstant(*(y_zp->Data())); - } + // We just fill out the output, does not matter singed or unsigned + int8_t* output = reinterpret_cast(y.MutableDataRaw()); + EigenMatrixMapRowMajor output_mat(output, M, N); + output_mat.setConstant(*(y_zp->Data())); } } } From c68f2da5b6d49d4f23aa7a568808edd300a49fea Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Thu, 19 Sep 2024 10:55:08 -0700 Subject: [PATCH 07/11] Fix zp fill out --- onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc b/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc index c6dc465c56b5..1880c83780f2 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc @@ -88,10 +88,10 @@ static void HandleZeroKCase(const Tensor& a_scale, const Tensor& b_scale, Tensor narrow(scaled_back.Shape().Size())); } } else { + const int32_t zp = (y_zp->IsDataType()) ? *(y_zp->Data()) : *(y_zp->Data()); // We just fill out the output, does not matter singed or unsigned int8_t* output = reinterpret_cast(y.MutableDataRaw()); - EigenMatrixMapRowMajor output_mat(output, M, N); - output_mat.setConstant(*(y_zp->Data())); + memset(output, zp, narrow(M * N)); } } } From 960179368e2bfd7f33168ebe727710180563ba8c Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Thu, 19 Sep 2024 11:01:08 -0700 Subject: [PATCH 08/11] Rework zp fill out --- .../contrib_ops/cpu/quantization/quant_gemm.cc | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc b/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc index 1880c83780f2..a5e8ad516f48 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc @@ -88,10 +88,13 @@ static void HandleZeroKCase(const Tensor& a_scale, const Tensor& b_scale, Tensor narrow(scaled_back.Shape().Size())); } } else { - const int32_t zp = (y_zp->IsDataType()) ? *(y_zp->Data()) : *(y_zp->Data()); - // We just fill out the output, does not matter singed or unsigned - int8_t* output = reinterpret_cast(y.MutableDataRaw()); - memset(output, zp, narrow(M * N)); + if (y_zp->IsDataType()) { + EigenMatrixMapRowMajor output_mat(reinterpret_cast(y.MutableDataRaw()), M, N); + output_mat.setConstant(*(y_zp->Data())); + } else { + EigenMatrixMapRowMajor output_mat(reinterpret_cast(y.MutableDataRaw()), M, N); + output_mat.setConstant(*(y_zp->Data())); + } } } } From 9ca55b4135136b62f0eb6f32e53cf5349efec5fe Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Thu, 19 Sep 2024 12:21:40 -0700 Subject: [PATCH 09/11] Make sure output is float with y_zp not present --- onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc b/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc index a5e8ad516f48..415589b1edc9 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc @@ -50,14 +50,16 @@ static void HandleZeroKCase(const Tensor& a_scale, const Tensor& b_scale, Tensor const auto N = narrow(output_dims[1]); if (y_zp == nullptr) { + // Because y_zp is not provided, the output is float32 // Either fill with bias_data if present or 0 - uint8_t* output = reinterpret_cast(y.MutableDataRaw()); + ORT_ENFORCE(y.SizeInBytes() == SafeInt(M) * N * sizeof(float), "Output must be sized for float"); + float* output = reinterpret_cast(y.MutableDataRaw()); if (bias != nullptr) { - auto clip_fn = [](uint32_t v) -> uint8_t { return static_cast(v & 0xFF); }; + auto to_float = [](uint32_t v) -> float { return static_cast(v); }; GemmBroadcastBiasAndApplyFn(M, N, bias->Data(), - bias->Shape(), output, clip_fn); + bias->Shape(), output, to_float); } else { - EigenMatrixMapRowMajor output_mat(output, M, N); + EigenMatrixMapRowMajor output_mat(output, M, N); output_mat.setZero(); } } else { From 09f42ecce6f2a0d436def815f6a22658db9821ad Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Thu, 19 Sep 2024 14:02:38 -0700 Subject: [PATCH 10/11] Account for b_scale being a vector of N --- .../cpu/quantization/quant_gemm.cc | 37 +++++++++++++------ .../test/contrib_ops/quant_gemm_test.cc | 15 ++++++++ 2 files changed, 41 insertions(+), 11 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc b/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc index 415589b1edc9..28a08bd5232e 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc @@ -52,7 +52,8 @@ static void HandleZeroKCase(const Tensor& a_scale, const Tensor& b_scale, Tensor if (y_zp == nullptr) { // Because y_zp is not provided, the output is float32 // Either fill with bias_data if present or 0 - ORT_ENFORCE(y.SizeInBytes() == SafeInt(M) * N * sizeof(float), "Output must be sized for float"); + ORT_ENFORCE(y.SizeInBytes() == SafeInt(M) * N * sizeof(float), + "Output must be sized for float"); float* output = reinterpret_cast(y.MutableDataRaw()); if (bias != nullptr) { auto to_float = [](uint32_t v) -> float { return static_cast(v); }; @@ -64,18 +65,32 @@ static void HandleZeroKCase(const Tensor& a_scale, const Tensor& b_scale, Tensor } } else { if (bias != nullptr) { - const float a_scale_value = a_scale.Data()[0]; - const float b_scale_value = b_scale.Data()[0]; - // scale bias_data back to float with result = bias_data * a_scale * b_scale. Tensor scaled_back(DataTypeImpl::GetType(), y.Shape(), allocator); - - auto scale_back_fn = [a_scale_value, b_scale_value](int32_t v) -> float { - return static_cast(v) * a_scale_value * b_scale_value; - }; - - GemmBroadcastBiasAndApplyFn(M, N, bias->Data(), bias->Shape(), - scaled_back.MutableData(), scale_back_fn); + const float a_scale_value = a_scale.Data()[0]; + const auto& b_shape = b_scale.Shape(); + if (b_shape.Size() == 1) { + // bscale is a scalar + const float b_scale_value = b_scale.Data()[0]; + auto scale_back_fn = [a_scale_value, b_scale_value](int32_t v) -> float { + return static_cast(v) * a_scale_value * b_scale_value; + }; + GemmBroadcastBiasAndApplyFn(M, N, bias->Data(), bias->Shape(), + scaled_back.MutableData(), scale_back_fn); + } else { + // b_scale is a 1-D tensor which should be the size of N + ORT_ENFORCE(b_shape[0] == N, "Length of b_scale is expected to be equal to N"); + const auto* b_scaled_data = b_scale.Data(); + Eigen::Index counter = 0; + auto scale_back_fn = [a_scale_value, b_scaled_data, N, counter](int32_t v) mutable -> float { + auto b_idx = counter++ % N; + return static_cast(v) * a_scale_value * b_scaled_data[b_idx]; + }; + + std::function fn{scale_back_fn}; + GemmBroadcastBiasAndApplyFn(M, N, bias->Data(), bias->Shape(), + scaled_back.MutableData(), fn); + } // re-quantize if (y_zp->IsDataType()) { diff --git a/onnxruntime/test/contrib_ops/quant_gemm_test.cc b/onnxruntime/test/contrib_ops/quant_gemm_test.cc index 3afcd6651aad..5549cd08f752 100644 --- a/onnxruntime/test/contrib_ops/quant_gemm_test.cc +++ b/onnxruntime/test/contrib_ops/quant_gemm_test.cc @@ -202,5 +202,20 @@ TEST(QuantGemmTest, GEMM) { RunQuantGemmTestBatch(4, 8, 68); } +TEST(QuantGemmTest, EmptyInputsNoBiasNoZp) { + OpTester test("QGemm", 1, onnxruntime::kMSDomain); + test.AddAttribute("transA", 0); + test.AddAttribute("transB", 0); + test.AddAttribute("alpha", 1.f); + + test.AddInput("A", {4, 0}, {}); + test.AddInput("a_scale", {}, {1.f}); + test.AddInput("a_zero_point", {}, {-1}); + + test.AddInput("B", {0, 4}, {}); + test.AddInput("b_scale", {}, {1.f}); + test.AddInput("b_zero_point", {}, {-1}); +} + } // namespace test } // namespace onnxruntime From af0d25e21c8349c3ad3912dc441b7ded314017c7 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Thu, 19 Sep 2024 16:39:39 -0700 Subject: [PATCH 11/11] Remove QGemm changes --- .../cpu/quantization/quant_gemm.cc | 115 +----------------- .../test/contrib_ops/quant_gemm_test.cc | 15 --- 2 files changed, 3 insertions(+), 127 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc b/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc index 28a08bd5232e..ff8ad090820b 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc @@ -12,110 +12,6 @@ namespace onnxruntime { namespace contrib { -template -void GemmBroadcastBiasAndApplyFn(Eigen::Index M, Eigen::Index N, const int32_t* bias_data, - const TensorShape& bias_shape, T* output, ApplyFn apply_fn) { - auto output_mat = EigenMatrixMapRowMajor(output, M, N); - if (bias_shape.Size() == 1) { - // C is (), (1,) or (1, 1), set the scalar - const auto constant = apply_fn(bias_data[0]); - output_mat.setConstant(constant); - } else if (bias_shape.NumDimensions() == 1 || bias_shape[0] == 1) { - // C is (N,) or (1, N) - output_mat.rowwise() = ConstEigenVectorMap(bias_data, N).transpose().unaryExpr(apply_fn); - } else if (bias_shape[1] == 1) { - // C is (M, 1) - output_mat.colwise() = ConstEigenVectorMap(bias_data, M).unaryExpr(apply_fn); - } else { - // C is (M, N), no broadcast needed. - output_mat = ConstEigenMatrixMapRowMajor(bias_data, M, N).unaryExpr(apply_fn); - } -} - -/// -/// This function will attempt to handle the case where K is zero while M and N are not -/// We need to fill the output either with zeros or with bias_data if present. -/// -/// -/// -/// -/// -/// -/// -/// -static void HandleZeroKCase(const Tensor& a_scale, const Tensor& b_scale, Tensor& y, const AllocatorPtr& allocator, - const Tensor* y_scale, const Tensor* y_zp, const Tensor* bias) { - const auto output_dims = y.Shape().GetDims(); - const auto M = narrow(output_dims[0]); - const auto N = narrow(output_dims[1]); - - if (y_zp == nullptr) { - // Because y_zp is not provided, the output is float32 - // Either fill with bias_data if present or 0 - ORT_ENFORCE(y.SizeInBytes() == SafeInt(M) * N * sizeof(float), - "Output must be sized for float"); - float* output = reinterpret_cast(y.MutableDataRaw()); - if (bias != nullptr) { - auto to_float = [](uint32_t v) -> float { return static_cast(v); }; - GemmBroadcastBiasAndApplyFn(M, N, bias->Data(), - bias->Shape(), output, to_float); - } else { - EigenMatrixMapRowMajor output_mat(output, M, N); - output_mat.setZero(); - } - } else { - if (bias != nullptr) { - // scale bias_data back to float with result = bias_data * a_scale * b_scale. - Tensor scaled_back(DataTypeImpl::GetType(), y.Shape(), allocator); - const float a_scale_value = a_scale.Data()[0]; - const auto& b_shape = b_scale.Shape(); - if (b_shape.Size() == 1) { - // bscale is a scalar - const float b_scale_value = b_scale.Data()[0]; - auto scale_back_fn = [a_scale_value, b_scale_value](int32_t v) -> float { - return static_cast(v) * a_scale_value * b_scale_value; - }; - GemmBroadcastBiasAndApplyFn(M, N, bias->Data(), bias->Shape(), - scaled_back.MutableData(), scale_back_fn); - } else { - // b_scale is a 1-D tensor which should be the size of N - ORT_ENFORCE(b_shape[0] == N, "Length of b_scale is expected to be equal to N"); - const auto* b_scaled_data = b_scale.Data(); - Eigen::Index counter = 0; - auto scale_back_fn = [a_scale_value, b_scaled_data, N, counter](int32_t v) mutable -> float { - auto b_idx = counter++ % N; - return static_cast(v) * a_scale_value * b_scaled_data[b_idx]; - }; - - std::function fn{scale_back_fn}; - GemmBroadcastBiasAndApplyFn(M, N, bias->Data(), bias->Shape(), - scaled_back.MutableData(), fn); - } - - // re-quantize - if (y_zp->IsDataType()) { - auto q_params = quantization::GetTensorQuantizationParams(y_scale, y_zp); - quantization::Quantize(scaled_back.Data(), - reinterpret_cast(y.MutableDataRaw()), q_params, - narrow(scaled_back.Shape().Size())); - } else { - auto q_params = quantization::GetTensorQuantizationParams(y_scale, y_zp); - quantization::Quantize(scaled_back.Data(), - reinterpret_cast(y.MutableDataRaw()), q_params, - narrow(scaled_back.Shape().Size())); - } - } else { - if (y_zp->IsDataType()) { - EigenMatrixMapRowMajor output_mat(reinterpret_cast(y.MutableDataRaw()), M, N); - output_mat.setConstant(*(y_zp->Data())); - } else { - EigenMatrixMapRowMajor output_mat(reinterpret_cast(y.MutableDataRaw()), M, N); - output_mat.setConstant(*(y_zp->Data())); - } - } - } -} - class QGemm : protected GemmBase, public MatMulIntegerBase { public: QGemm(const OpKernelInfo& info) : GemmBase(info), MatMulIntegerBase(info) { @@ -149,14 +45,6 @@ class QGemm : protected GemmBase, public MatMulIntegerBase { AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); - auto y = context->Output(OUT_Y, {M, N}); - if (M == 0 || N == 0) return Status::OK(); - - if (K == 0) { - HandleZeroKCase(*a_scale, *b_scale, *y, allocator, y_scale, y_zp, c); - return Status::OK(); - } - bool a_is_signed = a->IsDataType(); const uint8_t* a_data = static_cast(a->DataRaw()); @@ -179,6 +67,9 @@ class QGemm : protected GemmBase, public MatMulIntegerBase { } } + auto y = context->Output(OUT_Y, {M, N}); + if (M == 0 || N == 0) return Status::OK(); + // prepare output buffer of GEMM int32_t* gemm_output_data = nullptr; std::optional gemm_output_buffer; diff --git a/onnxruntime/test/contrib_ops/quant_gemm_test.cc b/onnxruntime/test/contrib_ops/quant_gemm_test.cc index 5549cd08f752..3afcd6651aad 100644 --- a/onnxruntime/test/contrib_ops/quant_gemm_test.cc +++ b/onnxruntime/test/contrib_ops/quant_gemm_test.cc @@ -202,20 +202,5 @@ TEST(QuantGemmTest, GEMM) { RunQuantGemmTestBatch(4, 8, 68); } -TEST(QuantGemmTest, EmptyInputsNoBiasNoZp) { - OpTester test("QGemm", 1, onnxruntime::kMSDomain); - test.AddAttribute("transA", 0); - test.AddAttribute("transB", 0); - test.AddAttribute("alpha", 1.f); - - test.AddInput("A", {4, 0}, {}); - test.AddInput("a_scale", {}, {1.f}); - test.AddInput("a_zero_point", {}, {-1}); - - test.AddInput("B", {0, 4}, {}); - test.AddInput("b_scale", {}, {1.f}); - test.AddInput("b_zero_point", {}, {-1}); -} - } // namespace test } // namespace onnxruntime