Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Address ZeroK case for Gemm for CPU and CUDA #22111

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
4 changes: 1 addition & 3 deletions cmake/onnxruntime_python.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ onnxruntime_add_shared_library_module(onnxruntime_pybind11_state ${onnxruntime_p

if(MSVC)
target_compile_options(onnxruntime_pybind11_state PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--compiler-options /utf-8>" "$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:/utf-8>")
if(onnxruntime_ENABLE_TRAINING)
target_compile_options(onnxruntime_pybind11_state PRIVATE "/bigobj")
endif()
target_compile_options(onnxruntime_pybind11_state PRIVATE "/bigobj")
snnn marked this conversation as resolved.
Show resolved Hide resolved
endif()
if(HAS_CAST_FUNCTION_TYPE)
target_compile_options(onnxruntime_pybind11_state PRIVATE "-Wno-cast-function-type")
Expand Down
115 changes: 112 additions & 3 deletions onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,110 @@
namespace onnxruntime {
namespace contrib {

template <class T, class ApplyFn>
void GemmBroadcastBiasAndApplyFn(Eigen::Index M, Eigen::Index N, const int32_t* bias_data,
const TensorShape& bias_shape, T* output, ApplyFn apply_fn) {
auto output_mat = EigenMatrixMapRowMajor<T>(output, M, N);
if (bias_shape.Size() == 1) {
// C is (), (1,) or (1, 1), set the scalar
const auto constant = apply_fn(bias_data[0]);
output_mat.setConstant(constant);
} else if (bias_shape.NumDimensions() == 1 || bias_shape[0] == 1) {
// C is (N,) or (1, N)
output_mat.rowwise() = ConstEigenVectorMap<int32_t>(bias_data, N).transpose().unaryExpr(apply_fn);
} else if (bias_shape[1] == 1) {
// C is (M, 1)
output_mat.colwise() = ConstEigenVectorMap<int32_t>(bias_data, M).unaryExpr(apply_fn);
} else {
// C is (M, N), no broadcast needed.
output_mat = ConstEigenMatrixMapRowMajor<int32_t>(bias_data, M, N).unaryExpr(apply_fn);
}
}

/// <summary>
/// This function will attempt to handle the case where K is zero while M and N are not
/// We need to fill the output either with zeros or with bias_data if present.
/// </summary>
/// <param name="a_scale"></param>
/// <param name="b_scale"></param>
/// <param name="y"></param>
/// <param name="allocator"></param>
/// <param name="y_scale"></param>
/// <param name="y_zp"></param>
/// <param name="bias_data"></param>
static void HandleZeroKCase(const Tensor& a_scale, const Tensor& b_scale, Tensor& y, const AllocatorPtr& allocator,
const Tensor* y_scale, const Tensor* y_zp, const Tensor* bias) {
const auto output_dims = y.Shape().GetDims();
const auto M = narrow<Eigen::Index>(output_dims[0]);
const auto N = narrow<Eigen::Index>(output_dims[1]);

if (y_zp == nullptr) {
// Because y_zp is not provided, the output is float32
// Either fill with bias_data if present or 0
ORT_ENFORCE(y.SizeInBytes() == SafeInt<size_t>(M) * N * sizeof(float),
"Output must be sized for float");
float* output = reinterpret_cast<float*>(y.MutableDataRaw());
if (bias != nullptr) {
auto to_float = [](uint32_t v) -> float { return static_cast<float>(v); };
GemmBroadcastBiasAndApplyFn(M, N, bias->Data<int32_t>(),
bias->Shape(), output, to_float);
} else {
EigenMatrixMapRowMajor<float> output_mat(output, M, N);
output_mat.setZero();
}
} else {
if (bias != nullptr) {
// scale bias_data back to float with result = bias_data * a_scale * b_scale.
Tensor scaled_back(DataTypeImpl::GetType<float>(), y.Shape(), allocator);
const float a_scale_value = a_scale.Data<float>()[0];
const auto& b_shape = b_scale.Shape();
if (b_shape.Size() == 1) {
// bscale is a scalar
const float b_scale_value = b_scale.Data<float>()[0];
auto scale_back_fn = [a_scale_value, b_scale_value](int32_t v) -> float {
return static_cast<float>(v) * a_scale_value * b_scale_value;
};
GemmBroadcastBiasAndApplyFn(M, N, bias->Data<int32_t>(), bias->Shape(),
scaled_back.MutableData<float>(), scale_back_fn);
} else {
// b_scale is a 1-D tensor which should be the size of N
ORT_ENFORCE(b_shape[0] == N, "Length of b_scale is expected to be equal to N");
const auto* b_scaled_data = b_scale.Data<float>();
Eigen::Index counter = 0;
auto scale_back_fn = [a_scale_value, b_scaled_data, N, counter](int32_t v) mutable -> float {
auto b_idx = counter++ % N;
return static_cast<float>(v) * a_scale_value * b_scaled_data[b_idx];
};

std::function<float(uint32_t)> fn{scale_back_fn};
GemmBroadcastBiasAndApplyFn(M, N, bias->Data<int32_t>(), bias->Shape(),
scaled_back.MutableData<float>(), fn);
}

// re-quantize
if (y_zp->IsDataType<int8_t>()) {
auto q_params = quantization::GetTensorQuantizationParams<int8_t>(y_scale, y_zp);
quantization::Quantize<int8_t>(scaled_back.Data<float>(),
reinterpret_cast<int8_t*>(y.MutableDataRaw()), q_params,
narrow<size_t>(scaled_back.Shape().Size()));
} else {
auto q_params = quantization::GetTensorQuantizationParams<uint8_t>(y_scale, y_zp);
quantization::Quantize<uint8_t>(scaled_back.Data<float>(),
reinterpret_cast<uint8_t*>(y.MutableDataRaw()), q_params,
narrow<size_t>(scaled_back.Shape().Size()));
}
} else {
if (y_zp->IsDataType<int8_t>()) {
EigenMatrixMapRowMajor<int8_t> output_mat(reinterpret_cast<int8_t*>(y.MutableDataRaw()), M, N);
output_mat.setConstant(*(y_zp->Data<int8_t>()));
} else {
EigenMatrixMapRowMajor<uint8_t> output_mat(reinterpret_cast<uint8_t*>(y.MutableDataRaw()), M, N);
output_mat.setConstant(*(y_zp->Data<uint8_t>()));
}
}
}
}

class QGemm : protected GemmBase, public MatMulIntegerBase {
public:
QGemm(const OpKernelInfo& info) : GemmBase(info), MatMulIntegerBase(info) {
Expand Down Expand Up @@ -45,6 +149,14 @@ class QGemm : protected GemmBase, public MatMulIntegerBase {
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));

auto y = context->Output(OUT_Y, {M, N});
if (M == 0 || N == 0) return Status::OK();

if (K == 0) {
HandleZeroKCase(*a_scale, *b_scale, *y, allocator, y_scale, y_zp, c);
return Status::OK();
}

bool a_is_signed = a->IsDataType<int8_t>();
const uint8_t* a_data = static_cast<const uint8_t*>(a->DataRaw());

Expand All @@ -67,9 +179,6 @@ class QGemm : protected GemmBase, public MatMulIntegerBase {
}
}

auto y = context->Output(OUT_Y, {M, N});
if (M == 0 || N == 0) return Status::OK();

// prepare output buffer of GEMM
int32_t* gemm_output_data = nullptr;
std::optional<Tensor> gemm_output_buffer;
Expand Down
59 changes: 37 additions & 22 deletions onnxruntime/core/providers/cpu/math/gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,14 @@ void Gemm<T>::ComputeGemm(CBLAS_TRANSPOSE trans_a, CBLAS_TRANSPOSE trans_b,
// Broadcast the bias as needed if bias is given
GemmBroadcastBias(M, N, beta, c_data, c_shape, y_data);

if (K == 0) {
if (beta == 0 || c_data == nullptr) {
auto output_span = gsl::make_span(y_data, SafeInt<size_t>(M) * N);
std::fill(output_span.begin(), output_span.end(), T{});
}
return;
}

math::Gemm<T>(trans_a, trans_b,
M, N, K,
alpha,
Expand All @@ -179,16 +187,18 @@ void Gemm<MLFloat16>::ComputeGemm(CBLAS_TRANSPOSE trans_a, CBLAS_TRANSPOSE trans
if (M == 0 || N == 0)
return;

#if defined(__GNUC__) && defined(HAS_CLASS_MEMACCESS)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wclass-memaccess"
#endif
// MLFloat16's constructor is explicit, so here we need to use memset
if (K == 0) {
if (beta != onnxruntime::MLFloat16::Zero && c_data != nullptr) {
GemmBroadcastBias(M, N, beta, c_data, c_shape, y_data);
} else {
auto output_span = gsl::make_span(y_data, SafeInt<size_t>(M) * N);
std::fill(output_span.begin(), output_span.end(), onnxruntime::MLFloat16::Zero);
}
return;
}

if (c_data == nullptr)
memset(&beta, 0, sizeof(MLFloat16));
#if defined(__GNUC__) && defined(HAS_CLASS_MEMACCESS)
#pragma GCC diagnostic pop
#endif
beta = onnxruntime::MLFloat16::Zero;
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
bool support_mlas = false;
if (c_shape == nullptr) {
Expand Down Expand Up @@ -413,19 +423,24 @@ Status Gemm<float>::Compute(OpKernelContext* context) const {
c_data, c_shape, y_data, thread_pool);
} else {
GemmBroadcastBias(M, N, beta_, c_data, c_shape, y_data);
MlasGemm(
trans_A_,
static_cast<size_t>(M),
static_cast<size_t>(N),
static_cast<size_t>(K),
alpha_,
A->Data<float>(),
static_cast<size_t>(trans_A_ != CblasNoTrans ? M : K),
packed_b_.get(),
c_data != nullptr ? beta_ : 0.0f,
y_data,
static_cast<size_t>(N),
thread_pool);
if (K > 0) {
MlasGemm(
trans_A_,
static_cast<size_t>(M),
static_cast<size_t>(N),
static_cast<size_t>(K),
alpha_,
A->Data<float>(),
static_cast<size_t>(trans_A_ != CblasNoTrans ? M : K),
packed_b_.get(),
c_data != nullptr ? beta_ : 0.0f,
y_data,
static_cast<size_t>(N),
thread_pool);
} else {
auto output_span = Y->MutableDataAsSpan<float>();
std::fill(output_span.begin(), output_span.end(), onnxruntime::MLFloat16::Zero);
}
}

ComputeActivation(y_data, SafeInt<size_t>(M) * N, thread_pool);
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/providers/cpu/math/gemm_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ class GemmHelper {
status_ = common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Gemm: Invalid bias shape for broadcast");

// it is possible the input is empty tensor, for example the output of roipool in fast rcnn.
ORT_ENFORCE(M_ >= 0 && K_ > 0 && N_ >= 0);
// it is also possible that K == 0
ORT_ENFORCE(M_ >= 0 && K_ >= 0 && N_ >= 0);
}

ptrdiff_t M() const { return M_; }
Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/core/providers/cuda/math/gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,16 @@ Status Gemm<T>::ComputeDefault(OpKernelContext* ctx, int M, int N, int K) const
}
}

if (K == 0) {
if (beta_ == 0 || B == nullptr) {
// When we have (M, 0, N) then the output should be filled out with zeros
// unless we have a bias
Fill<CudaT>(Stream(ctx), reinterpret_cast<CudaT*>(Y->MutableData<T>()), CudaT(0.f),
Y->Shape().Size());
}
return Status::OK();
}

CudaT alpha = ToCudaType<T>::FromFloat(alpha_);
CudaT beta = ToCudaType<T>::FromFloat(beta_);
// Gemm, note that CUDA assumes col-major, so Y(N,M) = alpha * op(W) x op(X) + beta * Y
Expand Down
15 changes: 15 additions & 0 deletions onnxruntime/test/contrib_ops/quant_gemm_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,5 +202,20 @@ TEST(QuantGemmTest, GEMM) {
RunQuantGemmTestBatch(4, 8, 68);
}

TEST(QuantGemmTest, EmptyInputsNoBiasNoZp) {
OpTester test("QGemm", 1, onnxruntime::kMSDomain);
test.AddAttribute<int64_t>("transA", 0);
test.AddAttribute<int64_t>("transB", 0);
test.AddAttribute<float>("alpha", 1.f);

test.AddInput<int8_t>("A", {4, 0}, {});
test.AddInput<float>("a_scale", {}, {1.f});
test.AddInput<int8_t>("a_zero_point", {}, {-1});

test.AddInput<int8_t>("B", {0, 4}, {});
test.AddInput<float>("b_scale", {}, {1.f});
test.AddInput<int8_t>("b_zero_point", {}, {-1});
}

} // namespace test
} // namespace onnxruntime
40 changes: 40 additions & 0 deletions onnxruntime/test/providers/cpu/math/gemm_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,46 @@ TYPED_TEST(GemmOpTypedTests, GemmEmptyTensor) {
.Config(run_with_tunable_op)
.RunWithConfig();
}

TYPED_TEST(GemmOpTypedTests, ZeroKWithBias) {
snnn marked this conversation as resolved.
Show resolved Hide resolved
OpTester test("Gemm", 13);

test.AddAttribute("transA", static_cast<int64_t>(0));
test.AddAttribute("transB", static_cast<int64_t>(0));
test.AddAttribute("alpha", 1.0f);
test.AddAttribute("beta", 1.0f);

test.AddInput<TypeParam>("A", {4, 0}, {});
test.AddInput<TypeParam>("B", {0, 4}, {});
test.AddInput<TypeParam>("C", {4}, std::vector<TypeParam>(4, static_cast<TypeParam>(1.0f)));
test.AddOutput<TypeParam>("Y", {4, 4}, std::vector<TypeParam>(16, static_cast<TypeParam>(1.0f)));

test.ConfigExcludeEps({kCoreMLExecutionProvider, kNnapiExecutionProvider,
kDmlExecutionProvider, kDnnlExecutionProvider, kQnnExecutionProvider,
kOpenVINOExecutionProvider})
.Config(run_with_tunable_op)
.RunWithConfig();
}

TYPED_TEST(GemmOpTypedTests, ZeroKWithNoBias) {
OpTester test("Gemm", 13);

test.AddAttribute("transA", static_cast<int64_t>(0));
test.AddAttribute("transB", static_cast<int64_t>(0));
test.AddAttribute("alpha", 1.0f);
test.AddAttribute("beta", .0f);

test.AddInput<TypeParam>("A", {4, 0}, {});
test.AddInput<TypeParam>("B", {0, 4}, {});
test.AddOutput<TypeParam>("Y", {4, 4}, std::vector<TypeParam>(16, static_cast<TypeParam>(0.0f)));

test.ConfigExcludeEps({kCoreMLExecutionProvider, kNnapiExecutionProvider,
kDmlExecutionProvider, kDnnlExecutionProvider, kQnnExecutionProvider,
kOpenVINOExecutionProvider})
.Config(run_with_tunable_op)
.RunWithConfig();
}

TYPED_TEST(GemmOpTypedTests, MissingBias) {
OpTester test("Gemm", 11);

Expand Down
Loading