diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 53842d88c3..c4c91c7112 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -429,7 +429,7 @@ def alloc_token_slots(self, num_tokens: int): def prepare_for_extend(self, vocab_size: int): self.forward_mode = ForwardMode.EXTEND - bs = self.batch_size() + bs = len(self.reqs) reqs = self.reqs input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs] extend_num_tokens = sum(len(ids) for ids in input_ids) @@ -509,7 +509,7 @@ def mix_with_running(self, running_batch: "ScheduleBatch"): self.extend_logprob_start_lens_cpu.extend([0] * running_bs) def check_decode_mem(self): - bs = self.batch_size() + bs = len(self.reqs) if self.token_to_kv_pool.available_size() >= bs: return True @@ -685,7 +685,7 @@ def prepare_for_decode(self, input_ids=None): self.seq_lens.add_(1) # Alloc mem - bs = self.batch_size() + bs = len(self.reqs) self.out_cache_loc = self.alloc_token_slots(bs) self.req_to_token_pool.req_to_token[