Skip to content

Commit

Permalink
Merge branch 'main' into catch-syntax-error
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Sep 28, 2024
2 parents e5c8f9b + 63e845d commit 14a76b9
Show file tree
Hide file tree
Showing 15 changed files with 497 additions and 45 deletions.
9 changes: 9 additions & 0 deletions .github/workflows/pr-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ jobs:
run: |
pip install --upgrade pip
pip install -e "python[dev]"
pip install transformers==4.44
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
- name: Run test
Expand All @@ -48,6 +49,7 @@ jobs:
run: |
pip install --upgrade pip
pip install -e "python[dev]"
pip install transformers==4.44
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
- name: Run test
Expand All @@ -67,6 +69,7 @@ jobs:
run: |
pip install --upgrade pip
pip install -e "python[dev]"
pip install transformers==4.44
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
- name: Run test
Expand All @@ -86,6 +89,7 @@ jobs:
run: |
pip install --upgrade pip
pip install -e "python[dev]"
pip install transformers==4.44
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
- name: Run test
Expand All @@ -105,6 +109,7 @@ jobs:
run: |
pip install --upgrade pip
pip install -e "python[all]"
pip install transformers==4.44
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
- name: Benchmark Single Latency
Expand Down Expand Up @@ -136,6 +141,7 @@ jobs:
run: |
pip install --upgrade pip
pip install -e "python[all]"
pip install transformers==4.44
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
- name: Benchmark Offline Throughput (w/o RadixAttention)
Expand Down Expand Up @@ -167,6 +173,7 @@ jobs:
run: |
pip install --upgrade pip
pip install -e "python[all]"
pip install transformers==4.44
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
- name: Benchmark Offline Throughput (TP=2)
Expand Down Expand Up @@ -198,6 +205,7 @@ jobs:
run: |
pip install --upgrade pip
pip install -e "python[all]"
pip install transformers==4.44
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
git clone https://github.com/merrymercy/human-eval.git
Expand All @@ -221,6 +229,7 @@ jobs:
run: |
pip install --upgrade pip
pip install -e "python[all]"
pip install transformers==4.44
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
git clone https://github.com/merrymercy/human-eval.git
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

--------------------------------------------------------------------------------

| [**Blog**](https://lmsys.org/blog/2024-07-25-sglang-llama3/) | [**Paper**](https://arxiv.org/abs/2312.07104) | [**Join Slack**](https://join.slack.com/t/sgl-fru7574/shared_invite/zt-2ngly9muu-t37XiH87qvD~6rVBTkTEHw) | [**Join Weekly Development Meeting**](https://calendar.app.google/v2Tw3kuHkKYyp8VV7) |
| [**Blog**](https://lmsys.org/blog/2024-07-25-sglang-llama3/) | [**Paper**](https://arxiv.org/abs/2312.07104) | [**Join Slack**](https://join.slack.com/t/sgl-fru7574/shared_invite/zt-2ngly9muu-t37XiH87qvD~6rVBTkTEHw) | [**Join Weekly Development Meeting**](https://t.co/4BFjCLnVHq) |

SGLang is a fast serving framework for large language models and vision language models.
It makes your interaction with models faster and more controllable by co-designing the backend runtime and frontend language.
Expand Down
34 changes: 34 additions & 0 deletions examples/runtime/reward_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# launch server
# python -m sglang.launch_server --model LxzGordon/URM-LLaMa-3.1-8B --is-embedding

import json

import requests

url = "http://127.0.0.1:30000"

PROMPT = (
"What is the range of the numeric output of a sigmoid node in a neural network?"
)
RESPONSE1 = "The output of a sigmoid node is bounded between -1 and 1."
RESPONSE2 = "The output of a sigmoid node is bounded between 0 and 1."

json_data = {
"conv": [
[
{"role": "user", "content": PROMPT},
{"role": "assistant", "content": RESPONSE1},
],
[
{"role": "user", "content": PROMPT},
{"role": "assistant", "content": RESPONSE2},
],
],
}
response = requests.post(
url + "/judge",
json=json_data,
).json()

print(response)
print("scores:", [x["embedding"] for x in response])
18 changes: 18 additions & 0 deletions python/sglang/srt/layers/torchao_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
"""
# Lazy import to suppress some warnings
from torchao.quantization import (
float8_dynamic_activation_float8_weight,
int4_weight_only,
int8_dynamic_activation_int8_weight,
int8_weight_only,
quantize_,
)
from torchao.quantization.observer import PerRow, PerTensor

dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False)
dummy_linear.weight = param
Expand All @@ -45,6 +47,22 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
# this requires newer hardware
# [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
quantize_(dummy_linear, float8_weight_only())
elif "fp8dq" in torchao_config:
granularity = torchao_config.split("-")[-1]
GRANULARITY_MAP = {
"per_row": PerRow(),
"per_tensor": PerTensor(),
}
assert (
granularity in GRANULARITY_MAP
), f"Supported granularity are: {GRANULARITY_MAP.keys()}, got {granularity}"
quantize_(
dummy_linear,
float8_dynamic_activation_float8_weight(
granularity=GRANULARITY_MAP[granularity]
),
)

return dummy_linear.weight


Expand Down
53 changes: 49 additions & 4 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,12 +215,11 @@ def post_init(self):
raise ValueError("Either text or input_ids should be provided.")

if self.text is not None:
is_single = isinstance(self.text, str)
self.is_single = isinstance(self.text, str)
else:
is_single = isinstance(self.input_ids[0], int)
self.is_single = is_single
self.is_single = isinstance(self.input_ids[0], int)

if is_single:
if self.is_single:
if self.rid is None:
self.rid = uuid.uuid4().hex
if self.sampling_params is None:
Expand Down Expand Up @@ -254,6 +253,52 @@ class TokenizedEmbeddingReqInput:
sampling_params: SamplingParams


@dataclass
class RewardReqInput:
# The input prompt in the chat format. It can be a single prompt or a batch of prompts.
conv: Union[List[List[Dict]], List[Dict]]
# The request id.
rid: Optional[Union[List[str], str]] = None
# Dummy sampling params for compatibility
sampling_params: Union[List[Dict], Dict] = None

is_single: bool = True

def post_init(self):
self.is_single = isinstance(self.conv[0], dict)

if self.is_single:
if self.rid is None:
self.rid = uuid.uuid4().hex
if self.sampling_params is None:
self.sampling_params = {}
self.sampling_params["max_new_tokens"] = 1
else:
# support select operation
self.batch_size = len(self.conv)
if self.rid is None:
self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
else:
if not isinstance(self.rid, list):
raise ValueError("The rid should be a list.")
if self.sampling_params is None:
self.sampling_params = [{}] * self.batch_size
for i in range(self.batch_size):
self.sampling_params[i]["max_new_tokens"] = 1


@dataclass
class TokenizedRewardReqInput:
# The request id
rid: str
# The input text
input_text: str
# The input token ids
input_ids: List[int]
# Dummy sampling params for compatibility
sampling_params: SamplingParams


@dataclass
class BatchTokenIDOut:
# The request id
Expand Down
51 changes: 41 additions & 10 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@
EmbeddingReqInput,
FlushCacheReq,
GenerateReqInput,
RewardReqInput,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
TokenizedRewardReqInput,
UpdateWeightReqInput,
UpdateWeightReqOutput,
)
Expand Down Expand Up @@ -142,7 +144,7 @@ def __init__(

async def generate_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
request: Optional[fastapi.Request] = None,
):
if self.to_create_loop:
Expand All @@ -163,7 +165,7 @@ async def generate_request(

async def _handle_single_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
request: Optional[fastapi.Request] = None,
index: Optional[int] = None,
is_cache_for_prefill: Optional[bool] = False,
Expand All @@ -173,7 +175,13 @@ async def _handle_single_request(

rid = obj.rid if not_use_index else obj.rid[index]
input_text = obj.text if not_use_index else obj.text[index]
if obj.input_ids is None:
if hasattr(obj, "conv"):
# reward model
assert self.tokenizer is not None
conv = obj.conv if not_use_index else obj.conv[index]
input_text = self.tokenizer.apply_chat_template(conv, tokenize=False)
input_ids = self.tokenizer.encode(input_text)
elif obj.input_ids is None:
assert self.tokenizer is not None
input_ids = self.tokenizer.encode(input_text)
else:
Expand Down Expand Up @@ -269,13 +277,21 @@ async def _handle_single_request(
else obj.lora_path
),
)
else: # is embedding
elif isinstance(obj, EmbeddingReqInput):
tokenized_obj = TokenizedEmbeddingReqInput(
rid,
input_text,
input_ids,
sampling_params,
)
else:
assert isinstance(obj, RewardReqInput)
tokenized_obj = TokenizedRewardReqInput(
rid,
input_text,
input_ids,
sampling_params,
)
self.send_to_controller.send_pyobj(tokenized_obj)

# Recv results
Expand All @@ -292,7 +308,7 @@ async def _handle_single_request(

async def _handle_batch_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
request: Optional[fastapi.Request] = None,
):
batch_size = obj.batch_size
Expand Down Expand Up @@ -329,9 +345,16 @@ async def _handle_batch_request(
rid = obj.rid[index]
if parallel_sample_num == 1:
## select operation
if obj.input_ids is None:
if hasattr(obj, "conv"):
# reward model
conv = obj.conv[i]
input_text = self.tokenizer.apply_chat_template(
conv, tokenize=False
)
input_ids = self.tokenizer.encode(input_text)
elif obj.input_ids is None:
input_text = obj.text[i]
input_ids = self.tokenizer.encode(obj.text[i])
input_ids = self.tokenizer.encode(input_text)
else:
input_text = None
input_ids = obj.input_ids[i]
Expand Down Expand Up @@ -370,13 +393,21 @@ async def _handle_batch_request(
else obj.lora_path
),
)
else:
elif isinstance(obj, EmbeddingReqInput):
tokenized_obj = TokenizedEmbeddingReqInput(
rid,
input_text,
input_ids,
sampling_params,
)
else:
assert isinstance(obj, RewardReqInput)
tokenized_obj = TokenizedRewardReqInput(
rid,
input_text,
input_ids,
sampling_params,
)
self.send_to_controller.send_pyobj(tokenized_obj)

event = asyncio.Event()
Expand Down Expand Up @@ -442,7 +473,7 @@ def _get_sampling_params(self, sampling_params_data: dict):
async def _wait_for_response(
self,
state: ReqState,
obj: Union[GenerateReqInput, EmbeddingReqInput],
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
rid: str,
request: Optional[fastapi.Request] = None,
index: Optional[int] = None,
Expand All @@ -469,7 +500,7 @@ async def _wait_for_response(
),
obj.return_text_in_logprobs,
)
else: # isinstance(obj, EmbeddingReqInput)
else: # isinstance(obj, (EmbeddingReqInput, RewardReqInput))
out = state.out_list[-1]

out["index"] = response_index
Expand Down
9 changes: 6 additions & 3 deletions python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import pickle
import time
import warnings
from typing import Any, List, Optional
from typing import Any, List, Optional, Union

import torch
import torch.distributed
Expand All @@ -41,6 +41,7 @@
FlushCacheReq,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
TokenizedRewardReqInput,
UpdateWeightReqInput,
UpdateWeightReqOutput,
)
Expand Down Expand Up @@ -223,7 +224,9 @@ def exposed_step(self, recv_reqs: List):
if isinstance(recv_req, TokenizedGenerateReqInput):
self.handle_generate_request(recv_req)
self.do_not_get_new_batch = False
elif isinstance(recv_req, TokenizedEmbeddingReqInput):
elif isinstance(
recv_req, (TokenizedEmbeddingReqInput, TokenizedRewardReqInput)
):
self.handle_embedding_request(recv_req)
self.do_not_get_new_batch = False
elif isinstance(recv_req, FlushCacheReq):
Expand Down Expand Up @@ -407,7 +410,7 @@ def handle_generate_request(

def handle_embedding_request(
self,
recv_req: TokenizedEmbeddingReqInput,
recv_req: Union[TokenizedEmbeddingReqInput, TokenizedRewardReqInput],
):
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
req.tokenizer = self.tokenizer
Expand Down
Loading

0 comments on commit 14a76b9

Please sign in to comment.