Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Persistent version of Flash Attention #2407

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

manman-ren
Copy link
Contributor

@manman-ren manman-ren commented Aug 2, 2024

Added two more variants: triton_tutorial_flash_v2_persistent and triton_tutorial_flash_v2_persistent_tma
The variants handle non-causal only. For causal, it has 2 invocations to attn_fwd_inner, which means we will have an outerloop and 2 inner loops
for ... # persistent loop
for ...
for ...
It is not clear how to flatten it into a 1D loop.

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

@triton.autotune(list(filter(keep, configs)), key=["N_CTX"])
@triton.jit
def _attn_fwd_persistent_tma(Q, Out, desc_q, desc_k, desc_v, sm_scale, M, desc_o, #
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a copy of _attn_fwd_persistent but with TMA changes?

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants