From f39a0197fdc39aeb468f8b51f7b9e7631a02bf98 Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Tue, 24 Sep 2024 22:50:31 -0700 Subject: [PATCH] Revert "kernel: use tensor cores for flashinfer gqa kernels" (#1511) --- python/sglang/srt/layers/attention_backend.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/layers/attention_backend.py b/python/sglang/srt/layers/attention_backend.py index 71dbfe0e3b..d7c1cf39d8 100644 --- a/python/sglang/srt/layers/attention_backend.py +++ b/python/sglang/srt/layers/attention_backend.py @@ -86,17 +86,9 @@ def __init__(self, model_runner: ModelRunner): super().__init__() self.model_runner = model_runner - local_num_qo_heads = ( - model_runner.model_config.num_attention_heads // model_runner.tp_size - ) - local_num_kv_heads = model_runner.model_config.get_num_kv_heads( - model_runner.tp_size - ) - if ( - not _grouped_size_compiled_for_decode_kernels( - local_num_qo_heads, local_num_kv_heads - ) - or local_num_qo_heads // local_num_kv_heads > 4 + if not _grouped_size_compiled_for_decode_kernels( + model_runner.model_config.num_attention_heads // model_runner.tp_size, + model_runner.model_config.get_num_kv_heads(model_runner.tp_size), ): self.decode_use_tensor_cores = True else: