Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a inplace concat custom op based on CUDA VMM API #9126

Draft
wants to merge 12 commits into
base: develop
Choose a base branch
from

Conversation

lszxb
Copy link
Contributor

@lszxb lszxb commented Sep 11, 2024

PR types

Performance optimization

PR changes

Others

Description

这一PR尝试为当前的大模型推理过程增加基于CUDA VMM API的inplace concat支持(原理类似于vAttention),从而避免在每一个解码步都复制一次整个KV Cache。
该功能暂时只实现了自定义算子,未来还需要增加相关的pass以自动适配其他模型。
目前这一PR在llama模型上应用了这一方案,在3072 input+1024 output的情况下大约有10%的提升。

目前主要的思路是:

  • 使用一种特殊的Tensor,其显存由VMM API分配,这种Tensor使用特殊的phi::Allocation,在创建时预留大量的虚拟地址空间,可以在必要时分配物理页映射到虚拟地址空间。
  • 为了兼容剩余的调用,cache的shape为batch x seq_len x num_head x head_dim,但由于状态在cache的尾部追加,cache的内存布局应该是seq_len x batch x num_head x head_dim。
    vtensor_reserve_one_token自定义算子的语义大致如下:
  • 如果key_cache不是VTensor,则新分配一个VTensor,并将原先key_cache中的数据复制到这个新的VTensor中。然后使用VTensor的扩展机制,在尾部预留新的一个token的空间,并将key_states复制到这个新的空间中。
  • 如果key_cache是VTensor,直接使用VTensor的扩展机制,在尾部预留新的一个token的空间,并将key_states复制到这个新的空间中。

目前可能存在的问题:

  • 仅支持每次追加1个token的空间。
  • 目前分配的虚拟地址空间大小和block大小为定值(1GiB与32MiB),可能暴露相关的API给用户进行调整会更好?
  • 输入和输出的key_cache共享同一块空间,在某些情况下可能会产生冲突。
  • 该方法依赖于每个step使用的kv cache是同一个Tensor,若有某些其他操作改变了kv cache的Tensor(比如说clone到另一个Tensor),则会导致失效,因此也需要配合这个PR的优化才可使用(assign_out_操作会导致复制)。
  • 通过该算子分配的显存无法使用现有的Allocator进行统一管理。
  • 未能测试HIP的支持情况。

Copy link

paddle-bot bot commented Sep 11, 2024

Thanks for your contribution!

Copy link

codecov bot commented Sep 16, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 53.26%. Comparing base (e340457) to head (9bca022).
Report is 3 commits behind head on develop.

Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #9126      +/-   ##
===========================================
- Coverage    53.26%   53.26%   -0.01%     
===========================================
  Files          652      652              
  Lines       105587   105588       +1     
===========================================
  Hits         56237    56237              
- Misses       49350    49351       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant