Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[draft]GQA MLFloat16 cpu #22102

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 20 additions & 18 deletions onnxruntime/contrib_ops/cpu/bert/attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,33 +20,34 @@ template <typename T>
void ComputeSmoothSoftmaxInplace(T* score, int N, int D, ThreadPool* tp) {
ThreadPool::TryParallelFor(tp, N, D * 2.0, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t j = begin; j != end; ++j) {
float* x = reinterpret_cast<T*>(score) + j * D;
float* y = x;
T* x = reinterpret_cast<T*>(score) + j * D;
T* y = x;

float max = -std::numeric_limits<float>::infinity();
for (int i = 0; i < D; i++) {
if (max < x[i])
max = x[i];
float x_i = static_cast<float>(x[i]);
if (max < x_i)
max = x_i;
}

if (max < 0.0f) {
max = 0.0f;
}

for (int i = 0; i < D; i++) {
y[i] = expf(x[i] - max);
y[i] = static_cast<T>(expf(static_cast<float>(x[i]) - max));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when T is float16, it will overflow easily.
I think we cannot use softmax inplace for float16. A float buffer is needed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should the expf(x[i] - max) belong to (0, 1]?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. It will not overflow.
I think we need to keep intermediate data as float/double to keep the accuracy. Every time data casted from float to half will cause accuracy loss. The loss is accumulated when we compute the sum below.

}

double sum = 0.0;

for (int i = 0; i < D; i++) {
sum += x[i];
sum += static_cast<float>(x[i]);
}

sum += exp(static_cast<double>(-max));

for (int i = 0; i < D; i++) {
y[i] = x[i] / (float)sum;
y[i] = static_cast<T>(static_cast<float>(x[i]) / static_cast<float>(sum));
}
}
});
Expand All @@ -61,8 +62,8 @@ template <typename T>
void ComputeAttentionSoftmaxInplace(T* score, int N, int D, ThreadPool* tp) {
ThreadPool::TryParallelFor(tp, N, D * 2.0, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t j = begin; j != end; ++j) {
float* x = reinterpret_cast<T*>(score) + j * D;
float* y = x;
T* x = reinterpret_cast<T*>(score) + j * D;
T* y = x;

// e^x is represented as infinity if x is large enough, like 100.f.
// Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if
Expand All @@ -71,26 +72,27 @@ void ComputeAttentionSoftmaxInplace(T* score, int N, int D, ThreadPool* tp) {
// max) / (e^(x1 - max) + ... + e^(xn - max))
float max = -std::numeric_limits<float>::infinity();
for (int i = 0; i < D; i++) {
if (max < x[i])
max = x[i];
float x_i = static_cast<float>(x[i]);
if (max < x_i)
max = x_i;
}
for (int i = 0; i < D; i++) {
y[i] = expf(x[i] - max);
y[i] = static_cast<T>(expf(static_cast<float>(x[i]) - max));
}

double sum = 0.0;

for (int i = 0; i < D; i++) {
sum += x[i];
sum += static_cast<float>(x[i]);
}

if (sum == 0) {
for (int i = 0; i < D; i++) {
y[i] = 1.0f / (float)D;
y[i] = static_cast<T>(1.0f / static_cast<float>(D));
}
} else {
for (int i = 0; i < D; i++) {
y[i] = x[i] / (float)sum;
y[i] = static_cast<T>(x[i] / static_cast<float>(sum));
}
}
}
Expand All @@ -105,9 +107,9 @@ inline void ComputeAttentionSoftmaxInplace(float* score, int N, int D, ThreadPoo
template <typename T>
void ComputeAttentionSoftcapInplace(T* scores, int sequence_length, float softcap) {
for (int i = 0; i < sequence_length; i++) {
scores[i] = scores[i] / softcap;
scores[i] = std::tanh(scores[i]);
scores[i] = scores[i] * softcap;
float score = static_cast<float>(scores[i]) / softcap;
score = std::tanh(score);
scores[i] = static_cast<T>(score * softcap);
}
}

Expand Down
97 changes: 91 additions & 6 deletions onnxruntime/contrib_ops/cpu/bert/attention_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,21 @@
namespace onnxruntime {
namespace contrib {

namespace {
template <typename T>
struct EigenType;

template <>
struct EigenType<float> {
using Type = float;
};

template <>
struct EigenType<MLFloat16> {
using Type = Eigen::half;
};
}

Check warning on line 26 in onnxruntime/contrib_ops/cpu/bert/attention_utils.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Anonymous namespace should be terminated with "// namespace" [readability/namespace] [5] Raw Output: onnxruntime/contrib_ops/cpu/bert/attention_utils.cc:26: Anonymous namespace should be terminated with "// namespace" [readability/namespace] [5]

// Reshape Q/K/V from BxSxD to BxSxNxH
inline Status Reshape_BSD_to_BSNH(Tensor* qkv,
int batch_size,
Expand Down Expand Up @@ -48,13 +63,43 @@
constexpr size_t element_size = sizeof(T);
ProcessBroadcastSpanFuncs add_funcs{
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<T>() = per_iter_bh.ScalarInput0<T>() + per_iter_bh.EigenInput1<T>().array();
// per_iter_bh.OutputEigen<T>() = per_iter_bh.ScalarInput0<T>() + per_iter_bh.EigenInput1<T>().array();
auto num_elements = per_iter_bh.NumOutputElements();

const auto* input_1 = reinterpret_cast<const typename EigenType<T>::Type*>(per_iter_bh.EigenInput1<T>().data());
ConstEigenVectorArrayMap<typename EigenType<T>::Type> input_1_vec_map(input_1, num_elements);

auto* output = reinterpret_cast<typename EigenType<T>::Type*>(per_iter_bh.OutputEigen<T>().data());
EigenVectorArrayMap<typename EigenType<T>::Type> output_vec_map(output, num_elements);

output_vec_map = input_1_vec_map + static_cast<typename EigenType<T>::Type>(per_iter_bh.ScalarInput0<T>());
},
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>().array() + per_iter_bh.ScalarInput1<T>();
// per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>().array() + per_iter_bh.ScalarInput1<T>();
auto num_elements = per_iter_bh.NumOutputElements();

const auto* input_0 = reinterpret_cast<const typename EigenType<T>::Type*>(per_iter_bh.EigenInput0<T>().data());
ConstEigenVectorArrayMap<typename EigenType<T>::Type> input_0_vec_map(input_0, num_elements);

auto* output = reinterpret_cast<typename EigenType<T>::Type*>(per_iter_bh.OutputEigen<T>().data());
EigenVectorArrayMap<typename EigenType<T>::Type> output_vec_map(output, num_elements);

output_vec_map = input_0_vec_map + static_cast<typename EigenType<T>::Type>(per_iter_bh.ScalarInput1<T>());
},
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>() + per_iter_bh.EigenInput1<T>();
// per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>() + per_iter_bh.EigenInput1<T>();
auto num_elements = per_iter_bh.NumOutputElements();

const auto* input_0 = reinterpret_cast<const typename EigenType<T>::Type*>(per_iter_bh.EigenInput0<T>().data());
ConstEigenVectorArrayMap<typename EigenType<T>::Type> input_0_vec_map(input_0, num_elements);

const auto* input_1 = reinterpret_cast<const typename EigenType<T>::Type*>(per_iter_bh.EigenInput1<T>().data());
ConstEigenVectorArrayMap<typename EigenType<T>::Type> input_1_vec_map(input_1, num_elements);

auto* output = reinterpret_cast<typename EigenType<T>::Type*>(per_iter_bh.OutputEigen<T>().data());
EigenVectorArrayMap<typename EigenType<T>::Type> output_vec_map(output, num_elements);

output_vec_map = input_0_vec_map + input_1_vec_map;
}}; // For element-wise add

// Allocate space for output of Q(BS, D) + bias(D)
Expand Down Expand Up @@ -114,6 +159,7 @@
return Status::OK();
}


// Add bias + reshape for each of Q/K/V
// This is used in decoder_with_past when the sequence length is 1
template <typename T>
Expand All @@ -129,16 +175,47 @@
OpKernelContext* context) {
// Note: the comments below will refer to Q's dimensions for simplicity
auto element_type = DataTypeImpl::GetType<T>();
using eigen_type = typename EigenType<T>::Type;
constexpr size_t element_size = sizeof(T);
ProcessBroadcastSpanFuncs add_funcs{
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<T>() = per_iter_bh.ScalarInput0<T>() + per_iter_bh.EigenInput1<T>().array();
//per_iter_bh.OutputEigen<T>() = per_iter_bh.ScalarInput0<T>() + per_iter_bh.EigenInput1<T>().array();

Check warning on line 182 in onnxruntime/contrib_ops/cpu/bert/attention_utils.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Should have a space between // and comment [whitespace/comments] [4] Raw Output: onnxruntime/contrib_ops/cpu/bert/attention_utils.cc:182: Should have a space between // and comment [whitespace/comments] [4]
auto num_elements = per_iter_bh.NumOutputElements();

const auto* input_1 = reinterpret_cast<const eigen_type*>(per_iter_bh.EigenInput1<T>().data());
ConstEigenVectorArrayMap<eigen_type> input_1_vec_map(input_1, num_elements);

auto* output = reinterpret_cast<eigen_type*>(per_iter_bh.OutputEigen<T>().data());
EigenVectorArrayMap<eigen_type> output_vec_map(output, num_elements);

output_vec_map = input_1_vec_map + static_cast<eigen_type>(per_iter_bh.ScalarInput0<T>());
},
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>().array() + per_iter_bh.ScalarInput1<T>();
// per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>().array() + per_iter_bh.ScalarInput1<T>();
auto num_elements = per_iter_bh.NumOutputElements();

const auto* input_0 = reinterpret_cast<const eigen_type*>(per_iter_bh.EigenInput0<T>().data());
ConstEigenVectorArrayMap<eigen_type> input_0_vec_map(input_0, num_elements);

auto* output = reinterpret_cast<eigen_type*>(per_iter_bh.OutputEigen<T>().data());
EigenVectorArrayMap<eigen_type> output_vec_map(output, num_elements);

output_vec_map = input_0_vec_map + static_cast<eigen_type>(per_iter_bh.ScalarInput1<T>());
},
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>() + per_iter_bh.EigenInput1<T>();
// per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>() + per_iter_bh.EigenInput1<T>();
auto num_elements = per_iter_bh.NumOutputElements();

const auto* input_0 = reinterpret_cast<const eigen_type*>(per_iter_bh.EigenInput0<T>().data());
ConstEigenVectorArrayMap<eigen_type> input_0_vec_map(input_0, num_elements);

const auto* input_1 = reinterpret_cast<const eigen_type*>(per_iter_bh.EigenInput1<T>().data());
ConstEigenVectorArrayMap<eigen_type> input_1_vec_map(input_1, num_elements);

auto* output = reinterpret_cast<eigen_type*>(per_iter_bh.OutputEigen<T>().data());
EigenVectorArrayMap<eigen_type> output_vec_map(output, num_elements);

output_vec_map = input_0_vec_map + input_1_vec_map;
}}; // For element-wise add

// Get Q's bias from combined bias
Expand Down Expand Up @@ -219,6 +296,10 @@
int batch_size, int num_heads, int sequence_length, int head_size,
const Tensor* in, const Tensor* bias, int bias_offset, OrtValue& out);

template Status MaybeTransposeToBNSHAndAddBias<MLFloat16>(OpKernelContext* context, AllocatorPtr allocator,
int batch_size, int num_heads, int sequence_length, int head_size,

Check warning on line 300 in onnxruntime/contrib_ops/cpu/bert/attention_utils.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cpu/bert/attention_utils.cc:300: Lines should be <= 120 characters long [whitespace/line_length] [2]
const Tensor* in, const Tensor* bias, int bias_offset, OrtValue& out);

Check warning on line 301 in onnxruntime/contrib_ops/cpu/bert/attention_utils.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cpu/bert/attention_utils.cc:301: Lines should be <= 120 characters long [whitespace/line_length] [2]

template <typename T>
Status MaybeTransposeToBNSH(AllocatorPtr allocator,
int batch_size, int num_heads, int sequence_length, int head_size,
Expand All @@ -242,5 +323,9 @@
int batch_size, int num_heads, int sequence_length, int head_size,
const Tensor* in, OrtValue& out);

template Status MaybeTransposeToBNSH<MLFloat16>(AllocatorPtr allocator,
int batch_size, int num_heads, int sequence_length, int head_size,
const Tensor* in, OrtValue& out);

} // namespace contrib
} // namespace onnxruntime
16 changes: 8 additions & 8 deletions onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@
const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H

if (!past_present_share_buffer) {
memset(present_key, 0, batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T));
memset((void*)present_key, 0, batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T));

Check warning on line 134 in onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h:134: Lines should be <= 120 characters long [whitespace/line_length] [2]

Check warning on line 134 in onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Using C-style cast. Use reinterpret_cast<void*>(...) instead [readability/casting] [4] Raw Output: onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h:134: Using C-style cast. Use reinterpret_cast<void*>(...) instead [readability/casting] [4]
}

const size_t loop_len = batch_size * num_heads_;
Expand Down Expand Up @@ -190,8 +190,8 @@
q = Q + q_input_chunk_length * i;
}

math::GemmEx<T, ThreadPool>(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q,
static_cast<int>(head_size), k, static_cast<int>(head_size), 0.0f /*bata*/, output,
math::GemmEx<T, ThreadPool>(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, static_cast<T>(alpha), q,

Check warning on line 193 in onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h:193: Lines should be <= 120 characters long [whitespace/line_length] [2]
static_cast<int>(head_size), k, static_cast<int>(head_size), static_cast<T>(0.0f) /*bata*/, output,

Check warning on line 194 in onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h:194: Lines should be <= 120 characters long [whitespace/line_length] [2]
static_cast<int>(present_buffer_sequence_length), nullptr);

// compute Softmax
Expand All @@ -200,7 +200,7 @@
size_t seq_causal_length = past_seqlen + seq + 1;
if (local_window_size_ > 0 && seq_causal_length > static_cast<size_t>(local_window_size_) + 1) {
for (size_t total_seq_id = 0; total_seq_id < seq_causal_length - local_window_size_ - 1; total_seq_id++) {
output_softmax[total_seq_id] = 0.f;
output_softmax[total_seq_id] = static_cast<T>(0.f);
}
if (softcap_ > 0.f) {
ComputeAttentionSoftcapInplace(output_softmax + seq_causal_length - local_window_size_ - 1,
Expand All @@ -226,7 +226,7 @@

// set causal [seq_causal_length, total_seqlen) to 0.f
for (size_t total_seq_id = seq_causal_length; total_seq_id < total_seqlen; total_seq_id++) {
output_softmax[total_seq_id] = 0.f;
output_softmax[total_seq_id] = static_cast<T>(0.f);
}

output_softmax += present_buffer_sequence_length;
Expand Down Expand Up @@ -261,7 +261,7 @@
const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H

if (!past_present_share_buffer) {
memset(present_value, 0, batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T));
memset((void*)present_value, 0, batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T));

Check warning on line 264 in onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h:264: Lines should be <= 120 characters long [whitespace/line_length] [2]

Check warning on line 264 in onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Using C-style cast. Use reinterpret_cast<void*>(...) instead [readability/casting] [4] Raw Output: onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h:264: Using C-style cast. Use reinterpret_cast<void*>(...) instead [readability/casting] [4]
}

const size_t loop_len = batch_size * num_heads_;
Expand Down Expand Up @@ -308,10 +308,10 @@
T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size;
ptrdiff_t attention_probs_offset = SafeInt<ptrdiff_t>(sequence_length) * present_buffer_sequence_length * i;

math::GemmEx<T, ThreadPool>(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen, 1.f, /*alpha*/
math::GemmEx<T, ThreadPool>(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen, static_cast<T>(1.f), /*alpha*/
attention_probs + attention_probs_offset,
static_cast<int>(present_buffer_sequence_length), v, static_cast<int>(head_size),
0.0f /*beta*/, output_current, static_cast<int>(hidden_size), nullptr);
static_cast<T>(0.0f) /*beta*/, output_current, static_cast<int>(hidden_size), nullptr);
}
});
}
Expand Down
24 changes: 14 additions & 10 deletions onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,20 @@ namespace onnxruntime {
namespace contrib {

// These ops are internal-only, so register outside of onnx
ONNX_OPERATOR_TYPED_KERNEL_EX(
GroupQueryAttention,
kMSDomain,
1,
float,
kCpuExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("M", DataTypeImpl::GetTensorType<int32_t>()),
GroupQueryAttention<float>);
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
GroupQueryAttention, \
kMSDomain, \
1, \
T, \
kCpuExecutionProvider, \
KernelDefBuilder() \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("M", DataTypeImpl::GetTensorType<int32_t>()), \
GroupQueryAttention<T>);

REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)

template <typename T>
GroupQueryAttention<T>::GroupQueryAttention(const OpKernelInfo& info)
Expand Down
Loading
Loading