Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Debug schedule optimization #1465

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/sglang/srt/managers/policy_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ def add_req_state(r, insert_sort=False):
)
else:
# Chunked prefill
return False
trunc_len = self.rem_chunk_tokens
req.extend_input_len = trunc_len
req.fill_ids = req.fill_ids[:trunc_len]
Expand Down
24 changes: 15 additions & 9 deletions python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def __init__(
self.new_token_ratio = self.min_new_token_ratio
self.new_token_ratio_decay = global_config.new_token_ratio_decay
self.do_not_get_new_batch = False
self.new_batch = None

def exposed_step(self, recv_reqs: List):
try:
Expand Down Expand Up @@ -248,13 +249,12 @@ def exposed_step(self, recv_reqs: List):

@torch.inference_mode()
def forward_step(self):
if self.do_not_get_new_batch and self.current_inflight_req is None:
new_batch = None
else:
new_batch = self.get_new_prefill_batch()
self.do_not_get_new_batch = False
if self.new_batch is None and self.running_batch is None:
self.new_batch = self.get_new_prefill_batch()

if new_batch is not None:
if self.new_batch is not None:
new_batch = self.new_batch
self.new_batch = None
# Run a new prefill batch
self.forward_prefill_batch(new_batch)

Expand Down Expand Up @@ -541,12 +541,12 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
self.tree_cache,
)
self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list]
return new_batch

def forward_prefill_batch(self, batch: ScheduleBatch):
# Build batch tensors
batch.prepare_for_extend(self.model_config.vocab_size)
new_batch.prepare_for_extend(self.model_config.vocab_size)
return new_batch

def forward_prefill_batch(self, batch: ScheduleBatch):
decoding_reqs = []
if self.is_mixed_chunk and self.running_batch is not None:
self.running_batch.prepare_for_decode()
Expand All @@ -558,6 +558,8 @@ def forward_prefill_batch(self, batch: ScheduleBatch):
# Forward and sample the next tokens
if batch.extend_num_tokens != 0:
logits_output = self.model_runner.forward(batch)
if self.new_batch is None:
self.new_batch = self.get_new_prefill_batch()
next_token_ids = self.model_runner.sample(logits_output, batch)

batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
Expand Down Expand Up @@ -623,6 +625,8 @@ def forward_prefill_batch(self, batch: ScheduleBatch):
else:
assert batch.extend_num_tokens != 0
logits_output = self.model_runner.forward(batch)
if self.new_batch is None:
self.new_batch = self.get_new_prefill_batch()
embeddings = logits_output.embeddings.tolist()

# Check finish conditions
Expand Down Expand Up @@ -751,6 +755,8 @@ def forward_decode_batch(self, batch: ScheduleBatch):

# Forward and sample the next tokens
logits_output = self.model_runner.forward(batch)
if self.new_batch is None:
self.new_batch = self.get_new_prefill_batch()
next_token_ids = self.model_runner.sample(logits_output, batch)
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids
Expand Down
Loading