Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removed RPB and LSE flags from template arguments in favor of runtime args #164

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

AdityaKane2001
Copy link

@alihassanijr

As discussed, removed RPB and LSE flags from template args and added them to runtime args (under struct Params). To make reviewing easier, I have only committed the kernel definition files and not the autogen'd files (the diff is too large to review).

Requesting review.

Comment on lines 283 to 284
assert(
(!has_rpb || !kHasCausalDims) && "Causal NA does not support RPB yet.");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move this to host side instead?

I highly prefer moving assertions to outside the kernel as much as possible, I've had some bad experiences in the past with device side asserts.

We might already have a check for it on host, but if we don't, just add an if statement and raise an error in kernel_forward.h. There should be some checks in the already that you can copy paste.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, please take a look if that's fine.

@@ -410,6 +410,25 @@ inline NA3dDim tuple_to_na_dim(std::tuple<int32_t, int32_t, int32_t> v) {
return NA3dDim(std::get<0>(v), std::get<1>(v), std::get<2>(v));
}

template <typename BoolTupleType>
bool bool_tuple_or(BoolTupleType tuple);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
bool bool_tuple_or(BoolTupleType tuple);
bool tuple_or(BoolTupleType tuple);

Nit: logical operators like or mostly make sense for boolean types, and in this case, only causal mask.

bool should_dump_lse = logsumexp_ptr != nullptr;

bool has_causal_dim = bool_tuple_or(is_causal);
assert((!has_rpb || !has_causal_dim) && "Causal NA does not support RPB yet.")
Copy link
Member

@alihassanijr alihassanijr Aug 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use NATTEN_CHECK like below so that the error is visible to the user, and more consistent with the rest of the code.

Comment on lines +42 to +43
#include <cassert>

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove?

Comment on lines +117 to +119
// removed in favor of runtime args
// bool kSupportsRPB_ = false,
// bool kStoresLSE_ = false
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just remove them.

Comment on lines +135 to +139

// replaced in favor of runtime args
// static constexpr bool kSupportsRPB = kSupportsRPB_;
// static constexpr bool kStoresLSE = kStoresLSE_;

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

@@ -179,6 +183,10 @@ struct FusedNeighborhoodAttentionKernel {
// [num_heads, num_queries_post_partitioning] - can be null
lse_scalar_t* logsumexp_ptr = nullptr;

// StoresLSE/SupportsRPB flags
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove

@@ -272,12 +280,11 @@ struct FusedNeighborhoodAttentionKernel {
output_ptr += (first_query * o_strideM).sum() +
(dilation_idx * o_stride_dilation).sum() + head_id * o_strideH;

if constexpr (kSupportsRPB) {
if (rpb_ptr != nullptr) {

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove extra line

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants