diff --git a/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc b/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc index aee31c1fc9cd..0d36d5301e94 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc @@ -12,35 +12,55 @@ namespace onnxruntime { namespace contrib { -template -void GemmBroadcastBiasScaleBackWithCast(Eigen::Index M, Eigen::Index N, const S* c_data, const TensorShape& bias_shape, - T* output, float a_scale, float b_scale) { - auto output_mat = EigenMatrixMapRowMajor(output, M, N); +void GemmBroadcastBiasScaleBack(Eigen::Index M, Eigen::Index N, const int32_t* bias_data, + const TensorShape& bias_shape, + float* output, float a_scale, float b_scale) { + auto output_mat = EigenMatrixMapRowMajor(output, M, N); if (bias_shape.Size() == 1) { // C is (), (1,) or (1, 1), set the scalar - const auto constant = static_cast(static_cast(c_data[0]) * a_scale * b_scale); + const auto constant = static_cast(bias_data[0]) * a_scale * b_scale; output_mat.setConstant(constant); } else if (bias_shape.NumDimensions() == 1 || bias_shape[0] == 1) { // C is (N,) or (1, N) - output_mat.rowwise() = (ConstEigenVectorMap(c_data, N).transpose().template cast() * - a_scale * b_scale) - .template cast(); + output_mat.rowwise() = ConstEigenVectorMap(bias_data, N).transpose().cast() * + a_scale * b_scale; } else if (bias_shape[1] == 1) { // C is (M, 1) - output_mat.colwise() = (ConstEigenVectorMap(c_data, M).template cast() * - a_scale * b_scale) - .template cast(); + output_mat.colwise() = ConstEigenVectorMap(bias_data, M).cast() * + a_scale * b_scale; } else { // C is (M, N), no broadcast needed. - output_mat = (ConstEigenMatrixMapRowMajor(c_data, M, N).template cast() * - a_scale * b_scale) - .template cast(); + output_mat = ConstEigenMatrixMapRowMajor(bias_data, M, N).cast() * + a_scale * b_scale; + } +} + +void GemmBroadcastBiasAndClip(Eigen::Index M, Eigen::Index N, const int32_t* bias_data, + const TensorShape& bias_shape, uint8_t* output) { + auto clip = [](int32_t v) -> uint8_t { + return static_cast(v & 0xff); + }; + + auto output_mat = EigenMatrixMapRowMajor(output, M, N); + if (bias_shape.Size() == 1) { + // C is (), (1,) or (1, 1), set the scalar + const auto constant = clip(bias_data[0]); + output_mat.setConstant(constant); + } else if (bias_shape.NumDimensions() == 1 || bias_shape[0] == 1) { + // C is (N,) or (1, N) + output_mat.rowwise() = ConstEigenVectorMap(bias_data, N).transpose().unaryExpr(clip); + } else if (bias_shape[1] == 1) { + // C is (M, 1) + output_mat.colwise() = ConstEigenVectorMap(bias_data, M).unaryExpr(clip); + } else { + // C is (M, N), no broadcast needed. + output_mat = ConstEigenMatrixMapRowMajor(bias_data, M, N).unaryExpr(clip); } } /// /// This function will attempt to handle the case where K is zero while M and N are not -/// We need to fill the output either with zeros or with c_data if present. +/// We need to fill the output either with zeros or with bias_data if present. /// /// /// @@ -48,7 +68,7 @@ void GemmBroadcastBiasScaleBackWithCast(Eigen::Index M, Eigen::Index N, const S* /// /// /// -/// +/// static void HandleZeroKCase(const Tensor& a_scale, const Tensor& b_scale, Tensor& y, const AllocatorPtr& allocator, const Tensor* y_scale, const Tensor* y_zp, const Tensor* bias) { const auto output_dims = y.Shape().GetDims(); @@ -58,22 +78,21 @@ static void HandleZeroKCase(const Tensor& a_scale, const Tensor& b_scale, Tensor const float b_scale_value = b_scale.Data()[0]; if (y_zp == nullptr) { - // Either fill with c_data if present or 0 - int8_t* output = reinterpret_cast(y.MutableDataRaw()); + // Either fill with bias_data if present or 0 + uint8_t* output = reinterpret_cast(y.MutableDataRaw()); if (bias != nullptr) { - GemmBroadcastBiasScaleBackWithCast(M, N, bias->Data(), bias->Shape(), output, - 1.f, 1.f); + GemmBroadcastBiasAndClip(M, N, bias->Data(), bias->Shape(), output); } else { - EigenMatrixMapRowMajor output_mat(output, M, N); + EigenMatrixMapRowMajor output_mat(output, M, N); output_mat.setZero(); } } else { if (bias != nullptr) { - // scale c_data back to float with result = c_data * a_scale * b_scale. + // scale bias_data back to float with result = bias_data * a_scale * b_scale. Tensor scaled_back(DataTypeImpl::GetType(), y.Shape(), allocator); - GemmBroadcastBiasScaleBackWithCast(M, N, bias->Data(), bias->Shape(), - scaled_back.MutableData(), - a_scale_value, b_scale_value); + GemmBroadcastBiasScaleBack(M, N, bias->Data(), bias->Shape(), + scaled_back.MutableData(), + a_scale_value, b_scale_value); // re-quantize if (y_zp->IsDataType()) {