From 54983d06d846042dfb5dc9488f86c6896b1c2787 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Thu, 19 Sep 2024 07:52:42 +0000 Subject: [PATCH 1/3] test --- python/sglang/srt/managers/tp_worker.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index fe9afc9f31..991b7cc0b4 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -214,6 +214,7 @@ def __init__( self.new_token_ratio = self.min_new_token_ratio self.new_token_ratio_decay = global_config.new_token_ratio_decay self.do_not_get_new_batch = False + self.new_batch = None def exposed_step(self, recv_reqs: List): try: @@ -248,13 +249,12 @@ def exposed_step(self, recv_reqs: List): @torch.inference_mode() def forward_step(self): - if self.do_not_get_new_batch and self.current_inflight_req is None: - new_batch = None - else: - new_batch = self.get_new_prefill_batch() - self.do_not_get_new_batch = False + if self.new_batch is None and self.running_batch is None: + self.new_batch = self.get_new_prefill_batch() - if new_batch is not None: + if self.new_batch is not None: + new_batch = self.new_batch + self.new_batch = None # Run a new prefill batch self.forward_prefill_batch(new_batch) @@ -541,12 +541,12 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]: self.tree_cache, ) self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list] - return new_batch - def forward_prefill_batch(self, batch: ScheduleBatch): # Build batch tensors - batch.prepare_for_extend(self.model_config.vocab_size) + new_batch.prepare_for_extend(self.model_config.vocab_size) + return new_batch + def forward_prefill_batch(self, batch: ScheduleBatch): decoding_reqs = [] if self.is_mixed_chunk and self.running_batch is not None: self.running_batch.prepare_for_decode() @@ -558,6 +558,8 @@ def forward_prefill_batch(self, batch: ScheduleBatch): # Forward and sample the next tokens if batch.extend_num_tokens != 0: logits_output = self.model_runner.forward(batch) + assert self.new_batch is None + self.new_batch = self.get_new_prefill_batch() next_token_ids = self.model_runner.sample(logits_output, batch) batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( @@ -623,6 +625,8 @@ def forward_prefill_batch(self, batch: ScheduleBatch): else: assert batch.extend_num_tokens != 0 logits_output = self.model_runner.forward(batch) + assert self.new_batch is None + self.new_batch = self.get_new_prefill_batch() embeddings = logits_output.embeddings.tolist() # Check finish conditions @@ -751,6 +755,8 @@ def forward_decode_batch(self, batch: ScheduleBatch): # Forward and sample the next tokens logits_output = self.model_runner.forward(batch) + assert self.new_batch is None + self.new_batch = self.get_new_prefill_batch() next_token_ids = self.model_runner.sample(logits_output, batch) batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( next_token_ids From 6cce0db750ba21c74537c6c429b27ee3660295d4 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Thu, 19 Sep 2024 07:58:08 +0000 Subject: [PATCH 2/3] update --- python/sglang/srt/managers/tp_worker.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 991b7cc0b4..24cdf0ed30 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -558,8 +558,8 @@ def forward_prefill_batch(self, batch: ScheduleBatch): # Forward and sample the next tokens if batch.extend_num_tokens != 0: logits_output = self.model_runner.forward(batch) - assert self.new_batch is None - self.new_batch = self.get_new_prefill_batch() + if self.new_batch is None: + self.new_batch = self.get_new_prefill_batch() next_token_ids = self.model_runner.sample(logits_output, batch) batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( @@ -625,8 +625,8 @@ def forward_prefill_batch(self, batch: ScheduleBatch): else: assert batch.extend_num_tokens != 0 logits_output = self.model_runner.forward(batch) - assert self.new_batch is None - self.new_batch = self.get_new_prefill_batch() + if self.new_batch is None: + self.new_batch = self.get_new_prefill_batch() embeddings = logits_output.embeddings.tolist() # Check finish conditions @@ -755,8 +755,8 @@ def forward_decode_batch(self, batch: ScheduleBatch): # Forward and sample the next tokens logits_output = self.model_runner.forward(batch) - assert self.new_batch is None - self.new_batch = self.get_new_prefill_batch() + if self.new_batch is None: + self.new_batch = self.get_new_prefill_batch() next_token_ids = self.model_runner.sample(logits_output, batch) batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( next_token_ids From 331e9540948bde2cefe9b7619a420d30a3fb332e Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Thu, 19 Sep 2024 20:58:22 +0000 Subject: [PATCH 3/3] fix --- python/sglang/srt/managers/policy_scheduler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/managers/policy_scheduler.py b/python/sglang/srt/managers/policy_scheduler.py index ada3904182..4826ac2153 100644 --- a/python/sglang/srt/managers/policy_scheduler.py +++ b/python/sglang/srt/managers/policy_scheduler.py @@ -251,6 +251,7 @@ def add_req_state(r, insert_sort=False): ) else: # Chunked prefill + return False trunc_len = self.rem_chunk_tokens req.extend_input_len = trunc_len req.fill_ids = req.fill_ids[:trunc_len]