Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new Relax function to the batched model for evaluating query tokens over multiple time steps in parallel #156

Merged
merged 7 commits into from
Jan 13, 2024

Conversation

masahi
Copy link
Member

@masahi masahi commented Jan 10, 2024

In speculative decoding and restoring KV cache entries for evicted parallel-sampling requests, we need to be able to compute logits over multiple tokens (time steps) while utilizing the KV cache for the past tensors. This is a hybrid of prefill and decode functions, in that

  • prefill can compute logits over multiple tokens but doesn't read from KV cache
  • decode works on one token at a time.

I'm introducing a new function, tentatively called evaluate_multi_query, for this purpose. multi_query_decode is also a good name.

The changes in run_llama_batched_vllm.py shows a new request type and how the new function is meant to be used. There is no change under serve yet since it is purely a model change. After we agree on the approach, I'll integrate this new function into the engine to complete my parallel-sampling work. @yelite needs this for speculative decoding.

There is no attention kernel that reads from KV cache and operates on multiple queries, except FlashInfer which has BatchedPrefillWithKVCache. But we can emulate the behavior of such kernel by materializing past KV tensors from the cache, concat them with the present tensors, and running the standard prefill attention. This is not efficient but its correctness is much easier to verify. Until we integrate FlashInfer or Flash attention adds paged KV cache support, we can use this emulation.

@sunggg @yelite @elvin-n

@sunggg
Copy link
Member

sunggg commented Jan 11, 2024

Thank you for the PR, @masahi! Which tvm should I use to run this?
Also, would you move examples/python/run_llama_batched_vllm.py under the serve/ so that that can be a single folder for mlc-serve?

@masahi
Copy link
Member Author

masahi commented Jan 11, 2024

For now we need TVM from https://github.com/masahi/tvm/tree/vllm-cache-reconstruct. After apache/tvm#16376 is merged, I'll do a rebase.

would you move examples/python/run_llama_batched_vllm.py under the serve/ so that that can be a single folder for mlc-serve?

examples/python/run_llama_batched_vllm.py is not associated with mlc-serve. mlc-ai/main also has it. I added it to demonstrate how to use the batched llama model.

@masahi
Copy link
Member Author

masahi commented Jan 11, 2024

Opened #157 which uses the new Relax function from this PR to enable parallel-sampling eviction.

Copy link

@yelite yelite left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great and should be sufficient for the speculative decoding with draft model.

By the way, is it still necessary to keep decode after we have a good kernel for evaluate_multi_query? Will there be performance loss if we run evaluate_multi_query with one token from each sequence? If not, maybe we can just name this decode. Maybe we can even retire prefill if the kernel can specialize without degrading performance in the case where it doesn't need to read past KV from cache.

@masahi
Copy link
Member Author

masahi commented Jan 12, 2024

By the way, is it still necessary to keep decode after we have a good kernel for evaluate_multi_query? Will there be performance loss if we run evaluate_multi_query with one token from each sequence? If not, maybe we can just name this decode. Maybe we can even retire prefill if the kernel can specialize without degrading performance in the case where it doesn't need to read past KV from cache.

This is an interesting idea. I'd like to think that specialization allows perf advantages (decode kernel shouldn't parallelize over the query tokens, since that dimension is small). FlashInfer implements a dedicated kernel for batched decode while it also has BatchedPrefillWithKVCache. We have to measure and see.

The comparison is a bit subtle since moving from single query to multiple ones involves switching entirely different kernel implementations (vllm to flash attention / flash infer). So perf can be affected by any number of reasons besides the increase in the number of query tokens.

@masahi masahi merged commit 66a2e53 into octoml:batch-serving Jan 13, 2024
1 check passed
Lunderberg pushed a commit to Lunderberg/mlc-llm that referenced this pull request Jan 30, 2024
This PR reorganizes the artifact structure. We now have two separate
types of directories to store the libs/weights/..., with one "prebuilt"
directory which holds all the prebuilt libs and weights downloaded from
internet, and other model directories that are generated by local
builds.

CLI and test scripts are updated accordingly for this change.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants