diff --git a/.github/ISSUE_TEMPLATE/1-bug-report.yml b/.github/ISSUE_TEMPLATE/1-bug-report.yml index c1684c14bb..5f6734867c 100644 --- a/.github/ISSUE_TEMPLATE/1-bug-report.yml +++ b/.github/ISSUE_TEMPLATE/1-bug-report.yml @@ -12,6 +12,7 @@ body: - label: 2. The bug has not been fixed in the latest version. - label: 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback. - label: 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/sgl-project/sglang/discussions/new/choose Otherwise, it will be closed. + - label: 5. Please use English, otherwise it will be closed. - type: textarea attributes: label: Describe the bug @@ -31,7 +32,7 @@ body: attributes: label: Environment description: | - Please provide necessary environment information here with `python3 -m sglang.check_env`. + Please provide necessary environment information here with `python3 -m sglang.check_env`. Otherwise the issue will be closed. placeholder: Environment here. validations: required: true diff --git a/.github/ISSUE_TEMPLATE/2-feature-request.yml b/.github/ISSUE_TEMPLATE/2-feature-request.yml index 5ab369f8b0..31bc4a127e 100644 --- a/.github/ISSUE_TEMPLATE/2-feature-request.yml +++ b/.github/ISSUE_TEMPLATE/2-feature-request.yml @@ -3,6 +3,12 @@ description: Suggest an idea for this project title: "[Feature] " body: +- type: checkboxes + attributes: + label: Checklist + options: + - label: 1. If the issue you raised is not a feature but a question, please raise a discussion at https://github.com/sgl-project/sglang/discussions/new/choose Otherwise, it will be closed. + - label: 2. Please use English, otherwise it will be closed. - type: textarea attributes: label: Motivation diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 20f4a10bc5..21f9a21117 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -1,15 +1,15 @@ -Thank you for your contribution, we really appreciate it. The following instructions will help improve your pull request and make it easier to receive feedback. If there are any items you don't understand, don't worry. Just submit the pull request and ask the maintainers for help. + ## Motivation -Please explain the motivation behind this PR and the goal you aim to achieve with it. + -## Modification +## Modifications -Briefly describe the changes made in this PR. + ## Checklist -1. Ensure pre-commit `pre-commit run --all-files` or other linting tools are used to fix potential lint issues. -2. Confirm that modifications are covered by complete unit tests. If not, please add more unit tests for correctness. -3. Modify documentation as needed, such as docstrings or example tutorials. +- [ ] Format your code according to the [Contributor Guide](https://github.com/sgl-project/sglang/blob/main/docs/en/contributor_guide.md). +- [ ] Add unit tests as outlined in the [Contributor Guide](https://github.com/sgl-project/sglang/blob/main/docs/en/contributor_guide.md). +- [ ] Update documentation as needed, including docstrings or example tutorials. \ No newline at end of file diff --git a/.github/workflows/accuracy-test.yml b/.github/workflows/accuracy-test.yml new file mode 100644 index 0000000000..6fb102a4c5 --- /dev/null +++ b/.github/workflows/accuracy-test.yml @@ -0,0 +1,43 @@ +name: Accuracy Test + +on: + push: + branches: [ main ] + paths: + - "python/sglang/**" + - "test/**" + pull_request: + branches: [ main ] + paths: + - "python/sglang/**" + - "test/**" + workflow_dispatch: + +concurrency: + group: accuracy-test-${{ github.ref }} + cancel-in-progress: true + +jobs: + accuracy-test: + if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + runs-on: 1-gpu-runner + + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Install dependencies + run: | + pip install --upgrade pip + pip install -e "python[all]" + pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall + + git clone https://github.com/merrymercy/human-eval.git + cd human-eval + pip install -e . + + - name: Evaluate Accuracy + timeout-minutes: 20 + run: | + cd test/srt + python3 test_eval_accuracy_large.py diff --git a/.github/workflows/cancel-pr-workflow.yml b/.github/workflows/cancel-pr-workflow.yml new file mode 100644 index 0000000000..535884ba60 --- /dev/null +++ b/.github/workflows/cancel-pr-workflow.yml @@ -0,0 +1,22 @@ +name: Cancel PR Workflows on Merge + +on: + pull_request_target: + types: + - closed + +permissions: + actions: write + +jobs: + cancel: + if: github.event.pull_request.merged == true + runs-on: ubuntu-latest + steps: + - name: Cancel Previous Runs + uses: styfle/cancel-workflow-action@0.12.1 + with: + workflow_id: all + access_token: ${{ secrets.GITHUB_TOKEN }} + ignore_sha: true + pr_number: ${{ github.event.pull_request.number }} diff --git a/.github/workflows/e2e-test.yml b/.github/workflows/e2e-test.yml index 78ac4d9ec7..11c94775c1 100644 --- a/.github/workflows/e2e-test.yml +++ b/.github/workflows/e2e-test.yml @@ -20,33 +20,37 @@ concurrency: jobs: e2e-test: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' - runs-on: bench + runs-on: 1-gpu-runner steps: - - name: Checkout code - uses: actions/checkout@v3 - - - name: Install dependencies - run: | - source $HOME/venv/bin/activate - echo "$HOME/venv/bin" >> $GITHUB_PATH - - pip install --upgrade pip - pip install -e "python[all]" - pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall - - - name: Benchmark Serving Throughput - run: | - cd test/srt - python3 -m unittest test_serving_throughput.TestServingThroughput.test_default - - - name: Benchmark Serving Throughput (w/o RadixAttention) - run: | - cd test/srt - python3 -m unittest test_serving_throughput.TestServingThroughput.test_default_without_radix_cache - - - name: Benchmark Serving Throughput (w/o FlashInfer) - run: | - cd test/srt - python3 -m unittest test_serving_throughput.TestServingThroughput.test_default_without_flashinfer - + - name: Checkout code + uses: actions/checkout@v3 + + - name: Install dependencies + run: | + pip install --upgrade pip + pip install -e "python[all]" + pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall + + - name: Benchmark Serving Throughput + timeout-minutes: 10 + run: | + cd test/srt + python3 -m unittest test_serving_throughput.TestServingThroughput.test_default + + - name: Benchmark Serving Latency + timeout-minutes: 10 + run: | + python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3.1-8B-Instruct --batch-size 1 --input 128 --output 8 + + - name: Benchmark Serving Throughput (w/o RadixAttention) + timeout-minutes: 10 + run: | + cd test/srt + python3 -m unittest test_serving_throughput.TestServingThroughput.test_default_without_radix_cache + + - name: Benchmark Serving Throughput (w/o ChunkedPrefill) + timeout-minutes: 10 + run: | + cd test/srt + python3 -m unittest test_serving_throughput.TestServingThroughput.test_default_without_chunked_prefill diff --git a/.github/workflows/moe-test.yml b/.github/workflows/moe-test.yml new file mode 100644 index 0000000000..4440aa215f --- /dev/null +++ b/.github/workflows/moe-test.yml @@ -0,0 +1,45 @@ +name: MoE Test + +on: + push: + branches: [ main ] + paths: + - "python/sglang/**" + - "test/**" + pull_request: + branches: [ main ] + paths: + - "python/sglang/**" + - "test/**" + workflow_dispatch: + +concurrency: + group: moe-test-${{ github.ref }} + cancel-in-progress: true + +jobs: + moe-test: + if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + runs-on: 2-gpu-runner + + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Install dependencies + run: | + pip install --upgrade pip + pip install -e "python[all]" + pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall + + - name: Benchmark MoE Serving Throughput + timeout-minutes: 10 + run: | + cd test/srt + python3 -m unittest test_moe_serving_throughput.TestServingThroughput.test_default + + - name: Benchmark MoE Serving Throughput (w/o RadixAttention) + timeout-minutes: 10 + run: | + cd test/srt + python3 -m unittest test_moe_serving_throughput.TestServingThroughput.test_default_without_radix_cache diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml index f9b79dc674..41a565a638 100644 --- a/.github/workflows/unit-test.yml +++ b/.github/workflows/unit-test.yml @@ -18,31 +18,39 @@ concurrency: cancel-in-progress: true jobs: - unit-test: + unit-test-jobs: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' - runs-on: unit - + runs-on: 1-gpu-runner + strategy: + matrix: + test_type: ['backend-0', 'backend-1', 'frontend'] steps: - - name: Checkout code - uses: actions/checkout@v3 - - - name: Install dependencies - run: | - source $HOME/venv/bin/activate - echo "$HOME/venv/bin" >> $GITHUB_PATH - - pip install --upgrade pip - pip install -e "python[all]" - pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall - pip install accelerate - pip install sentence_transformers + - name: Checkout code + uses: actions/checkout@v3 + + - name: Install dependencies + run: | + pip install --upgrade pip + pip install -e "python[dev]" + pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall + + - name: Run test + timeout-minutes: 20 + run: | + if [ "${{ matrix.test_type }}" = "frontend" ]; then + cd test/lang + python3 run_suite.py --suite minimal + elif [ "${{ matrix.test_type }}" = "backend-0" ]; then + cd test/srt + python3 run_suite.py --suite minimal --range-begin 0 --range-end 8 + elif [ "${{ matrix.test_type }}" = "backend-1" ]; then + cd test/srt + python3 run_suite.py --suite minimal --range-begin 8 + fi - - name: Test Backend Runtime - run: | - cd test/srt - python3 run_suite.py --suite minimal - - - name: Test Frontend Language - run: | - cd test/lang - python3 run_suite.py --suite minimal + unit-test: + needs: unit-test-jobs + runs-on: ubuntu-latest + steps: + - name: Merge step + run: echo "This is an empty merge step" \ No newline at end of file diff --git a/README.md b/README.md index f81593ef6d..8e3e47c100 100644 --- a/README.md +++ b/README.md @@ -17,17 +17,18 @@ SGLang is a fast serving framework for large language models and vision language It makes your interaction with models faster and more controllable by co-designing the backend runtime and frontend language. The core features include: -- **Fast Backend Runtime**: Efficient serving with RadixAttention for prefix caching, jump-forward constrained decoding, continuous batching, token attention (paged attention), tensor parallelism, flashinfer kernels, and quantization (AWQ/FP8/GPTQ/Marlin). +- **Fast Backend Runtime**: Efficient serving with RadixAttention for prefix caching, jump-forward constrained decoding, continuous batching, token attention (paged attention), tensor parallelism, FlashInfer kernels, and quantization (AWQ/FP8/GPTQ/Marlin). - **Flexible Frontend Language**: Enables easy programming of LLM applications with chained generation calls, advanced prompting, control flow, multiple modalities, parallelism, and external interactions. ## News - [2024/07] 🔥 Faster Llama3 Serving with SGLang Runtime (vs. TensorRT-LLM, vLLM) ([blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/)). -- [2024/04] SGLang is used by the official **LLaVA-NeXT (video)** release ([blog](https://llava-vl.github.io/blog/2024-04-30-llava-next-video/)). +- [2024/08] 🔥 LLaVA-OneVision with single-image, multi-image and video are supported ([blog](https://llava-vl.github.io/blog/2024-08-05-llava-onevision/)). - [2024/02] SGLang enables **3x faster JSON decoding** with compressed finite state machine ([blog](https://lmsys.org/blog/2024-02-05-compressed-fsm/)).
More +- [2024/04] SGLang is used by the official **LLaVA-NeXT (video)** release ([blog](https://llava-vl.github.io/blog/2024-04-30-llava-next-video/)). - [2024/01] SGLang provides up to **5x faster inference** with RadixAttention ([blog](https://lmsys.org/blog/2024-01-17-sglang/)). - [2024/01] SGLang powers the serving of the official **LLaVA v1.6** release demo ([usage](https://github.com/haotian-liu/LLaVA?tab=readme-ov-file#demo)). @@ -55,7 +56,7 @@ pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ ### Method 2: From source ``` # Use the last release branch -git clone -b v0.2.11 https://github.com/sgl-project/sglang.git +git clone -b v0.2.14.post2 https://github.com/sgl-project/sglang.git cd sglang pip install --upgrade pip @@ -76,11 +77,65 @@ docker run --gpus all \ --env "HF_TOKEN=" \ --ipc=host \ lmsysorg/sglang:latest \ - python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --host 0.0.0.0 --port 30000 + python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --host 0.0.0.0 --port 30000 ``` +### Method 4: Using docker compose + +
+More + +> This method is recommended if you plan to serve it as a service. +> A better approach is to use the [k8s-sglang-service.yaml](./docker/k8s-sglang-service.yaml). + +1. Copy the [compose.yml](./docker/compose.yaml) to your local machine +2. Execute the command `docker compose up -d` in your terminal. +
+ +### Method 5: Run on Kubernetes or Clouds with SkyPilot + +
+More + +To deploy on Kubernetes or 12+ clouds, you can use [SkyPilot](https://github.com/skypilot-org/skypilot). + +1. Install SkyPilot and set up Kubernetes cluster or cloud access: see [SkyPilot's documentation](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html). +2. Deploy on your own infra with a single command and get the HTTP API endpoint: +
+SkyPilot YAML: sglang.yaml + +```yaml +# sglang.yaml +envs: + HF_TOKEN: null + +resources: + image_id: docker:lmsysorg/sglang:latest + accelerators: A100 + ports: 30000 + +run: | + conda deactivate + python3 -m sglang.launch_server \ + --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \ + --host 0.0.0.0 \ + --port 30000 +``` +
+ +```bash +# Deploy on any cloud or Kubernetes cluster. Use --cloud to select a specific cloud provider. +HF_TOKEN= sky launch -c sglang --env HF_TOKEN sglang.yaml + +# Get the HTTP API endpoint +sky status --endpoint 30000 sglang +``` +3. To further scale up your deployment with autoscaling and failure recovery, check out the [SkyServe + SGLang guide](https://github.com/skypilot-org/skypilot/tree/master/llm/sglang#serving-llama-2-with-sglang-for-more-traffic-using-skyserve). +
+ + ### Common Notes -- If you cannot install FlashInfer, check out its [installation](https://docs.flashinfer.ai/installation.html#) page. If you still cannot install it, you can use the slower Triton kernels by adding `--disable-flashinfer` when launching the server. +- [FlashInfer](https://github.com/flashinfer-ai/flashinfer) is currently one of the dependencies that must be installed for SGLang. It only supports sm75 and above. If you encounter any FlashInfer-related issues on sm75+ devices (e.g., T4, A10, A100, L4, L40S, H100), consider using Triton's kernel by `--disable-flashinfer --disable-flashinfer-sampling` and raise an issue. - If you only need to use the OpenAI backend, you can avoid installing other dependencies by using `pip install "sglang[openai]"`. ## Backend: SGLang Runtime (SRT) @@ -134,6 +189,13 @@ response = client.chat.completions.create( max_tokens=64, ) print(response) + +# Text embedding +response = client.embeddings.create( + model="default", + input="How are you today", +) +print(response) ``` It supports streaming, vision, and most features of the Chat/Completions/Models/Batch endpoints specified by the [OpenAI API Reference](https://platform.openai.com/docs/api-reference/). @@ -154,7 +216,7 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct - See [hyperparameter_tuning.md](docs/en/hyperparameter_tuning.md) on tuning hyperparameters for better performance. - If you see out-of-memory errors during prefill for long prompts, try to set a smaller chunked prefill size. ``` -python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --chunked-prefill-size 2048 +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --chunked-prefill-size 4096 ``` - Add `--nnodes 2` to run tensor parallelism on multiple nodes. If you have two nodes with two GPUs on each node and want to run TP=4, let `sgl-dev-0` be the hostname of the first node and `50000` be an available port. ``` @@ -170,19 +232,21 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct ### Supported Models +**Generative Models** + - Llama / Llama 2 / Llama 3 / Llama 3.1 - Mistral / Mixtral / Mistral NeMo - Gemma / Gemma 2 - Qwen / Qwen 2 / Qwen 2 MoE - DeepSeek / DeepSeek 2 -- LLaVA 1.5 / 1.6 - - `python -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000` - - `python -m sglang.launch_server --model-path liuhaotian/llava-v1.6-vicuna-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000` - - `python -m sglang.launch_server --model-path liuhaotian/llava-v1.6-34b --tokenizer-path liuhaotian/llava-v1.6-34b-tokenizer --port 30000` -- LLaVA-NeXT-Video - - see [examples/usage/llava_video](examples/usage/llava_video) +- [LLaVA-OneVision](https://llava-vl.github.io/blog/2024-08-05-llava-onevision/) + - `python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --port=30000 --tp-size=8 --chat-template=chatml-llava` + - Query the server with the [OpenAI Vision API](https://platform.openai.com/docs/guides/vision). See examples at [test/srt/test_vision_openai_server.py](test/srt/test_vision_openai_server.py) +- LLaVA 1.5 / 1.6 / NeXT + - `python -m sglang.launch_server --model-path lmms-lab/llama3-llava-next-8b --port=30000 --tp-size=1 --chat-template=llava_llama_3` + - `python -m sglang.launch_server --model-path lmms-lab/llava-next-72b --port=30000 --tp-size=8 --chat-template=chatml-llava` + - Query the server with the [OpenAI Vision API](https://platform.openai.com/docs/guides/vision). See examples at [test/srt/test_vision_openai_server.py](test/srt/test_vision_openai_server.py) - Yi-VL - - see [srt_example_yi_vl.py](examples/quick_start/srt_example_yi_vl.py). - StableLM - Command-R - DBRX @@ -190,37 +254,52 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct - ChatGLM - InternLM 2 +**Embedding Models** + +- e5-mistral +- gte-Qwen2 + - `python -m sglang.launch_server --model-path Alibaba-NLP/gte-Qwen2-7B-instruct --is-embedding` + Instructions for supporting a new model are [here](https://github.com/sgl-project/sglang/blob/main/docs/en/model_support.md). #### Use Models From ModelScope -To use model from [ModelScope](https://www.modelscope.cn), setting environment variable SGLANG_USE_MODELSCOPE. +
+More + +To use a model from [ModelScope](https://www.modelscope.cn), set the environment variable SGLANG_USE_MODELSCOPE. ``` export SGLANG_USE_MODELSCOPE=true ``` Launch [Qwen2-7B-Instruct](https://www.modelscope.cn/models/qwen/qwen2-7b-instruct) Server ``` SGLANG_USE_MODELSCOPE=true python -m sglang.launch_server --model-path qwen/Qwen2-7B-Instruct --port 30000 -``` +``` + +
#### Run Llama 3.1 405B +
+More ```bash -## Run 405B (fp8) on a single node +# Run 405B (fp8) on a single node python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct-FP8 --tp 8 -## Run 405B (fp16) on two nodes -# replace the `172.16.4.52:20000` with your own first node ip address and port, disable CUDA Graph temporarily - -# on the first node -GLOO_SOCKET_IFNAME=eth0 python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct --tp 16 --nccl-init-addr 172.16.4.52:20000 --nnodes 2 --node-rank 0 --disable-cuda-graph --mem-frac 0.75 +# Run 405B (fp16) on two nodes +## on the first node, replace the `172.16.4.52:20000` with your own first node ip address and port +GLOO_SOCKET_IFNAME=eth0 python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct --tp 16 --nccl-init-addr 172.16.4.52:20000 --nnodes 2 --node-rank 0 --disable-cuda-graph -# on the second -GLOO_SOCKET_IFNAME=eth0 python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct --tp 16 --nccl-init-addr 172.16.4.52:20000 --nnodes 2 --node-rank 1 --disable-cuda-graph --mem-frac 0.75 +## on the first node, replace the `172.16.4.52:20000` with your own first node ip address and port +GLOO_SOCKET_IFNAME=eth0 python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct --tp 16 --nccl-init-addr 172.16.4.52:20000 --nnodes 2 --node-rank 1 --disable-cuda-graph ``` +
+ ### Benchmark Performance -- Benchmark a single static batch by running the following command without launching a server. The arguments are the same as for `launch_server.py`. Note that this is not a dynamic batching server, so it may run out of memory for a batch size that a real server can handle. A real server truncates the prefill into several batches, while this unit test does not. For accurate large batch testing, consider using `sglang.bench_serving`. +- Benchmark a single static batch by running the following command without launching a server. The arguments are the same as for `launch_server.py`. + Note that this is not a dynamic batching server, so it may run out of memory for a batch size that a real server can handle. + A real server truncates the prefill into several batches, while this unit test does not. For accurate large batch testing, please use `sglang.bench_serving` instead. ``` python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 32 --input-len 256 --output-len 32 ``` @@ -353,7 +432,7 @@ def tip_suggestion(s): s += "In summary" + sgl.gen("summary") ``` -#### Multi Modality +#### Multi-Modality Use `sgl.image` to pass an image as input. ```python @@ -407,7 +486,7 @@ def character_gen(s, name): s += sgl.gen("json_output", max_tokens=256, regex=character_regex) ``` -See also [json_decode.py](examples/usage/json_decode.py) for an additional example on specifying formats with Pydantic models. +See also [json_decode.py](examples/usage/json_decode.py) for an additional example of specifying formats with Pydantic models. #### Batching Use `run_batch` to run a batch of requests with continuous batching. @@ -469,7 +548,6 @@ def chat_example(s): - The `choices` argument in `sgl.gen` is implemented by computing the [token-length normalized log probabilities](https://blog.eleuther.ai/multiple-choice-normalization/) of all choices and selecting the one with the highest probability. - The `regex` argument in `sgl.gen` is implemented through autoregressive decoding with logit bias masking, according to the constraints set by the regex. It is compatible with `temperature=0` and `temperature != 0`. - ## Benchmark And Performance ![8b_throughput](https://lmsys.org/images/blog/sglang_llama3/8b_throughput.svg) ![70b_fp8_throughput](https://lmsys.org/images/blog/sglang_llama3/70b_fp8_throughput.svg) diff --git a/benchmark/gsm8k/bench_other.py b/benchmark/gsm8k/bench_other.py index c80c17a249..2a938d6bb9 100644 --- a/benchmark/gsm8k/bench_other.py +++ b/benchmark/gsm8k/bench_other.py @@ -65,10 +65,9 @@ def main(args): def get_one_answer(i): answer = call_generate( prompt=few_shot_examples + questions[i], - # prompt="System: " + few_shot_examples + "<|separator|>\n\n" + questions[i], temperature=0, max_tokens=256, - stop="Question", + stop=["Question", "Assistant:", "<|separator|>"], ) states[i] = answer diff --git a/benchmark/gsm8k/bench_sglang.py b/benchmark/gsm8k/bench_sglang.py index 298ec11d73..d32790fe0c 100644 --- a/benchmark/gsm8k/bench_sglang.py +++ b/benchmark/gsm8k/bench_sglang.py @@ -64,7 +64,9 @@ def main(args): @sgl.function def few_shot_gsm8k(s, question): s += few_shot_examples + question - s += sgl.gen("answer", max_tokens=512, stop="Question") + s += sgl.gen( + "answer", max_tokens=512, stop=["Question", "Assistant:", "<|separator|>"] + ) ##################################### ########## SGL Program End ########## @@ -88,6 +90,9 @@ def few_shot_gsm8k(s, question): for i in range(len(states)): preds.append(get_answer_value(states[i]["answer"])) + # print(f"{preds=}") + # print(f"{labels=}") + # Compute accuracy acc = np.mean(np.array(preds) == np.array(labels)) invalid = np.mean(np.array(preds) == INVALID) diff --git a/benchmark/json_decode_regex/bench_other.py b/benchmark/json_decode_regex/bench_other.py index bbe22835a3..d80ea1de7e 100644 --- a/benchmark/json_decode_regex/bench_other.py +++ b/benchmark/json_decode_regex/bench_other.py @@ -6,11 +6,11 @@ from tqdm import tqdm -from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STRING +from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STR from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate from sglang.utils import dump_state_text, read_jsonl -REGEX_LIST = r"\[(" + REGEX_STRING + ", )*" + REGEX_STRING + r"\]" +REGEX_LIST = r"\[(" + REGEX_STR + ", )*" + REGEX_STR + r"\]" # fmt: off @@ -20,9 +20,9 @@ def json_decode(document, generate): s += "Here is the name, country, and symbol of the city in JSON format.\n" s += "{\n" s += ' "name": ' - s += generate(s, max_tokens=8, regex=REGEX_STRING + ",") + "\n" + s += generate(s, max_tokens=8, regex=REGEX_STR + ",") + "\n" s += ' "country": ' - s += generate(s, max_tokens=8, regex=REGEX_STRING + ",") + "\n" + s += generate(s, max_tokens=8, regex=REGEX_STR + ",") + "\n" s += ' "latitude": ' s += generate(s, max_tokens=8, regex=REGEX_FLOAT + ",") + "\n" s += ' "population": ' diff --git a/benchmark/json_decode_regex/bench_sglang.py b/benchmark/json_decode_regex/bench_sglang.py index 1964387229..462c77750c 100644 --- a/benchmark/json_decode_regex/bench_sglang.py +++ b/benchmark/json_decode_regex/bench_sglang.py @@ -3,14 +3,14 @@ import time import sglang as sgl -from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STRING +from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STR from sglang.test.test_utils import ( add_common_sglang_args_and_parse, select_sglang_backend, ) from sglang.utils import dump_state_text, read_jsonl -REGEX_LIST = r"\[(" + REGEX_STRING + ", )*" + REGEX_STRING + r"\]" +REGEX_LIST = r"\[(" + REGEX_STR + ", )*" + REGEX_STR + r"\]" # fmt: off @sgl.function @@ -18,8 +18,8 @@ def json_warm_up(s): s += "The information about Hogwarts is in the following JSON format.\n" with s.var_scope("json_output"): s += "{\n" - s += ' "name": ' + sgl.gen("name", max_tokens=8, regex=REGEX_STRING + ",") + "\n" - s += ' "country": ' + sgl.gen("country", max_tokens=8, regex=REGEX_STRING + ",") + "\n" + s += ' "name": ' + sgl.gen("name", max_tokens=8, regex=REGEX_STR + ",") + "\n" + s += ' "country": ' + sgl.gen("country", max_tokens=8, regex=REGEX_STR + ",") + "\n" s += ' "latitude": ' + sgl.gen("latitude", max_tokens=8, regex=REGEX_FLOAT + ",") + "\n" s += ' "population": ' + sgl.gen("population", max_tokens=8, regex=REGEX_INT + ",") + "\n" s += ' "top 3 landmarks": ' + sgl.gen( "landmarks", max_tokens=24, regex=REGEX_LIST) + "\n" @@ -35,8 +35,8 @@ def json_decode(s, document): s += "Here is the name, country, and symbol of the city in JSON format.\n" with s.var_scope("json_output"): s += "{\n" - s += ' "name": ' + sgl.gen("name", max_tokens=8, regex=REGEX_STRING + ",") + "\n" - s += ' "country": ' + sgl.gen("country", max_tokens=8, regex=REGEX_STRING + ",") + "\n" + s += ' "name": ' + sgl.gen("name", max_tokens=8, regex=REGEX_STR + ",") + "\n" + s += ' "country": ' + sgl.gen("country", max_tokens=8, regex=REGEX_STR + ",") + "\n" s += ' "latitude": ' + sgl.gen("latitude", max_tokens=8, regex=REGEX_FLOAT + ",") + "\n" s += ' "population": ' + sgl.gen("population", max_tokens=8, regex=REGEX_INT + ",") + "\n" s += ' "top 3 landmarks": ' + sgl.gen( "landmarks", max_tokens=24, regex=REGEX_LIST) + "\n" diff --git a/benchmark/latency_throughput/README.md b/benchmark/latency_throughput/README.md deleted file mode 100644 index b6c2e67971..0000000000 --- a/benchmark/latency_throughput/README.md +++ /dev/null @@ -1,71 +0,0 @@ -# Benchmark Latency and Throughput - -## SGLang - -### Launch a server -``` -python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 -``` - -### Benchmark one batch - -``` -python3 bench_one.py -python3 bench_one.py --batch-size 64 -``` - -### Benchmark online serving with many requests - -``` -python3 bench_serving.py --backend srt --port 30000 --tokenizer meta-llama/Llama-2-7b-chat-hf --num-prompt 1000 --request-rate 100 --input-len 1024 --output-len 256 -``` - -### Benchmark online serving on the ShareGPT dataset - -#### Download data -``` -wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json -``` - -#### Run ShareGPT -``` -python3 bench_serving.py --backend srt --port 30000 --tokenizer meta-llama/Llama-2-7b-chat-hf --dataset ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 10 --request-rate 10 -``` - -### Profile with Nsight -1. To profile a single batch, use `nsys profile --cuda-graph-trace=node python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-8B --batch-size 64 --input-len 512` -2. To profile a server, use `nsys profile --trace-fork-before-exec=true --cuda-graph-trace=node python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3-8B`. - - -## Other baselines - -### vLLM -``` -python3 -m vllm.entrypoints.api_server --model meta-llama/Llama-2-7b-chat-hf --tensor-parallel 1 --disable-log-requests --swap-space 16 --port 21000 -``` - -``` -# run synthetic -python3 bench_serving.py --backend vllm --port 30000 --tokenizer meta-llama/Llama-2-7b-chat-hf --num-prompt 1000 --request-rate 100 --input-len 1024 --output-len 256 -``` - -``` -# run ShareGPT -python3 bench_serving.py --backend vllm --port 21000 --tokenizer meta-llama/Llama-2-7b-chat-hf --dataset ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 10 --request-rate 10 -``` - -``` -# run one batch -python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-70B --tensor 8 --disable-log-requests --max-num-seqs 1024 --quantization fp8 - -python3 bench_one.py --input-len 1024 --batch-size 1 1 2 4 8 16 32 64 128 256 512 768 1024 --port 8000 --backend vllm -``` - -### LightLLM -``` -python -m lightllm.server.api_server --model_dir ~/model_weights/Llama-2-7b-chat-hf --max_total_token_num 15600 --tokenizer_mode auto --port 22000 -``` - -``` -python3 bench_serving.py --backend lightllm --port 22000 --tokenizer meta-llama/Llama-2-7b-chat-hf --dataset ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 10 --request-rate 10 -``` diff --git a/benchmark/latency_throughput/bench_one.py b/benchmark/latency_throughput/bench_one.py deleted file mode 100644 index b390c44a53..0000000000 --- a/benchmark/latency_throughput/bench_one.py +++ /dev/null @@ -1,147 +0,0 @@ -""" -Usage: -python3 bench_one.py --input-len 2048 --batch-size 1 2 4 8 16 32 64 128 256 512 -""" - -import argparse -import json -import time - -import numpy as np -import requests - - -def run_one_batch_size(bs): - url = f"{args.host}:{args.port}" - max_new_tokens = args.max_tokens - - if args.input_len: - input_ids = [ - [int(x) for x in np.random.randint(0, high=16384, size=(args.input_len,))] - for _ in range(bs) - ] - else: - text = [f"{i, }" for i in range(bs)] - - tic = time.time() - if args.backend == "srt": - if args.input_len: - inputs = {"input_ids": input_ids} - else: - inputs = {"text": text} - - response = requests.post( - url + "/generate", - json={ - "sampling_params": { - "temperature": 0, - "max_new_tokens": max_new_tokens, - "ignore_eos": True, - }, - **inputs, - }, - ) - elif args.backend == "lightllm": - response = requests.post( - url + "/generate", - json={ - "inputs": text[0], - "parameters": { - "temperature": 0, - "max_new_tokens": max_new_tokens, - "ignore_eos": True, - }, - }, - ) - elif args.backend == "vllm": - if args.input_len: - inputs = {"prompt": input_ids} - else: - inputs = {"prompt": text} - - response = requests.post( - url + "/v1/completions", - json={ - "model": args.vllm_model_name, - "temperature": 0, - "max_tokens": max_new_tokens, - "ignore_eos": True, - **inputs, - }, - ) - elif args.backend == "ginfer": - import grpc - from ginfer import sampler_pb2, sampler_pb2_grpc - - sampler_channel = grpc.insecure_channel(url.replace("http://", "")) - sampler = sampler_pb2_grpc.SamplerStub(sampler_channel) - - tic = time.time() - sample_request = sampler_pb2.SampleTextRequest( - prompt=text[0], - settings=sampler_pb2.SampleSettings( - max_len=max_new_tokens, - rng_seed=0, - temperature=0, - nucleus_p=1, - ), - ) - stream = sampler.SampleText(sample_request) - response = "".join([x.text for x in stream]) - latency = time.time() - tic - - if isinstance(response, str): - ret = response - else: - ret = response.json() - print(ret) - - input_len = args.input_len if args.input_len else 1 - output_len = max_new_tokens - - output_throughput = bs * max_new_tokens / latency - overall_throughput = bs * (input_len + output_len) / latency - print(f"latency: {latency:.2f} s") - print(f"output throughput: {output_throughput:.2f} token/s") - print(f"(input + output) throughput: {overall_throughput:.2f} token/s") - - with open("results.jsonl", "a") as fout: - res = { - "backend": args.backend, - "input_len": args.input_len, - "output_len": args.max_tokens, - "batch_size": bs, - "latency": latency, - "output_throughput": output_throughput, - "overall_throughput": overall_throughput, - } - fout.write(json.dumps(res) + "\n") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--host", type=str, default="http://127.0.0.1") - parser.add_argument("--port", type=int, default=None) - parser.add_argument("--backend", type=str, default="srt") - parser.add_argument("--input-len", type=int, default=None) - parser.add_argument("--batch-size", type=int, nargs="*", default=[1]) - parser.add_argument("--max-tokens", type=int, default=256) - parser.add_argument( - "--vllm-model-name", type=str, default="meta-llama/Meta-Llama-3-70B" - ) - args = parser.parse_args() - - if args.port is None: - if args.backend == "srt": - args.port = 30000 - elif args.backend == "vllm": - args.port = 21000 - elif args.backend == "lightllm": - args.port = 22000 - elif args.backend == "ginfer": - args.port = 9988 - else: - raise ValueError(f"Invalid backend: {args.backend}") - - for bs in args.batch_size: - run_one_batch_size(bs) diff --git a/benchmark/latency_throughput/bench_serving.py b/benchmark/latency_throughput/bench_serving.py deleted file mode 100644 index 74fafc9494..0000000000 --- a/benchmark/latency_throughput/bench_serving.py +++ /dev/null @@ -1,374 +0,0 @@ -"""Benchmark online serving throughput. - -On the server side, run one of the following commands: - (vLLM backend) - python -m vllm.entrypoints.api_server \ - --model --swap-space 16 \ - --disable-log-requests - - (TGI backend) - ./launch_hf_server.sh - -On the client side, run: - python benchmarks/benchmark_serving.py \ - --backend \ - --tokenizer --dataset \ - --request-rate -""" - -import argparse -import asyncio -import json -import os -import random -import time -from typing import AsyncGenerator, List, Tuple - -import aiohttp -import numpy as np -from tqdm.asyncio import tqdm_asyncio -from transformers import AutoTokenizer - -# (prompt len, output len, latency) -REQUEST_LATENCY: List[Tuple[int, int, float]] = [] - - -def sample_requests( - dataset_path: str, - num_requests: int, - tokenizer: AutoTokenizer, -) -> List[Tuple[str, int, int]]: - def load_dataset(): - with open(dataset_path, encoding="utf-8") as f: - dataset = json.load(f) - # Filter out the conversations with less than 2 turns. - dataset = [data for data in dataset if len(data["conversations"]) >= 2] - # Only keep the first two turns of each conversation. - dataset = [ - (data["conversations"][0]["value"], data["conversations"][1]["value"]) - for data in dataset - ] - - # Tokenize the prompts and completions. - prompts = [prompt for prompt, _ in dataset] - prompt_token_ids = tokenizer(prompts).input_ids - completions = [completion for _, completion in dataset] - completion_token_ids = tokenizer(completions).input_ids - tokenized_dataset = [] - for i in range(len(dataset)): - output_len = len(completion_token_ids[i]) - tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len)) - - # Filter out too long sequences. - filtered_dataset: List[Tuple[str, int, int]] = [] - for prompt, prompt_token_ids, output_len in tokenized_dataset: - prompt_len = len(prompt_token_ids) - if prompt_len < 4 or output_len < 4: - # Prune too short sequences. - # This is because TGI causes errors when the input or output length - # is too short. - continue - if prompt_len > 1024 or prompt_len + output_len > 2048: - # Prune too long sequences. - continue - filtered_dataset.append((prompt, prompt_len, output_len)) - - return filtered_dataset - - try: - from diskcache import Cache - - home_dir = os.path.expanduser("~") - cache = Cache(f"{home_dir}/.cache/sglang") - with Cache(cache.directory) as reference: - reference_key = f"{dataset_path}_{tokenizer.name_or_path}" - if reference_key in reference: - print("Reading dataset from cache...") - dataset = reference[reference_key] - else: - dataset = load_dataset() - reference[reference_key] = dataset - except ImportError: - dataset = load_dataset() - - # Sample the requests. - sampled_requests = random.sample(dataset, num_requests) - return sampled_requests - - -async def get_request( - input_requests: List[Tuple[str, int, int]], - request_rate: float, -) -> AsyncGenerator[Tuple[str, int, int], None]: - input_requests = iter(input_requests) - for request in input_requests: - yield request - - if request_rate == float("inf"): - # If the request rate is infinity, then we don't need to wait. - continue - # Sample the request interval from the exponential distribution. - interval = np.random.exponential(1.0 / request_rate) - # The next request will be sent after the interval. - await asyncio.sleep(interval) - - -async def send_request( - backend: str, - api_url: str, - prompt: str, - prompt_len: int, - output_len: int, - best_of: int, - use_beam_search: bool, -) -> None: - request_start_time = time.perf_counter() - - headers = {"User-Agent": "Benchmark Client"} - if backend == "vllm": - pload = { - "prompt": prompt, - "n": 1, - "best_of": best_of, - "use_beam_search": use_beam_search, - "temperature": 0.0 if use_beam_search else 1.0, - "top_p": 1.0, - "max_tokens": output_len, - "ignore_eos": True, - "stream": False, - } - elif backend == "tgi": - assert not use_beam_search - params = { - "best_of": best_of, - "max_new_tokens": output_len, - "do_sample": True, - } - pload = { - "inputs": prompt, - "parameters": params, - } - elif backend == "srt": - assert not use_beam_search - params = { - "ignore_eos": True, - "max_new_tokens": output_len, - } - pload = { - "text": prompt, - "sampling_params": params, - } - elif backend == "lightllm": - assert not use_beam_search - params = { - "ignore_eos": True, - "max_new_tokens": output_len, - } - pload = { - "inputs": prompt, - "parameters": params, - } - elif backend == "ginfer": - pass - else: - raise ValueError(f"Unknown backend: {backend}") - - if backend != "ginfer": - timeout = aiohttp.ClientTimeout(total=3 * 3600) - async with aiohttp.ClientSession(timeout=timeout) as session: - while True: - async with session.post( - api_url, headers=headers, json=pload - ) as response: - chunks = [] - async for chunk, _ in response.content.iter_chunks(): - chunks.append(chunk) - output = b"".join(chunks).decode("utf-8") - output = json.loads(output) - - # Re-send the request if it failed. - if "error" not in output: - break - else: - print(output) - else: - import grpc - from ginfer import sampler_pb2, sampler_pb2_grpc - - api_url = api_url.replace("http://", "").replace("/generate", "") - sampler_channel = grpc.aio.insecure_channel(api_url) - sampler = sampler_pb2_grpc.SamplerStub(sampler_channel) - - request_end_time = time.perf_counter() - sample_request = sampler_pb2.SampleTextRequest( - prompt=prompt, - settings=sampler_pb2.SampleSettings( - max_len=output_len, - rng_seed=0, - temperature=0, - nucleus_p=1, - ), - ) - stream = sampler.SampleText(sample_request) - response = "".join([x.text async for x in stream]) - - request_end_time = time.perf_counter() - request_latency = request_end_time - request_start_time - REQUEST_LATENCY.append((prompt_len, output_len, request_latency)) - - -async def benchmark( - backend: str, - api_url: str, - input_requests: List[Tuple[str, int, int]], - best_of: int, - use_beam_search: bool, - request_rate: float, -) -> None: - tasks: List[asyncio.Task] = [] - async for request in get_request(input_requests, request_rate): - prompt, prompt_len, output_len = request - task = asyncio.create_task( - send_request( - backend, - api_url, - prompt, - prompt_len, - output_len, - best_of, - use_beam_search, - ) - ) - tasks.append(task) - await tqdm_asyncio.gather(*tasks) - - -def main(args: argparse.Namespace): - print(args) - random.seed(args.seed) - np.random.seed(args.seed) - - api_url = f"{args.host}:{args.port}/generate" - if args.tokenizer.endswith(".json") or args.tokenizer.endswith(".model"): - from sglang.srt.hf_transformers_utils import get_tokenizer - - tokenizer = get_tokenizer(args.tokenizer) - else: - tokenizer = AutoTokenizer.from_pretrained( - args.tokenizer, trust_remote_code=args.trust_remote_code - ) - - if args.dataset: - input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer) - else: - input_lens = np.random.randint( - int(args.input_len * args.range_ratio), - args.input_len + 1, - size=args.num_prompts, - ) - output_lens = np.random.randint( - int(args.output_len * args.range_ratio), - args.output_len + 1, - size=args.num_prompts, - ) - offsets = np.random.randint(0, tokenizer.vocab_size, size=args.num_prompts) - input_requests = [] - for i in range(args.num_prompts): - prompt = tokenizer.decode( - [ - (offsets[i] + i + j) % (tokenizer.vocab_size - 129) + 128 - for j in range(input_lens[i]) - ] - ) - input_requests.append((prompt, int(input_lens[i]), int(output_lens[i]))) - - benchmark_start_time = time.perf_counter() - asyncio.run( - benchmark( - args.backend, - api_url, - input_requests, - args.best_of, - args.use_beam_search, - args.request_rate, - ) - ) - benchmark_end_time = time.perf_counter() - benchmark_time = benchmark_end_time - benchmark_start_time - - # Compute the statistics. - latencies = [latency for _, _, latency in REQUEST_LATENCY] - avg_latency = np.mean(latencies) - avg_per_token_latency = np.mean( - [ - latency / (prompt_len + output_len) - for prompt_len, output_len, latency in REQUEST_LATENCY - ] - ) - avg_per_output_token_latency = np.mean( - [latency / output_len for _, output_len, latency in REQUEST_LATENCY] - ) - decoding_throughput = ( - np.sum([output_len for _, output_len, _ in REQUEST_LATENCY]) / benchmark_time - ) - - # latencies = [round(latency, 2) for _, _, latency in REQUEST_LATENCY] - # print(latencies) - - print(f"Total time: {benchmark_time:.2f} s") - print(f"Request throughput: {args.num_prompts / benchmark_time:.2f} requests/s") - print(f"Decoding throughput: {decoding_throughput:.2f} token/s") - print(f"Average latency: {avg_latency:.2f} s") - print(f"Average latency per token: {avg_per_token_latency:.2f} s") - print(f"Average latency per output token: {avg_per_output_token_latency:.2f} s") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Benchmark the online serving throughput." - ) - parser.add_argument( - "--backend", - type=str, - default="srt", - choices=["vllm", "tgi", "srt", "lightllm", "ginfer"], - ) - parser.add_argument("--host", type=str, default="http://localhost") - parser.add_argument("--port", type=int, default=30000) - parser.add_argument("--dataset", type=str, help="Path to the dataset.") - parser.add_argument("--input-len", type=int, default=2048) - parser.add_argument("--output-len", type=int, default=256) - parser.add_argument("--range-ratio", type=float, default=1.0) - parser.add_argument( - "--tokenizer", - type=str, - default="NousResearch/Meta-Llama-3-8B", - help="Name or path of the tokenizer.", - ) - parser.add_argument( - "--best-of", - type=int, - default=1, - help="Generates `best_of` sequences per prompt and " "returns the best one.", - ) - parser.add_argument("--use-beam-search", action="store_true") - parser.add_argument( - "--num-prompts", type=int, default=1000, help="Number of prompts to process." - ) - parser.add_argument( - "--request-rate", - type=float, - default=float("inf"), - help="Number of requests per second. If this is inf, " - "then all the requests are sent at time 0. " - "Otherwise, we use Poisson process to synthesize " - "the request arrival times.", - ) - parser.add_argument("--seed", type=int, default=0) - parser.add_argument( - "--trust-remote-code", - action="store_true", - help="trust remote code from huggingface", - ) - args = parser.parse_args() - main(args) diff --git a/docker/Dockerfile b/docker/Dockerfile index 95127b33a9..42656ca264 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,5 +1,5 @@ ARG CUDA_VERSION=12.1.1 -FROM nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu20.04 +FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 ARG BUILD_TYPE=all ENV DEBIAN_FRONTEND=noninteractive @@ -8,7 +8,7 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ && apt update -y \ && apt install software-properties-common -y \ && add-apt-repository ppa:deadsnakes/ppa -y && apt update \ - && apt install python3.10 -y \ + && apt install python3.10 python3.10-dev -y \ && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.8 1 && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 2 \ && update-alternatives --set python3 /usr/bin/python3.10 && apt install python3.10-distutils -y \ && apt install curl git sudo -y \ diff --git a/docker/compose.yaml b/docker/compose.yaml new file mode 100644 index 0000000000..f2da3a416a --- /dev/null +++ b/docker/compose.yaml @@ -0,0 +1,31 @@ +services: + sglang: + image: lmsysorg/sglang:latest + container_name: sglang + volumes: + - ${HOME}/.cache/huggingface:/root/.cache/huggingface + restart: always + network_mode: host + # Or you can only publish port 30000 + # ports: + # - 30000:30000 + environment: + HF_TOKEN: + entrypoint: python3 -m sglang.launch_server + command: + --model-path meta-llama/Meta-Llama-3.1-8B-Instruct + --host 0.0.0.0 + --port 30000 + ulimits: + memlock: -1 + stack: 67108864 + ipc: host + healthcheck: + test: ["CMD-SHELL", "curl -f http://localhost:30000/health || exit 1"] + deploy: + resources: + reservations: + devices: + - driver: nvidia + device_ids: ['0'] + capabilities: [gpu] diff --git a/docker/k8s-sglang-service.yaml b/docker/k8s-sglang-service.yaml new file mode 100644 index 0000000000..c217f356af --- /dev/null +++ b/docker/k8s-sglang-service.yaml @@ -0,0 +1,76 @@ +apiVersion: node.k8s.io/v1 +kind: RuntimeClass +metadata: + name: nvidia +handler: nvidia +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: meta-llama-31-8b-instruct-sglang +spec: + replicas: 1 + strategy: + type: Recreate + selector: + matchLabels: + app: meta-llama-31-8b-instruct-sglang + template: + metadata: + labels: + app: meta-llama-31-8b-instruct-sglang + model: meta-llama-31-8b-instruct + engine: sglang + spec: + hostIPC: true + restartPolicy: Always + runtimeClassName: nvidia + containers: + - name: meta-llama-31-8b-instruct-sglang + image: docker.io/lmsysorg/sglang:latest + imagePullPolicy: Always # IfNotPresent or Never + ports: + - containerPort: 30000 + command: ["python3", "-m", "sglang.launch_server"] + args: ["--model-path", "meta-llama/Meta-Llama-3.1-8B-Instruct", "--host", "0.0.0.0", "--port", "30000"] + env: + - name: HF_TOKEN + value: + resources: + limits: + nvidia.com/gpu: 1 + volumeMounts: + - name: hf-cache + mountPath: /root/.cache/huggingface + readOnly: true + - name: localtime + mountPath: /etc/localtime + readOnly: true + livenessProbe: + httpGet: + path: /health + port: 30000 + initialDelaySeconds: 30 + periodSeconds: 10 + volumes: + - name: hf-cache + hostPath: + path: /root/.cache/huggingface + type: Directory + - name: localtime + hostPath: + path: /etc/localtime + type: File +--- +apiVersion: v1 +kind: Service +metadata: + name: meta-llama-31-8b-instruct-sglang +spec: + selector: + app: meta-llama-31-8b-instruct-sglang + ports: + - protocol: TCP + port: 30000 # port on host + targetPort: 30000 # port in container + type: LoadBalancer diff --git a/docs/en/benchmark_and_profiling.md b/docs/en/benchmark_and_profiling.md new file mode 100644 index 0000000000..3fbd935891 --- /dev/null +++ b/docs/en/benchmark_and_profiling.md @@ -0,0 +1,49 @@ +# Benchmark and Profiling + +## Benchmark +- Benchmark a single static batch by running the following command without launching a server. The arguments are the same as for `launch_server.py`. Note that this is not a dynamic batching server, so it may run out of memory for a batch size that a real server can handle. A real server truncates the prefill into several batches, while this unit test does not. For accurate large batch testing, consider using `sglang.bench_serving`. + ``` + python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 32 --input-len 256 --output-len 32 + ``` +- Benchmark online serving. Launch a server first and run the following command. + ``` + python3 -m sglang.bench_serving --backend sglang --num-prompt 10 + ``` + +## Profile with Nsight +0. Prerequisite +```bash +# install nsys +# https://docs.nvidia.com/nsight-systems/InstallationGuide/index.html +apt update +apt install -y --no-install-recommends gnupg +echo "deb http://developer.download.nvidia.com/devtools/repos/ubuntu$(source /etc/lsb-release; echo "$DISTRIB_RELEASE" | tr -d .)/$(dpkg --print-architecture) /" | tee /etc/apt/sources.list.d/nvidia-devtools.list +apt-key adv --fetch-keys http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub +apt update +apt install nsight-systems-cli +``` + +1. To profile a single batch, use `nsys profile --trace-fork-before-exec=true --cuda-graph-trace=node python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-8B --batch-size 64 --input-len 512` + +2. To profile a server, e.g. + +```bash +# server +# set the delay and duration times according to needs +nsys profile --trace-fork-before-exec=true --cuda-graph-trace=node -o sglang.out --delay 60 --duration 70 python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --disable-radix-cache + +# client +python3 -m sglang.bench_serving --backend sglang --num-prompts 6000 --dataset-name random --random-input 4096 --random-output 2048 +``` + +3. Use NVTX, e.g. + +```bash +# install nvtx +pip install nvtx + +# code snippets +import nvtx +with nvtx.annotate("description", color="color"): + # some critical code +``` \ No newline at end of file diff --git a/docs/en/hyperparameter_tuning.md b/docs/en/hyperparameter_tuning.md index 53b92435c7..f2bf9d55f3 100644 --- a/docs/en/hyperparameter_tuning.md +++ b/docs/en/hyperparameter_tuning.md @@ -6,11 +6,12 @@ Achieving a large batch size is the most important thing for attaining high thro When the server is running at full load, look for the following in the log: -```[gpu=0] Decode batch. #running-req: 233, #token: 370959, token usage: 0.82, gen throughput (token/s): 4594.01, #queue-req: 417``` +```Decode batch. #running-req: 233, #token: 370959, token usage: 0.82, gen throughput (token/s): 4594.01, #queue-req: 417``` ### Tune Your Request Submission Speed `#queue-req` indicates the number of requests in the queue. If you frequently see `#queue-req == 0`, it suggests you are bottlenecked by the request submission speed. -A healthy range for `#queue-req` is `100 - 1000`. +A healthy range for `#queue-req` is `50 - 1000`. +On the other hand, do not make `#queue-req` too large because it will also increase the scheduling overhead on the server. ### Tune `--schedule-conservativeness` `token usage` indicates the KV cache memory utilization of the server. `token usage > 0.9` means good utilization. @@ -19,13 +20,14 @@ The case of serving being too conservative can happen when users send many reque On the other hand, if you see `token usage` very high and you frequently see warnings like `decode out of memory happened, #retracted_reqs: 1, #new_token_ratio: 0.9998 -> 1.0000`, you can increase `--schedule-conservativeness` to a value like 1.3. +If you see `decode out of memory happened` occasionally but not frequently, it is okay. ### Tune `--dp-size` and `--tp-size` Data parallelism is better for throughput. When there is enough GPU memory, always favor data parallelism for throughput. -### (Minor) Tune `--max-prefill-tokens`, `--mem-fraction-static`, `--max-running-requests` +### Avoid out-of-memory by tuning `--chunked-prefill-size`, `--mem-fraction-static`, `--max-running-requests` If you see out of memory (OOM) errors, you can decrease these parameters. -If OOM happens during prefill, try to decrease `--max-prefill-tokens`. +If OOM happens during prefill, try to decrease `--chunked-prefill-size` to `4096` or `2048`. If OOM happens during decoding, try to decrease `--max-running-requests`. You can also try to decrease `--mem-fraction-static`, which reduces the memory usage of the KV cache memory pool and helps both prefill and decoding. diff --git a/docs/en/model_support.md b/docs/en/model_support.md index e46e99e85c..1d720acf5c 100644 --- a/docs/en/model_support.md +++ b/docs/en/model_support.md @@ -5,7 +5,7 @@ To support a new model in SGLang, you only need to add a single file under [SGLa Another valuable resource is the [vLLM Models Directory](https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models). vLLM has extensive coverage of models, and SGLang has reused vLLM for most parts of the model implementations. This similarity makes it easy to port many models from vLLM to SGLang. To port a model from vLLM to SGLang, you can compare these two files [SGLang LLaMA Implementation](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/llama2.py) and [vLLM LLaMA Implementation](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py). This comparison will help you understand how to convert a model implementation from vLLM to SGLang. The major difference is the replacement of PagedAttention with RadixAttention. The other parts are almost identical. Specifically, - - Replace vllm's `Attention` with `RadixAttention`. + - Replace vllm's `Attention` with `RadixAttention`. Note that you need to pass `layer_id` all the way to `RadixAttention`. - Replace vllm's `LogitsProcessor` with SGLang's `LogitsProcessor`. - Remove `Sample`. - Change `forward()` functions, and add `input_metadata`. diff --git a/docs/en/sampling_params.md b/docs/en/sampling_params.md index 5f1cdece6a..0e1c13e4bd 100644 --- a/docs/en/sampling_params.md +++ b/docs/en/sampling_params.md @@ -1,5 +1,8 @@ # Sampling Parameters in SGLang Runtime This doc describes the sampling parameters of the SGLang Runtime. +It is the low-level endpoint of the runtime. +If you want a high-level endpoint that can automatically handle chat templates, consider using the [OpenAI Compatible API +](https://github.com/sgl-project/sglang?tab=readme-ov-file#openai-compatible-api). The `/generate` endpoint accepts the following arguments in the JSON format. @@ -45,6 +48,8 @@ temperature: float = 1.0, top_p: float = 1.0, # Top-k sampling top_k: int = -1, +# Min-p sampling +min_p: float = 0.0, # Whether to ignore EOS token. ignore_eos: bool = False, # Whether to skip the special tokens during detokenization. @@ -55,6 +60,9 @@ spaces_between_special_tokens: bool = True, regex: Optional[str] = None, # Do parallel sampling and return `n` outputs. n: int = 1, +# Constrains the output to follow a given JSON schema. +# `regex` and `json_schema` cannot be set at the same time. +json_schema: Optional[str] = None, ## Penalties. See [Performance Implications on Penalties] section below for more informations. @@ -138,7 +146,7 @@ print("") Launch a server ``` -python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.6-vicuna-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000 +python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-7b-ov --chat-template chatml-llava ``` Download an image @@ -153,7 +161,9 @@ import requests response = requests.post( "http://localhost:30000/generate", json={ - "text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: \nDescribe this picture ASSISTANT:", + "text": "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" + "<|im_start|>user\n\nDescribe this image in a very short sentence.<|im_end|>\n" + "<|im_start|>assistant\n", "image_data": "example_image.png", "sampling_params": { "temperature": 0, diff --git a/docs/en/setup_github_runner.md b/docs/en/setup_github_runner.md new file mode 100644 index 0000000000..8e817dcc88 --- /dev/null +++ b/docs/en/setup_github_runner.md @@ -0,0 +1,44 @@ +# Set Up Self-hosted Runners for GitHub Action + +## Add a Runner + +### Step 1: Start a docker container. + +You can mount a folder for the shared huggingface model weights cache. The command below uses `/tmp/huggingface` as an example. + +``` +docker pull nvidia/cuda:12.1.1-devel-ubuntu22.04 +docker run --shm-size 64g -it -v /tmp/huggingface:/hf_home --gpus all nvidia/cuda:12.1.1-devel-ubuntu22.04 /bin/bash +``` + +### Step 2: Configure the runner by `config.sh` + +Run these commands inside the container. + +``` +apt update && apt install -y curl python3-pip git +export RUNNER_ALLOW_RUNASROOT=1 +``` + +Then follow https://github.com/sgl-project/sglang/settings/actions/runners/new?arch=x64&os=linux to run `config.sh` + +**Notes** +- Do not need to specify the runner group +- Give it a name (e.g., `test-sgl-gpu-0`) and some labels (e.g., `1-gpu-runner`). The labels can be editted later in Github Settings. +- Do not need to change the work folder. + +### Step 3: Run the runner by `run.sh` + +- Set up environment variables +``` +export HF_HOME=/hf_home +export SGLANG_IS_IN_CI=true +export HF_TOKEN=hf_xxx +export OPENAI_API_KEY=sk-xxx +export CUDA_VISIBLE_DEVICES=0 +``` + +- Run it forever +``` +while true; do ./run.sh; echo "Restarting..."; sleep 2; done +``` \ No newline at end of file diff --git a/docs/en/setup_runner.md b/docs/en/setup_runner.md deleted file mode 100644 index 34f4576845..0000000000 --- a/docs/en/setup_runner.md +++ /dev/null @@ -1,34 +0,0 @@ -# Set up self hosted runner for GitHub Action - -## Config Runner - -```bash -# https://github.com/sgl-project/sglang/settings/actions/runners/new?arch=x64&os=linux -# Involves some TOKEN and other private information, click the link to view specific steps. -``` - -## Start Runner - -add `/lib/systemd/system/runner.service` -``` -[Unit] -StartLimitIntervalSec=0 -[Service] -Environment="CUDA_VISIBLE_DEVICES=7" -Environment="XDG_CACHE_HOME=/data/.cache" -Environment="HF_TOKEN=hf_**" -Environment="OPENAI_API_KEY=sk-**" -Environment="HOME=/data/zhyncs" -Restart=always -RestartSec=1 -ExecStart=/data/zhyncs/actions-runner/run.sh -[Install] -WantedBy=multi-user.target -``` - -```bash -sudo systemctl daemon-reload -sudo systemctl start runner -sudo systemctl enable runner -sudo systemctl status runner -``` diff --git a/examples/quick_start/anthropic_example_chat.py b/examples/frontend_language/quick_start/anthropic_example_chat.py similarity index 100% rename from examples/quick_start/anthropic_example_chat.py rename to examples/frontend_language/quick_start/anthropic_example_chat.py diff --git a/examples/quick_start/anthropic_example_complete.py b/examples/frontend_language/quick_start/anthropic_example_complete.py similarity index 100% rename from examples/quick_start/anthropic_example_complete.py rename to examples/frontend_language/quick_start/anthropic_example_complete.py diff --git a/examples/quick_start/azure_openai_example_chat.py b/examples/frontend_language/quick_start/azure_openai_example_chat.py similarity index 100% rename from examples/quick_start/azure_openai_example_chat.py rename to examples/frontend_language/quick_start/azure_openai_example_chat.py diff --git a/examples/quick_start/gemini_example_chat.py b/examples/frontend_language/quick_start/gemini_example_chat.py similarity index 100% rename from examples/quick_start/gemini_example_chat.py rename to examples/frontend_language/quick_start/gemini_example_chat.py diff --git a/examples/quick_start/gemini_example_complete.py b/examples/frontend_language/quick_start/gemini_example_complete.py similarity index 100% rename from examples/quick_start/gemini_example_complete.py rename to examples/frontend_language/quick_start/gemini_example_complete.py diff --git a/examples/quick_start/gemini_example_multimodal_chat.py b/examples/frontend_language/quick_start/gemini_example_multimodal_chat.py similarity index 100% rename from examples/quick_start/gemini_example_multimodal_chat.py rename to examples/frontend_language/quick_start/gemini_example_multimodal_chat.py diff --git a/examples/quick_start/images/cat.jpeg b/examples/frontend_language/quick_start/images/cat.jpeg similarity index 100% rename from examples/quick_start/images/cat.jpeg rename to examples/frontend_language/quick_start/images/cat.jpeg diff --git a/examples/quick_start/images/dog.jpeg b/examples/frontend_language/quick_start/images/dog.jpeg similarity index 100% rename from examples/quick_start/images/dog.jpeg rename to examples/frontend_language/quick_start/images/dog.jpeg diff --git a/examples/quick_start/srt_example_chat.py b/examples/frontend_language/quick_start/local_example_chat.py similarity index 98% rename from examples/quick_start/srt_example_chat.py rename to examples/frontend_language/quick_start/local_example_chat.py index b1e1658a2a..e1e4b62cca 100644 --- a/examples/quick_start/srt_example_chat.py +++ b/examples/frontend_language/quick_start/local_example_chat.py @@ -1,6 +1,6 @@ """ Usage: -python3 srt_example_chat.py +python3 local_example_chat.py """ import sglang as sgl diff --git a/examples/quick_start/srt_example_complete.py b/examples/frontend_language/quick_start/local_example_complete.py similarity index 97% rename from examples/quick_start/srt_example_complete.py rename to examples/frontend_language/quick_start/local_example_complete.py index 056245979f..00a451cf64 100644 --- a/examples/quick_start/srt_example_complete.py +++ b/examples/frontend_language/quick_start/local_example_complete.py @@ -1,6 +1,6 @@ """ Usage: -python3 srt_example_complete.py +python3 local_example_complete.py """ import sglang as sgl diff --git a/examples/quick_start/srt_example_llava.py b/examples/frontend_language/quick_start/local_example_llava_next.py similarity index 69% rename from examples/quick_start/srt_example_llava.py rename to examples/frontend_language/quick_start/local_example_llava_next.py index 5d8f752394..823dc7b0e8 100644 --- a/examples/quick_start/srt_example_llava.py +++ b/examples/frontend_language/quick_start/local_example_llava_next.py @@ -1,8 +1,14 @@ """ -Usage: python3 srt_example_llava.py +Usage: python3 local_example_llava_next.py """ +from PIL import ImageFile + import sglang as sgl +from sglang.lang.chat_template import get_chat_template +from sglang.srt.utils import load_image + +ImageFile.LOAD_TRUNCATED_IMAGES = True # Allow loading of truncated images @sgl.function @@ -44,10 +50,17 @@ def batch(): if __name__ == "__main__": - runtime = sgl.Runtime( - model_path="liuhaotian/llava-v1.6-vicuna-7b", - tokenizer_path="llava-hf/llava-1.5-7b-hf", - ) + import multiprocessing as mp + + mp.set_start_method("spawn", force=True) + + runtime = sgl.Runtime(model_path="lmms-lab/llama3-llava-next-8b") + runtime.endpoint.chat_template = get_chat_template("llama-3-instruct") + + # Or you can use the 72B model + # runtime = sgl.Runtime(model_path="lmms-lab/llava-next-72b", tp_size=8) + # runtime.endpoint.chat_template = get_chat_template("chatml-llava") + sgl.set_default_backend(runtime) print(f"chat template: {runtime.endpoint.chat_template.name}") diff --git a/examples/quick_start/openai_example_chat.py b/examples/frontend_language/quick_start/openai_example_chat.py similarity index 100% rename from examples/quick_start/openai_example_chat.py rename to examples/frontend_language/quick_start/openai_example_chat.py diff --git a/examples/quick_start/openai_example_complete.py b/examples/frontend_language/quick_start/openai_example_complete.py similarity index 100% rename from examples/quick_start/openai_example_complete.py rename to examples/frontend_language/quick_start/openai_example_complete.py diff --git a/examples/quick_start/openrouter_example_chat.py b/examples/frontend_language/quick_start/openrouter_example_chat.py similarity index 100% rename from examples/quick_start/openrouter_example_chat.py rename to examples/frontend_language/quick_start/openrouter_example_chat.py diff --git a/examples/quick_start/together_example_chat.py b/examples/frontend_language/quick_start/together_example_chat.py similarity index 100% rename from examples/quick_start/together_example_chat.py rename to examples/frontend_language/quick_start/together_example_chat.py diff --git a/examples/quick_start/together_example_complete.py b/examples/frontend_language/quick_start/together_example_complete.py similarity index 100% rename from examples/quick_start/together_example_complete.py rename to examples/frontend_language/quick_start/together_example_complete.py diff --git a/examples/usage/chinese_regex.py b/examples/frontend_language/usage/chinese_regex.py similarity index 100% rename from examples/usage/chinese_regex.py rename to examples/frontend_language/usage/chinese_regex.py diff --git a/examples/usage/choices_logprob.py b/examples/frontend_language/usage/choices_logprob.py similarity index 100% rename from examples/usage/choices_logprob.py rename to examples/frontend_language/usage/choices_logprob.py diff --git a/examples/usage/cot_decoding.py b/examples/frontend_language/usage/cot_decoding.py similarity index 100% rename from examples/usage/cot_decoding.py rename to examples/frontend_language/usage/cot_decoding.py diff --git a/examples/usage/json_decode.py b/examples/frontend_language/usage/json_decode.py similarity index 95% rename from examples/usage/json_decode.py rename to examples/frontend_language/usage/json_decode.py index dc34d3527b..ce8f5ba706 100644 --- a/examples/usage/json_decode.py +++ b/examples/frontend_language/usage/json_decode.py @@ -35,6 +35,9 @@ def character_gen(s, name): name + " is a character in Harry Potter. Please fill in the following information about this character.\n" ) + s += "The constrained regex is:\n" + s += character_regex + "\n" + s += "The JSON output is:\n" s += sgl.gen("json_output", max_tokens=256, regex=character_regex) diff --git a/examples/usage/json_logprobs.py b/examples/frontend_language/usage/json_logprobs.py similarity index 100% rename from examples/usage/json_logprobs.py rename to examples/frontend_language/usage/json_logprobs.py diff --git a/examples/usage/llava_video/srt_example_llava_v.py b/examples/frontend_language/usage/llava_video/srt_example_llava_v.py similarity index 92% rename from examples/usage/llava_video/srt_example_llava_v.py rename to examples/frontend_language/usage/llava_video/srt_example_llava_v.py index 27ba862d30..1f2931a5a4 100644 --- a/examples/usage/llava_video/srt_example_llava_v.py +++ b/examples/frontend_language/usage/llava_video/srt_example_llava_v.py @@ -1,7 +1,8 @@ """ Usage: pip install opencv-python-headless -python3 srt_example_llava.py + +python3 srt_example_llava_v.py """ import argparse @@ -9,6 +10,8 @@ import os import time +import requests + import sglang as sgl @@ -121,6 +124,20 @@ def batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size= if __name__ == "__main__": + url = "https://raw.githubusercontent.com/EvolvingLMMs-Lab/sglang/dev/onevision_local/assets/jobs.mp4" + + cache_dir = os.path.expanduser("~/.cache") + file_path = os.path.join(cache_dir, "jobs.mp4") + + os.makedirs(cache_dir, exist_ok=True) + + response = requests.get(url) + response.raise_for_status() # Raise an exception for bad responses + + with open(file_path, "wb") as f: + f.write(response.content) + + print(f"File downloaded and saved to: {file_path}") # Create the parser parser = argparse.ArgumentParser( description="Run video processing with specified port." @@ -148,7 +165,7 @@ def batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size= parser.add_argument( "--video-dir", type=str, - default="./videos/Q98Z4OTh8RwmDonc.mp4", + default=os.path.expanduser("~/.cache/jobs.mp4"), help="The directory or path for the processed video files.", ) parser.add_argument( @@ -167,13 +184,9 @@ def batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size= # Parse the arguments args = parser.parse_args() - cur_port = args.port - cur_chunk = args.chunk_idx - num_chunks = args.num_chunks - num_frames = args.num_frames if "34b" in args.model_path.lower(): @@ -185,7 +198,6 @@ def batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size= exit() model_overide_args = {} - model_overide_args["mm_spatial_pool_stride"] = args.mm_spatial_pool_stride model_overide_args["architectures"] = ["LlavaVidForCausalLM"] model_overide_args["num_frames"] = args.num_frames @@ -218,7 +230,6 @@ def batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size= print(f"chat template: {runtime.endpoint.chat_template.name}") # Run a single request - # try: print("\n========== single ==========\n") root = args.video_dir if os.path.isfile(root): @@ -240,13 +251,10 @@ def batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size= ) # Calculate the average processing time print(f"Average processing time per video: {average_time:.2f} seconds") runtime.shutdown() - # except Exception as e: - # print(e) - runtime.shutdown() - # # # Run a batch of requests + # # Run a batch of requests # print("\n========== batch ==========\n") # if not os.path.exists(args.save_dir): # os.makedirs(args.save_dir) - # batch(args.video_dir,args.save_dir,cur_chunk, num_chunks, num_frames, num_chunks) + # batch(args.video_dir, args.save_dir, cur_chunk, num_chunks, num_frames, num_chunks) # runtime.shutdown() diff --git a/examples/usage/llava_video/srt_example_llava_v.sh b/examples/frontend_language/usage/llava_video/srt_example_llava_v.sh similarity index 100% rename from examples/usage/llava_video/srt_example_llava_v.sh rename to examples/frontend_language/usage/llava_video/srt_example_llava_v.sh diff --git a/examples/usage/openai_chat_speculative.py b/examples/frontend_language/usage/openai_chat_speculative.py similarity index 100% rename from examples/usage/openai_chat_speculative.py rename to examples/frontend_language/usage/openai_chat_speculative.py diff --git a/examples/usage/openai_speculative.py b/examples/frontend_language/usage/openai_speculative.py similarity index 100% rename from examples/usage/openai_speculative.py rename to examples/frontend_language/usage/openai_speculative.py diff --git a/examples/usage/parallel_sample.py b/examples/frontend_language/usage/parallel_sample.py similarity index 100% rename from examples/usage/parallel_sample.py rename to examples/frontend_language/usage/parallel_sample.py diff --git a/examples/usage/rag_using_parea/trace_and_evaluate_rag_using_parea.ipynb b/examples/frontend_language/usage/rag_using_parea/trace_and_evaluate_rag_using_parea.ipynb similarity index 100% rename from examples/usage/rag_using_parea/trace_and_evaluate_rag_using_parea.ipynb rename to examples/frontend_language/usage/rag_using_parea/trace_and_evaluate_rag_using_parea.ipynb diff --git a/examples/usage/readme_examples.py b/examples/frontend_language/usage/readme_examples.py similarity index 100% rename from examples/usage/readme_examples.py rename to examples/frontend_language/usage/readme_examples.py diff --git a/examples/usage/streaming.py b/examples/frontend_language/usage/streaming.py similarity index 100% rename from examples/usage/streaming.py rename to examples/frontend_language/usage/streaming.py diff --git a/examples/usage/triton/Dockerfile b/examples/frontend_language/usage/triton/Dockerfile similarity index 100% rename from examples/usage/triton/Dockerfile rename to examples/frontend_language/usage/triton/Dockerfile diff --git a/examples/usage/triton/README.md b/examples/frontend_language/usage/triton/README.md similarity index 100% rename from examples/usage/triton/README.md rename to examples/frontend_language/usage/triton/README.md diff --git a/examples/usage/triton/models/character_generation/1/model.py b/examples/frontend_language/usage/triton/models/character_generation/1/model.py similarity index 100% rename from examples/usage/triton/models/character_generation/1/model.py rename to examples/frontend_language/usage/triton/models/character_generation/1/model.py diff --git a/examples/usage/triton/models/character_generation/config.pbtxt b/examples/frontend_language/usage/triton/models/character_generation/config.pbtxt similarity index 100% rename from examples/usage/triton/models/character_generation/config.pbtxt rename to examples/frontend_language/usage/triton/models/character_generation/config.pbtxt diff --git a/examples/quick_start/srt_example_yi_vl.py b/examples/quick_start/srt_example_yi_vl.py deleted file mode 100644 index 66c7d57126..0000000000 --- a/examples/quick_start/srt_example_yi_vl.py +++ /dev/null @@ -1,70 +0,0 @@ -""" -Usage: python3 srt_example_yi_vl.py - -Requirements: transformers==4.38 -""" - -import sglang as sgl - - -@sgl.function -def image_qa(s, image_path, question): - s += sgl.user(sgl.image(image_path) + question) - s += sgl.assistant(sgl.gen("answer")) - - -def single(): - state = image_qa.run( - image_path="images/cat.jpeg", - question="What is this?", - max_new_tokens=64, - stop="###", - ) - print(state["answer"], "\n") - - -def stream(): - state = image_qa.run( - image_path="images/cat.jpeg", - question="What is this?", - max_new_tokens=64, - stream=True, - stop="###", - ) - - for out in state.text_iter("answer"): - print(out, end="", flush=True) - print() - - -def batch(): - states = image_qa.run_batch( - [ - {"image_path": "images/cat.jpeg", "question": "What is this?"}, - {"image_path": "images/dog.jpeg", "question": "What is this?"}, - ], - max_new_tokens=64, - stop="###", - ) - for s in states: - print(s["answer"], "\n") - - -if __name__ == "__main__": - runtime = sgl.Runtime(model_path="BabyChou/Yi-VL-6B") - # runtime = sgl.Runtime(model_path="BabyChou/Yi-VL-34B") - sgl.set_default_backend(runtime) - - # Run a single request - print("\n========== single ==========\n") - single() - - # Stream output - print("\n========== stream ==========\n") - stream() - - # Run a batch of requests - print("\n========== batch ==========\n") - batch() - - runtime.shutdown() diff --git a/examples/usage/async_io.py b/examples/runtime/async_io_api.py similarity index 100% rename from examples/usage/async_io.py rename to examples/runtime/async_io_api.py diff --git a/examples/usage/llava/http_llama3_llava_test.py b/examples/runtime/llava_onevision/http_llama3_llava_test.py similarity index 94% rename from examples/usage/llava/http_llama3_llava_test.py rename to examples/runtime/llava_onevision/http_llama3_llava_test.py index 813a26af53..a019e214d6 100644 --- a/examples/usage/llava/http_llama3_llava_test.py +++ b/examples/runtime/llava_onevision/http_llama3_llava_test.py @@ -4,7 +4,7 @@ # Installing latest sglang. # Endpoint Service CLI: -# python -m sglang.launch_server --model-path lmms-lab/llama3-llava-next-8b --tokenizer-path lmms-lab/llama3-llava-next-8b-tokenizer --port=30000 --host="127.0.0.1" --tp-size=4 +python -m sglang.launch_server --model-path lmms-lab/llama3-llava-next-8b --port=30000 python3 http_llama3_llava_test.py @@ -16,7 +16,6 @@ import asyncio import copy import json -import time import aiohttp import requests diff --git a/examples/runtime/llava_onevision/http_llava_onevision_test.py b/examples/runtime/llava_onevision/http_llava_onevision_test.py new file mode 100644 index 0000000000..0c93d2ce2b --- /dev/null +++ b/examples/runtime/llava_onevision/http_llava_onevision_test.py @@ -0,0 +1,264 @@ +""" +Usage: + +python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --port=30000 --tp-size=8 --chat-template=chatml-llava + +python3 http_llava_onevision_test.py +""" + +import base64 +import io +import os +import sys +import time + +import numpy as np +import openai +import requests +from decord import VideoReader, cpu +from PIL import Image + +# pip install httpx==0.23.3 +# pip install decord +# pip install protobuf==3.20.0 + + +def download_video(url, cache_dir): + file_path = os.path.join(cache_dir, "jobs.mp4") + os.makedirs(cache_dir, exist_ok=True) + + response = requests.get(url) + response.raise_for_status() + + with open(file_path, "wb") as f: + f.write(response.content) + + print(f"File downloaded and saved to: {file_path}") + return file_path + + +def create_openai_client(base_url): + return openai.Client(api_key="EMPTY", base_url=base_url) + + +def image_stream_request_test(client): + print("----------------------Image Stream Request Test----------------------") + stream_request = client.chat.completions.create( + model="default", + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png" + }, + }, + { + "type": "text", + "text": "Please describe this image. Please list the benchmarks and the models.", + }, + ], + }, + ], + temperature=0.7, + max_tokens=1024, + stream=True, + ) + stream_response = "" + + for chunk in stream_request: + if chunk.choices[0].delta.content is not None: + content = chunk.choices[0].delta.content + stream_response += content + sys.stdout.write(content) + sys.stdout.flush() + + print("-" * 30) + + +def multi_image_stream_request_test(client): + print( + "----------------------Multi-Images Stream Request Test----------------------" + ) + stream_request = client.chat.completions.create( + model="default", + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png" + }, + }, + { + "type": "image_url", + "image_url": { + "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png" + }, + }, + { + "type": "text", + "text": "I have shown you two images. Please describe the two images to me.", + }, + ], + }, + ], + temperature=0.7, + max_tokens=1024, + stream=True, + ) + stream_response = "" + + for chunk in stream_request: + if chunk.choices[0].delta.content is not None: + content = chunk.choices[0].delta.content + stream_response += content + sys.stdout.write(content) + sys.stdout.flush() + + print("-" * 30) + + +def video_stream_request_test(client, video_path): + print("------------------------Video Stream Request Test----------------------") + messages = prepare_video_messages(video_path) + + video_request = client.chat.completions.create( + model="default", + messages=messages, + temperature=0, + max_tokens=1024, + stream=True, + ) + print("-" * 30) + video_response = "" + + for chunk in video_request: + if chunk.choices[0].delta.content is not None: + content = chunk.choices[0].delta.content + video_response += content + sys.stdout.write(content) + sys.stdout.flush() + print("-" * 30) + + +def image_speed_test(client): + print("----------------------Image Speed Test----------------------") + start_time = time.time() + request = client.chat.completions.create( + model="default", + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png" + }, + }, + { + "type": "text", + "text": "Please describe this image. Please list the benchmarks and the models.", + }, + ], + }, + ], + temperature=0, + max_tokens=1024, + ) + end_time = time.time() + response = request.choices[0].message.content + print(response) + print("-" * 30) + print_speed_test_results(request, start_time, end_time) + + +def video_speed_test(client, video_path): + print("------------------------Video Speed Test------------------------") + messages = prepare_video_messages(video_path) + + start_time = time.time() + video_request = client.chat.completions.create( + model="default", + messages=messages, + temperature=0, + max_tokens=1024, + ) + end_time = time.time() + video_response = video_request.choices[0].message.content + print(video_response) + print("-" * 30) + print_speed_test_results(video_request, start_time, end_time) + + +def prepare_video_messages(video_path): + max_frames_num = 32 + vr = VideoReader(video_path, ctx=cpu(0)) + total_frame_num = len(vr) + uniform_sampled_frames = np.linspace( + 0, total_frame_num - 1, max_frames_num, dtype=int + ) + frame_idx = uniform_sampled_frames.tolist() + frames = vr.get_batch(frame_idx).asnumpy() + + base64_frames = [] + for frame in frames: + pil_img = Image.fromarray(frame) + buff = io.BytesIO() + pil_img.save(buff, format="JPEG") + base64_str = base64.b64encode(buff.getvalue()).decode("utf-8") + base64_frames.append(base64_str) + + messages = [{"role": "user", "content": []}] + frame_format = { + "type": "image_url", + "image_url": {"url": "data:image/jpeg;base64,{}"}, + } + + for base64_frame in base64_frames: + frame_format["image_url"]["url"] = "data:image/jpeg;base64,{}".format( + base64_frame + ) + messages[0]["content"].append(frame_format.copy()) + + prompt = {"type": "text", "text": "Please describe the video in detail."} + messages[0]["content"].append(prompt) + + return messages + + +def print_speed_test_results(request, start_time, end_time): + total_tokens = request.usage.total_tokens + completion_tokens = request.usage.completion_tokens + prompt_tokens = request.usage.prompt_tokens + + print(f"Total tokens: {total_tokens}") + print(f"Completion tokens: {completion_tokens}") + print(f"Prompt tokens: {prompt_tokens}") + print(f"Time taken: {end_time - start_time} seconds") + print(f"Token per second: {total_tokens / (end_time - start_time)}") + print(f"Completion token per second: {completion_tokens / (end_time - start_time)}") + print(f"Prompt token per second: {prompt_tokens / (end_time - start_time)}") + + +def main(): + url = "https://raw.githubusercontent.com/EvolvingLMMs-Lab/sglang/dev/onevision_local/assets/jobs.mp4" + cache_dir = os.path.expanduser("~/.cache") + video_path = download_video(url, cache_dir) + + client = create_openai_client("http://127.0.0.1:30000/v1") + + image_stream_request_test(client) + multi_image_stream_request_test(client) + video_stream_request_test(client, video_path) + image_speed_test(client) + video_speed_test(client, video_path) + + +if __name__ == "__main__": + main() diff --git a/examples/usage/llava/http_qwen_llava_test.py b/examples/runtime/llava_onevision/http_qwen_llava_test.py similarity index 95% rename from examples/usage/llava/http_qwen_llava_test.py rename to examples/runtime/llava_onevision/http_qwen_llava_test.py index 1c29658c60..dca56e7a33 100644 --- a/examples/usage/llava/http_qwen_llava_test.py +++ b/examples/runtime/llava_onevision/http_qwen_llava_test.py @@ -4,7 +4,7 @@ # Installing latest sglang. # Endpoint Service CLI: -# python -m sglang.launch_server --model-path lmms-lab/llava-next-72b --tokenizer-path lmms-lab/llavanext-qwen-tokenizer --port=30000 --host="127.0.0.1" --tp-size=4 +python -m sglang.launch_server --model-path lmms-lab/llava-next-72b --port=30000 --tp-size=8 python3 http_qwen_llava_test.py @@ -16,7 +16,6 @@ import asyncio import copy import json -import time import aiohttp import requests diff --git a/examples/usage/openai_batch_chat.py b/examples/runtime/openai_batch_chat.py similarity index 100% rename from examples/usage/openai_batch_chat.py rename to examples/runtime/openai_batch_chat.py diff --git a/examples/usage/openai_batch_complete.py b/examples/runtime/openai_batch_complete.py similarity index 100% rename from examples/usage/openai_batch_complete.py rename to examples/runtime/openai_batch_complete.py diff --git a/examples/usage/llava/srt_llava_next_test.py b/examples/usage/llava/srt_llava_next_test.py deleted file mode 100644 index 0f9621648a..0000000000 --- a/examples/usage/llava/srt_llava_next_test.py +++ /dev/null @@ -1,90 +0,0 @@ -""" -Usage: python3 srt_example_llava.py -""" - -from PIL import ImageFile - -import sglang as sgl -from sglang.lang.chat_template import get_chat_template -from sglang.srt.utils import load_image - -ImageFile.LOAD_TRUNCATED_IMAGES = True # Allow loading of truncated images - - -@sgl.function -def image_qa(s, image, question): - s += sgl.user(sgl.image(image) + question) - s += sgl.assistant(sgl.gen("answer")) - - -def single(): - image_url = "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg" - pil_image, _ = load_image(image_url) - state = image_qa.run(image=pil_image, question="What is this?", max_new_tokens=512) - print(state["answer"], "\n") - - -def stream(): - image_url = "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg" - pil_image, _ = load_image(image_url) - state = image_qa.run( - image=pil_image, - question="Please generate short caption for this image.", - max_new_tokens=512, - temperature=0, - stream=True, - ) - - for out in state.text_iter("answer"): - print(out, end="", flush=True) - print() - - -def batch(): - image_url = "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg" - pil_image, _ = load_image(image_url) - states = image_qa.run_batch( - [ - {"image": pil_image, "question": "What is this?"}, - {"image": pil_image, "question": "What is this?"}, - ], - max_new_tokens=512, - ) - for s in states: - print(s["answer"], "\n") - - -if __name__ == "__main__": - import multiprocessing as mp - - mp.set_start_method("spawn", force=True) - runtime = sgl.Runtime( - model_path="lmms-lab/llama3-llava-next-8b", - tokenizer_path="lmms-lab/llama3-llava-next-8b-tokenizer", - ) - runtime.endpoint.chat_template = get_chat_template("llama-3-instruct") - # runtime = sgl.Runtime( - # model_path="lmms-lab/llava-next-72b", - # tokenizer_path="lmms-lab/llavanext-qwen-tokenizer", - # ) - # runtime.endpoint.chat_template = get_chat_template("chatml-llava") - sgl.set_default_backend(runtime) - print(f"chat template: {runtime.endpoint.chat_template.name}") - - # Or you can use API models - # sgl.set_default_backend(sgl.OpenAI("gpt-4-vision-preview")) - # sgl.set_default_backend(sgl.VertexAI("gemini-pro-vision")) - - # Run a single request - print("\n========== single ==========\n") - single() - - # Stream output - print("\n========== stream ==========\n") - stream() - - # Run a batch of requests - print("\n========== batch ==========\n") - batch() - - runtime.shutdown() diff --git a/examples/usage/llava_video/videos/Q98Z4OTh8RwmDonc.mp4 b/examples/usage/llava_video/videos/Q98Z4OTh8RwmDonc.mp4 deleted file mode 100644 index 32d912dbfa..0000000000 Binary files a/examples/usage/llava_video/videos/Q98Z4OTh8RwmDonc.mp4 and /dev/null differ diff --git a/examples/usage/openai_parallel_sample.py b/examples/usage/openai_parallel_sample.py deleted file mode 100644 index 753e66c744..0000000000 --- a/examples/usage/openai_parallel_sample.py +++ /dev/null @@ -1,153 +0,0 @@ -import openai - -client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="EMPTY") - -# Text completion -response = client.completions.create( - model="default", - prompt="I am a robot and I want to study like humans. Now let's tell a story. Once upon a time, there was a little", - n=1, - temperature=0.8, - max_tokens=32, -) -print(response) - - -# Text completion -response = client.completions.create( - model="default", - prompt="I am a robot and I want to study like humans. Now let's tell a story. Once upon a time, there was a little", - n=5, - temperature=0.8, - max_tokens=320, -) -print(response) - - -# Text completion -response = client.completions.create( - model="default", - prompt="I am a robot and I want to study like humans. Now let's tell a story. Once upon a time, there was a little", - n=3, - temperature=0.8, - max_tokens=32, -) -print(response) - - -# Text completion -response = client.completions.create( - model="default", - prompt=["The name of the famous soccer player is"], - n=1, - temperature=0.8, - max_tokens=128, -) -print(response) - - -# Text completion -response = client.completions.create( - model="default", - prompt=["The name of the famous soccer player is ", "The capital of US is"], - n=1, - temperature=0.8, - max_tokens=32, -) -print(response) - - -# Text completion -response = client.completions.create( - model="default", - prompt=["The name of the famous soccer player is ", "The capital of US is"], - n=3, - temperature=0.8, - max_tokens=32, -) -print(response) - - -response = client.completions.create( - model="default", - prompt=[ - "prompt1: I am a robot and I want to learn like humans. Now let's begin a tale. Once upon a time, there was a small", - "prompt2: As a robot, my goal is to understand human learning. Let's start a story. In a faraway land, there lived a tiny", - "prompt3: Being a robot, I aspire to study like people. Let's share a story. Long ago, there was a little", - "prompt4: I am a robot aiming to learn like humans. Let's narrate a story. Once, in a distant kingdom, there was a young", - "prompt5: As a robot, I seek to learn in human ways. Let's tell a story. Once upon a time, in a small village, there was a young", - ], - n=1, - temperature=0.8, - max_tokens=320, -) -print(response) - - -# Text completion -response = client.completions.create( - model="default", - prompt=[ - "The capital of France is", - "The capital of Germany is", - "The capital of US is", - ], - n=3, - temperature=0.8, - max_tokens=32, -) -print(response) - -# Chat completion -response = client.chat.completions.create( - model="default", - messages=[ - {"role": "system", "content": "You are a helpful AI assistant"}, - {"role": "user", "content": "List 3 countries and their capitals."}, - ], - temperature=0.8, - max_tokens=1, - logprobs=True, - top_logprobs=3, -) -print(response) - -# Chat completion -response = client.chat.completions.create( - model="default", - messages=[ - {"role": "system", "content": "You are a helpful AI assistant"}, - {"role": "user", "content": "List 3 countries and their capitals."}, - ], - temperature=0.8, - max_tokens=1, - n=1, -) -print(response) - -# Chat completion -response = client.chat.completions.create( - model="default", - messages=[ - {"role": "system", "content": "You are a helpful AI assistant"}, - {"role": "user", "content": "List 3 countries and their capitals."}, - ], - temperature=0.8, - max_tokens=1, - logprobs=True, - top_logprobs=3, -) -print(response) - -# Chat completion -response = client.chat.completions.create( - model="default", - messages=[ - {"role": "system", "content": "You are a helpful AI assistant"}, - {"role": "user", "content": "List 3 countries and their capitals."}, - ], - temperature=0.8, - max_tokens=1, - n=4, -) -print(response) diff --git a/examples/usage/rag_using_parea/max-tokens-fixed-rag-trace.png b/examples/usage/rag_using_parea/max-tokens-fixed-rag-trace.png deleted file mode 100644 index 2ea09fdc60..0000000000 Binary files a/examples/usage/rag_using_parea/max-tokens-fixed-rag-trace.png and /dev/null differ diff --git a/python/pyproject.toml b/python/pyproject.toml index 32d4912a3a..87c99bffae 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sglang" -version = "0.2.11" +version = "0.2.14.post2" description = "SGLang is yet another fast serving framework for large language models and vision language models." readme = "README.md" requires-python = ">=3.8" @@ -20,14 +20,14 @@ dependencies = [ ] [project.optional-dependencies] -srt = ["aiohttp", "fastapi", "hf_transfer", "huggingface_hub", "interegular", +srt = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular", "packaging", "pillow", "psutil", "pydantic", "python-multipart", "torch", "uvicorn", "uvloop", "zmq", - "vllm==0.5.4", "outlines>=0.0.44"] + "vllm==0.5.5", "outlines>=0.0.44"] openai = ["openai>=1.0", "tiktoken"] anthropic = ["anthropic>=0.20.0"] litellm = ["litellm>=1.0.0"] -test = ["jsonlines", "matplotlib", "pandas"] +test = ["jsonlines", "matplotlib", "pandas", "sentence_transformers", "accelerate"] all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"] dev = ["sglang[all]", "sglang[test]"] diff --git a/python/sglang/api.py b/python/sglang/api.py index 5a177c36b0..9405606b71 100644 --- a/python/sglang/api.py +++ b/python/sglang/api.py @@ -62,9 +62,11 @@ def gen( name: Optional[str] = None, max_tokens: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, + stop_token_ids: Optional[List[int]] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, top_k: Optional[int] = None, + min_p: Optional[float] = None, frequency_penalty: Optional[float] = None, presence_penalty: Optional[float] = None, ignore_eos: Optional[bool] = None, @@ -72,10 +74,11 @@ def gen( logprob_start_len: Optional[int] = None, top_logprobs_num: Optional[int] = None, return_text_in_logprobs: Optional[bool] = None, - dtype: Optional[type] = None, + dtype: Optional[Union[type, str]] = None, choices: Optional[List[str]] = None, choices_method: Optional[ChoicesSamplingMethod] = None, regex: Optional[str] = None, + json_schema: Optional[str] = None, ): """Call the model to generate. See the meaning of the arguments in docs/en/sampling_params.md""" @@ -98,9 +101,11 @@ def gen( name, max_tokens, stop, + stop_token_ids, temperature, top_p, top_k, + min_p, frequency_penalty, presence_penalty, ignore_eos, @@ -110,6 +115,7 @@ def gen( return_text_in_logprobs, dtype, regex, + json_schema, ) @@ -117,9 +123,11 @@ def gen_int( name: Optional[str] = None, max_tokens: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, + stop_token_ids: Optional[List[int]] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, top_k: Optional[int] = None, + min_p: Optional[float] = None, frequency_penalty: Optional[float] = None, presence_penalty: Optional[float] = None, ignore_eos: Optional[bool] = None, @@ -132,9 +140,11 @@ def gen_int( name, max_tokens, stop, + stop_token_ids, temperature, top_p, top_k, + min_p, frequency_penalty, presence_penalty, ignore_eos, @@ -151,9 +161,11 @@ def gen_string( name: Optional[str] = None, max_tokens: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, + stop_token_ids: Optional[List[int]] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, top_k: Optional[int] = None, + min_p: Optional[float] = None, frequency_penalty: Optional[float] = None, presence_penalty: Optional[float] = None, ignore_eos: Optional[bool] = None, @@ -166,9 +178,11 @@ def gen_string( name, max_tokens, stop, + stop_token_ids, temperature, top_p, top_k, + min_p, frequency_penalty, presence_penalty, ignore_eos, diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index c2b956e1da..3a48740857 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -54,7 +54,7 @@ from sglang.srt.model_config import ModelConfig from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.model_runner import ModelRunner -from sglang.srt.sampling_params import SamplingParams +from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import ServerArgs from sglang.srt.utils import suppress_other_loggers @@ -64,7 +64,7 @@ class BenchArgs: run_name: str = "before" batch_size: Tuple[int] = (1,) input_len: Tuple[int] = (1024,) - output_len: Tuple[int] = (4,) + output_len: Tuple[int] = (16,) result_filename: str = "" correctness_test: bool = False # This is only used for correctness test @@ -111,7 +111,11 @@ def load_model(server_args, tp_rank): suppress_other_loggers() rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None - model_config = ModelConfig(path=server_args.model_path) + model_config = ModelConfig( + server_args.model_path, + server_args.trust_remote_code, + context_length=server_args.context_length, + ) model_runner = ModelRunner( model_config=model_config, mem_fraction_static=server_args.mem_fraction_static, @@ -195,17 +199,17 @@ def extend(reqs, model_runner): token_to_kv_pool=model_runner.token_to_kv_pool, tree_cache=None, ) - batch.prepare_for_extend(model_runner.model_config.vocab_size, None) - output = model_runner.forward(batch, ForwardMode.EXTEND) - next_token_ids = batch.sample(output.next_token_logits) - return next_token_ids, output.next_token_logits, batch + batch.prepare_for_extend(model_runner.model_config.vocab_size) + sample_output, logits_output = model_runner.forward(batch, ForwardMode.EXTEND) + next_token_ids = sample_output.batch_next_token_ids.tolist() + return next_token_ids, logits_output.next_token_logits, batch def decode(input_token_ids, batch, model_runner): - batch.prepare_for_decode(input_token_ids.cpu().numpy()) - output = model_runner.forward(batch, ForwardMode.DECODE) - next_token_ids = batch.sample(output.next_token_logits) - return next_token_ids, output.next_token_logits + batch.prepare_for_decode(input_token_ids) + sample_output, logits_output = model_runner.forward(batch, ForwardMode.DECODE) + next_token_ids = sample_output.batch_next_token_ids.tolist() + return next_token_ids, logits_output.next_token_logits @torch.inference_mode() @@ -221,6 +225,7 @@ def correctness_test( # Prepare inputs input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer) + rank_print(f"{input_ids=}") if bench_args.cut_len > 0: # Prefill @@ -349,7 +354,7 @@ def latency_test( for bs, il, ol in itertools.product( bench_args.batch_size, bench_args.input_len, bench_args.output_len ): - req = prepare_synthetic_inputs_for_latency_test(bs, il) + reqs = prepare_synthetic_inputs_for_latency_test(bs, il) ret = latency_test_run_once( bench_args.run_name, model_runner, rank_print, reqs, bs, il, ol ) diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index e3a2ad0a2c..69d175d843 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -149,10 +149,12 @@ async def async_request_openai_completions( "completions" ), "OpenAI Completions API URL must end with 'completions'." + prompt = request_func_input.prompt + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: payload = { "model": request_func_input.model, - "prompt": request_func_input.prompt, + "prompt": prompt, "temperature": 0.0, "best_of": 1, "max_tokens": request_func_input.output_len, @@ -220,6 +222,13 @@ async def async_request_openai_completions( return output +async def async_request_gserver( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + raise NotImplementedError() + + def get_model(pretrained_model_name_or_path: str) -> str: if os.getenv("SGLANG_USE_MODELSCOPE", "False").lower() == "true": import huggingface_hub.constants @@ -238,6 +247,13 @@ def get_model(pretrained_model_name_or_path: str) -> str: def get_tokenizer( pretrained_model_name_or_path: str, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + if pretrained_model_name_or_path.endswith( + ".json" + ) or pretrained_model_name_or_path.endswith(".model"): + from sglang.srt.hf_transformers_utils import get_tokenizer + + return get_tokenizer(pretrained_model_name_or_path) + if pretrained_model_name_or_path is not None and not os.path.exists( pretrained_model_name_or_path ): @@ -252,6 +268,7 @@ def get_tokenizer( "vllm": async_request_openai_completions, "lmdeploy": async_request_openai_completions, "trt": async_request_trt_llm, + "gserver": async_request_gserver, } @@ -351,9 +368,9 @@ def sample_sharegpt_requests( # Tokenize the prompts and completions. prompt = dataset[i][0] - prompt_token_ids = tokenizer(prompt).input_ids + prompt_token_ids = tokenizer.encode(prompt) completion = dataset[i][1] - completion_token_ids = tokenizer(completion).input_ids + completion_token_ids = tokenizer.encode(completion) prompt_len = len(prompt_token_ids) output_len = ( len(completion_token_ids) if fixed_output_len is None else fixed_output_len @@ -361,7 +378,9 @@ def sample_sharegpt_requests( if prompt_len < 4 or output_len < 4: # Prune too short sequences. continue - if prompt_len > 1024 or prompt_len + output_len > 2048: + if prompt_len > 1024 or ( + prompt_len + output_len > 2048 and fixed_output_len is None + ): # Prune too long sequences. continue filtered_dataset.append((prompt, prompt_len, output_len)) @@ -422,7 +441,7 @@ def sample_random_requests( for i in range(num_prompts): # Tokenize the prompts and completions. prompt = dataset[i][0] - prompt_token_ids = tokenizer(prompt).input_ids + prompt_token_ids = tokenizer.encode(prompt) prompt_len = len(prompt_token_ids) if prompt_len > input_lens[i]: @@ -488,7 +507,7 @@ def calculate_metrics( output_len = outputs[i].output_len output_lens.append(output_len) retokenized_output_len = len( - tokenizer(outputs[i].generated_text, add_special_tokens=False).input_ids + tokenizer.encode(outputs[i].generated_text, add_special_tokens=False) ) retokenized_output_lens.append(retokenized_output_len) total_input += input_requests[i][1] @@ -547,7 +566,6 @@ async def benchmark( input_requests: List[Tuple[str, int, int]], request_rate: float, disable_tqdm: bool, - enable_multi: bool, extra_request_body: Dict[str, Any], ): if backend in ASYNC_REQUEST_FUNCS: @@ -669,19 +687,20 @@ async def benchmark( "backend": args.backend, "dataset_name": args.dataset_name, "request_rate": request_rate, - "total_input": metrics.total_input, - "total_output": metrics.total_output, - "total_output_retokenized": metrics.total_output_retokenized, - "mean_e2e_latency": metrics.mean_e2e_latency_ms, - "median_e2e_latency": metrics.median_e2e_latency_ms, - "median_ttft": metrics.median_ttft_ms, - "median_itl": metrics.median_itl_ms, - "output_token_throughput": metrics.output_throughput, + "total_input_tokens": metrics.total_input, + "total_output_tokens": metrics.total_output, + "total_output_tokens_retokenized": metrics.total_output_retokenized, + "mean_e2e_latency_ms": metrics.mean_e2e_latency_ms, + "median_e2e_latency_ms": metrics.median_e2e_latency_ms, + "median_ttft_ms": metrics.median_ttft_ms, + "median_itl_ms": metrics.median_itl_ms, + "output_throughput": metrics.output_throughput, "sharegpt_output_len": args.sharegpt_output_len, "random_input_len": args.random_input_len, "random_output_len": args.random_output_len, "random_range_ratio": args.random_range_ratio, - "benchmark_duration": benchmark_duration, + "duration": benchmark_duration, + "completed": metrics.completed, } else: print(f"Error running benchmark for request rate: {request_rate}") @@ -755,6 +774,7 @@ def run_benchmark(args_: argparse.Namespace): global args args = args_ + # Set global environments set_ulimit() random.seed(args.seed) np.random.seed(args.seed) @@ -763,12 +783,14 @@ def run_benchmark(args_: argparse.Namespace): if args.extra_request_body: extra_request_body = json.loads(args.extra_request_body) + # Set url if args.port is None: args.port = { "sglang": 30000, "lmdeploy": 23333, "vllm": 8000, "trt": 8000, + "gserver": 9988, }.get(args.backend, 30000) api_url = ( @@ -791,7 +813,11 @@ def run_benchmark(args_: argparse.Namespace): if args.model is None: print("Please provide a model using `--model` when using `trt` backend.") sys.exit(1) + elif args.backend == "gserver": + api_url = args.base_url if args.base_url else f"{args.host}:{args.port}" + args.model = args.model or "default" + # Get model name if args.model is None: try: response = requests.get(model_url) @@ -816,6 +842,7 @@ def run_benchmark(args_: argparse.Namespace): print(f"{args}\n") + # Read dataset backend = args.backend model_id = args.model tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model @@ -841,7 +868,21 @@ def run_benchmark(args_: argparse.Namespace): else: raise ValueError(f"Unknown dataset: {args.dataset_name}") - if args.multi: + if not args.multi: + return asyncio.run( + benchmark( + backend=backend, + api_url=api_url, + model_id=model_id, + tokenizer=tokenizer, + input_requests=input_requests, + request_rate=args.request_rate, + disable_tqdm=args.disable_tqdm, + extra_request_body=extra_request_body, + ) + ) + else: + # Benchmark multiple rps. TODO: use a fixed duration to compute num_prompts request_rates = parse_request_rate_range(args.request_rate_range) for rate in request_rates: @@ -854,27 +895,11 @@ def run_benchmark(args_: argparse.Namespace): input_requests=input_requests, request_rate=rate, disable_tqdm=args.disable_tqdm, - enable_multi=args.multi, extra_request_body=extra_request_body, ) ) - else: - return asyncio.run( - benchmark( - backend=backend, - api_url=api_url, - model_id=model_id, - tokenizer=tokenizer, - input_requests=input_requests, - request_rate=args.request_rate, - disable_tqdm=args.disable_tqdm, - enable_multi=args.multi, - extra_request_body=extra_request_body, - ) - ) -# to avoid relying on SGLang's components def set_ulimit(target_soft_limit=65535): resource_type = resource.RLIMIT_NOFILE current_soft, current_hard = resource.getrlimit(resource_type) @@ -965,9 +990,9 @@ def set_ulimit(target_soft_limit=65535): type=float, default=float("inf"), help="Number of requests per second. If this is inf, then all the requests are sent at time 0. " - "Otherwise, we use Poisson process to synthesize the request arrival times. Default is 128.0.", + "Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.", ) - parser.add_argument("--seed", type=int, default=0, help="Default is 0.") + parser.add_argument("--seed", type=int, default=1, help="The random seed.") parser.add_argument( "--multi", action="store_true", diff --git a/python/sglang/check_env.py b/python/sglang/check_env.py index cc8ba10e00..4db1f82fc0 100644 --- a/python/sglang/check_env.py +++ b/python/sglang/check_env.py @@ -170,6 +170,17 @@ def get_gpu_topology(): return None +def get_hypervisor_vendor(): + try: + output = subprocess.check_output(["lscpu"], text=True) + for line in output.split("\n"): + if "Hypervisor vendor:" in line: + return line.split(":")[1].strip() + return None + except: + return None + + def check_env(): """ Check and print environment information. @@ -184,6 +195,10 @@ def check_env(): if gpu_topo: env_info["NVIDIA Topology"] = gpu_topo + hypervisor_vendor = get_hypervisor_vendor() + if hypervisor_vendor: + env_info["Hypervisor vendor"] = hypervisor_vendor + ulimit_soft, _ = resource.getrlimit(resource.RLIMIT_NOFILE) env_info["ulimit soft"] = ulimit_soft diff --git a/python/sglang/global_config.py b/python/sglang/global_config.py index b02ce9f81e..d5f16e2ae5 100644 --- a/python/sglang/global_config.py +++ b/python/sglang/global_config.py @@ -27,7 +27,7 @@ def __init__(self): # Runtime constants: others self.num_continue_decode_steps = 10 self.retract_decode_steps = 20 - self.flashinfer_workspace_size = 192 * 1024 * 1024 + self.flashinfer_workspace_size = 384 * 1024 * 1024 # Output tokenization configs self.skip_special_tokens_in_output = True diff --git a/python/sglang/lang/backend/runtime_endpoint.py b/python/sglang/lang/backend/runtime_endpoint.py index 7f0db5b359..5012f646ea 100644 --- a/python/sglang/lang/backend/runtime_endpoint.py +++ b/python/sglang/lang/backend/runtime_endpoint.py @@ -1,21 +1,23 @@ import json +import warnings from typing import List, Optional from sglang.global_config import global_config from sglang.lang.backend.base_backend import BaseBackend from sglang.lang.chat_template import get_chat_template_by_model_path -from sglang.lang.choices import ( - ChoicesDecision, - ChoicesSamplingMethod, - token_length_normalized, -) +from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod from sglang.lang.interpreter import StreamExecutor -from sglang.lang.ir import SglSamplingParams +from sglang.lang.ir import ( + REGEX_BOOL, + REGEX_FLOAT, + REGEX_INT, + REGEX_STR, + SglSamplingParams, +) from sglang.utils import http_request class RuntimeEndpoint(BaseBackend): - def __init__( self, base_url: str, @@ -95,32 +97,52 @@ def fill_image(self, s: StreamExecutor): ) self._assert_success(res) + def _handle_dtype_to_regex(self, sampling_params: SglSamplingParams): + if sampling_params.dtype is None: + return + + if sampling_params.stop == (): + sampling_params.stop = [] + + dtype_regex = None + if sampling_params.dtype in ["int", int]: + + dtype_regex = REGEX_INT + sampling_params.stop.extend([" ", "\n"]) + elif sampling_params.dtype in ["float", float]: + + dtype_regex = REGEX_FLOAT + sampling_params.stop.extend([" ", "\n"]) + elif sampling_params.dtype in ["str", str]: + + dtype_regex = REGEX_STR + elif sampling_params.dtype in ["bool", bool]: + + dtype_regex = REGEX_BOOL + else: + raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}") + + if dtype_regex is not None and sampling_params.regex is not None: + warnings.warn( + f"Both dtype and regex are set. Only dtype will be used. dtype: {sampling_params.dtype}, regex: {sampling_params.regex}" + ) + + sampling_params.regex = dtype_regex + def generate( self, s: StreamExecutor, sampling_params: SglSamplingParams, ): - if sampling_params.dtype is None: - data = { - "text": s.text_, - "sampling_params": { - "skip_special_tokens": global_config.skip_special_tokens_in_output, - "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out, - **sampling_params.to_srt_kwargs(), - }, - } - elif sampling_params.dtype in [int, "int"]: - data = { - "text": s.text_, - "sampling_params": { - "skip_special_tokens": global_config.skip_special_tokens_in_output, - "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out, - "dtype": "int", - **sampling_params.to_srt_kwargs(), - }, - } - else: - raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}") + self._handle_dtype_to_regex(sampling_params) + data = { + "text": s.text_, + "sampling_params": { + "skip_special_tokens": global_config.skip_special_tokens_in_output, + "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out, + **sampling_params.to_srt_kwargs(), + }, + } for item in [ "return_logprob", @@ -151,27 +173,16 @@ def generate_stream( s: StreamExecutor, sampling_params: SglSamplingParams, ): - if sampling_params.dtype is None: - data = { - "text": s.text_, - "sampling_params": { - "skip_special_tokens": global_config.skip_special_tokens_in_output, - "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out, - **sampling_params.to_srt_kwargs(), - }, - } - elif sampling_params.dtype in [int, "int"]: - data = { - "text": s.text_, - "sampling_params": { - "skip_special_tokens": global_config.skip_special_tokens_in_output, - "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out, - "dtype": "int", - **sampling_params.to_srt_kwargs(), - }, - } - else: - raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}") + self._handle_dtype_to_regex(sampling_params) + + data = { + "text": s.text_, + "sampling_params": { + "skip_special_tokens": global_config.skip_special_tokens_in_output, + "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out, + **sampling_params.to_srt_kwargs(), + }, + } for item in [ "return_logprob", diff --git a/python/sglang/lang/chat_template.py b/python/sglang/lang/chat_template.py index bfde4bbdb6..fa300b25f0 100644 --- a/python/sglang/lang/chat_template.py +++ b/python/sglang/lang/chat_template.py @@ -1,6 +1,6 @@ -from dataclasses import dataclass, field +from dataclasses import dataclass from enum import Enum, auto -from typing import Callable, Dict, List, Optional, Tuple +from typing import Callable, Dict, List, Tuple class ChatTemplateStyle(Enum): @@ -137,7 +137,7 @@ def get_chat_template_by_model_path(model_path): register_chat_template( ChatTemplate( name="chatml-llava", - default_system_prompt="Answer the questions.", + default_system_prompt="You are a helpful assistant.", role_prefix_and_suffix={ "system": ("<|im_start|>system\n", "<|im_end|>\n"), "user": ("<|im_start|>user\n", "<|im_end|>\n"), @@ -145,7 +145,7 @@ def get_chat_template_by_model_path(model_path): }, style=ChatTemplateStyle.PLAIN, stop_str=("<|im_end|>",), - image_token=" \n", + image_token="\n", ) ) @@ -322,12 +322,17 @@ def match_chat_ml(model_path: str): if "tinyllama" in model_path: return get_chat_template("chatml") # Now the suffix for qwen2 chat model is "instruct" - if "qwen" in model_path and ("chat" in model_path or "instruct" in model_path): + if ( + "qwen" in model_path + and ("chat" in model_path or "instruct" in model_path) + and ("llava" not in model_path) + ): return get_chat_template("qwen") if ( "llava-v1.6-34b" in model_path or "llava-v1.6-yi-34b" in model_path or "llava-next-video-34b" in model_path + or "llava-onevision-qwen2" in model_path ): return get_chat_template("chatml-llava") diff --git a/python/sglang/lang/compiler.py b/python/sglang/lang/compiler.py index 95af04adb0..5e1b411fc2 100644 --- a/python/sglang/lang/compiler.py +++ b/python/sglang/lang/compiler.py @@ -130,6 +130,7 @@ def run( temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1, + min_p: float = 0.0, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, backend=None, @@ -145,6 +146,7 @@ def run( temperature=temperature, top_p=top_p, top_k=top_k, + min_p=min_p, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, ) @@ -160,6 +162,7 @@ def run_batch( temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1, + min_p: float = 0.0, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, backend=None, @@ -178,6 +181,7 @@ def run_batch( temperature=temperature, top_p=top_p, top_k=top_k, + min_p=min_p, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, ) diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index cf53fac303..91f48456aa 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -20,7 +20,6 @@ SglConstantText, SglExpr, SglExprList, - SglFunction, SglGen, SglImage, SglRoleBegin, @@ -181,8 +180,10 @@ def __init__( num_api_spec_tokens=None, use_thread=True, ): + from sglang.lang.backend.base_backend import BaseBackend + self.sid = uuid.uuid4().hex - self.backend = backend + self.backend: BaseBackend = backend self.arguments: Dict[str, Any] = arguments self.default_sampling_para = default_sampling_para self.stream = stream @@ -658,9 +659,11 @@ def _resolve_sampling_params(self, sampling_params): for item in [ "max_new_tokens", "stop", + "stop_token_ids", "temperature", "top_p", "top_k", + "min_p", "frequency_penalty", "presence_penalty", "ignore_eos", @@ -670,6 +673,7 @@ def _resolve_sampling_params(self, sampling_params): "return_text_in_logprobs", "dtype", "regex", + "json_schema", ]: value = getattr(sampling_params, item, None) if value is not None: diff --git a/python/sglang/lang/ir.py b/python/sglang/lang/ir.py index 135110c1e0..99a3e8e68b 100644 --- a/python/sglang/lang/ir.py +++ b/python/sglang/lang/ir.py @@ -8,19 +8,21 @@ from sglang.global_config import global_config from sglang.lang.choices import ChoicesSamplingMethod -REGEX_INT = r"[-+]?[0-9]+" -REGEX_FLOAT = r"[-+]?[0-9]*\.?[0-9]+" +REGEX_INT = r"[-+]?[0-9]+[ \n]*" +REGEX_FLOAT = r"[-+]?[0-9]*\.?[0-9]+[ \n]*" REGEX_BOOL = r"(True|False)" -REGEX_STRING = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg +REGEX_STR = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg @dataclasses.dataclass class SglSamplingParams: max_new_tokens: int = 128 stop: Union[str, List[str]] = () + stop_token_ids: Optional[List[int]] = () temperature: float = 1.0 top_p: float = 1.0 top_k: int = -1 # -1 means disable + min_p: float = 0.0 frequency_penalty: float = 0.0 presence_penalty: float = 0.0 ignore_eos: bool = False @@ -28,6 +30,7 @@ class SglSamplingParams: logprob_start_len: Optional[int] = (None,) top_logprobs_num: Optional[int] = (None,) return_text_in_logprobs: Optional[bool] = (None,) + json_schema: Optional[str] = None # for constrained generation, not included in to_xxx_kwargs dtype: Optional[str] = None @@ -37,9 +40,11 @@ def clone(self): return SglSamplingParams( self.max_new_tokens, self.stop, + self.stop_token_ids, self.temperature, self.top_p, self.top_k, + self.min_p, self.frequency_penalty, self.presence_penalty, self.ignore_eos, @@ -47,6 +52,7 @@ def clone(self): self.logprob_start_len, self.top_logprobs_num, self.return_text_in_logprobs, + self.json_schema, ) def to_openai_kwargs(self): @@ -108,13 +114,16 @@ def to_srt_kwargs(self): return { "max_new_tokens": self.max_new_tokens, "stop": self.stop, + "stop_token_ids": self.stop_token_ids, "temperature": self.temperature, "top_p": self.top_p, "top_k": self.top_k, + "min_p": self.min_p, "frequency_penalty": self.frequency_penalty, "presence_penalty": self.presence_penalty, "ignore_eos": self.ignore_eos, "regex": self.regex, + "json_schema": self.json_schema, } @@ -141,10 +150,12 @@ def run( self, *args, max_new_tokens: int = 128, - stop: Union[str, List[str]] = (), + stop: Union[str, List[str]] = [], + stop_token_ids: Optional[List[int]] = [], temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1, + min_p: float = 0.0, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, ignore_eos: bool = False, @@ -161,9 +172,11 @@ def run( default_sampling_para = SglSamplingParams( max_new_tokens=max_new_tokens, stop=stop, + stop_token_ids=stop_token_ids, temperature=temperature, top_p=top_p, top_k=top_k, + min_p=min_p, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, ignore_eos=ignore_eos, @@ -181,9 +194,11 @@ def run_batch( *, max_new_tokens: int = 128, stop: Union[str, List[str]] = (), + stop_token_ids: Optional[List[int]] = [], temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1, + min_p: float = 0.0, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, ignore_eos: bool = False, @@ -218,9 +233,11 @@ def run_batch( default_sampling_para = SglSamplingParams( max_new_tokens=max_new_tokens, stop=stop, + stop_token_ids=stop_token_ids, temperature=temperature, top_p=top_p, top_k=top_k, + min_p=min_p, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, ignore_eos=ignore_eos, @@ -397,9 +414,11 @@ def __init__( name: Optional[str] = None, max_new_tokens: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, + stop_token_ids: Optional[List[int]] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, top_k: Optional[int] = None, + min_p: Optional[float] = None, frequency_penalty: Optional[float] = None, presence_penalty: Optional[float] = None, ignore_eos: Optional[bool] = None, @@ -409,6 +428,7 @@ def __init__( return_text_in_logprobs: Optional[bool] = None, dtype: Optional[type] = None, regex: Optional[str] = None, + json_schema: Optional[str] = None, ): """Call the model to generate. See the meaning of the arguments in docs/en/sampling_params.md""" super().__init__() @@ -416,9 +436,11 @@ def __init__( self.sampling_params = SglSamplingParams( max_new_tokens=max_new_tokens, stop=stop, + stop_token_ids=stop_token_ids, temperature=temperature, top_p=top_p, top_k=top_k, + min_p=min_p, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, ignore_eos=ignore_eos, @@ -428,6 +450,7 @@ def __init__( return_text_in_logprobs=return_text_in_logprobs, dtype=dtype, regex=regex, + json_schema=json_schema, ) def __repr__(self): diff --git a/python/sglang/launch_server.py b/python/sglang/launch_server.py index 91dc0dc4e9..1df64e848c 100644 --- a/python/sglang/launch_server.py +++ b/python/sglang/launch_server.py @@ -1,9 +1,11 @@ """Launch the inference server.""" import argparse +import os from sglang.srt.server import launch_server from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import kill_child_process if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -11,4 +13,9 @@ args = parser.parse_args() server_args = ServerArgs.from_cli_args(args) - launch_server(server_args) + try: + launch_server(server_args) + except Exception as e: + raise e + finally: + kill_child_process(os.getpid(), including_parent=False) diff --git a/python/sglang/launch_server_llavavid.py b/python/sglang/launch_server_llavavid.py index c34dd21167..797ad07a47 100644 --- a/python/sglang/launch_server_llavavid.py +++ b/python/sglang/launch_server_llavavid.py @@ -5,8 +5,12 @@ from sglang.srt.server import ServerArgs, launch_server if __name__ == "__main__": - model_overide_args = {} + parser = argparse.ArgumentParser() + ServerArgs.add_cli_args(parser) + args = parser.parse_args() + server_args = ServerArgs.from_cli_args(args) + model_overide_args = {} model_overide_args["mm_spatial_pool_stride"] = 2 model_overide_args["architectures"] = ["LlavaVidForCausalLM"] model_overide_args["num_frames"] = 16 @@ -16,14 +20,7 @@ model_overide_args["max_sequence_length"] = 4096 * 2 model_overide_args["tokenizer_model_max_length"] = 4096 * 2 model_overide_args["model_max_length"] = 4096 * 2 - - parser = argparse.ArgumentParser() - ServerArgs.add_cli_args(parser) - args = parser.parse_args() - if "34b" in args.model_path.lower(): model_overide_args["image_token_index"] = 64002 - server_args = ServerArgs.from_cli_args(args) - launch_server(server_args, model_overide_args, None) diff --git a/python/sglang/srt/configs/__init__.py b/python/sglang/srt/configs/__init__.py new file mode 100644 index 0000000000..9e74366709 --- /dev/null +++ b/python/sglang/srt/configs/__init__.py @@ -0,0 +1,5 @@ +from sglang.srt.configs.exaone import ExaoneConfig + +__all__ = [ + "ExaoneConfig", +] diff --git a/python/sglang/srt/configs/exaone.py b/python/sglang/srt/configs/exaone.py new file mode 100644 index 0000000000..7b0a2d290d --- /dev/null +++ b/python/sglang/srt/configs/exaone.py @@ -0,0 +1,195 @@ +# coding=utf-8 +# Copyright 2024 The LG AI Research EXAONE Lab. All rights reserved. +# Copyright 2024 The LG CNS AI Engineering Team. +# Copyright 2023-2024 SGLang Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" EXAONE model configuration """ +from typing import Any, Dict + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +EXAONE_PRETRAINED_CONFIG_ARCHIVE_MAP: Dict[str, Any] = {} + + +# ruff: noqa: E501 +class ExaoneConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a :class:`~transformers.ExaoneModel`. It is used to + instantiate a EXAONE model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the Exaone + + Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model + outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. + + + Args: + vocab_size (:obj:`int`, `optional`, defaults to 102400): + Vocabulary size of the EXAONE model. Defines the number of different tokens that can be represented by the + :obj:`inputs_ids` passed when calling :class:`~transformers.ExaoneModel`. Vocabulary size of the model. + Defines the different tokens that can be represented by the `inputs_ids` passed to the forward method of + :class:`~transformers.EXAONEModel`. + max_position_embeddings (:obj:`int`, `optional`, defaults to 2048): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + hidden_size (:obj:`int`, `optional`, defaults to 2048): + Dimensionality of the encoder layers and the pooler layer. + num_layers (:obj:`int`, `optional`, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (:obj:`int`, `optional`, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (:obj:`int`, `optional`): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + intermediate_size (:obj:`int`, `optional`, defaults to `hidden_size * 4`): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + activation_function (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"silu"`): + The non-linear activation function (function or string) in the decoder. + rope_theta (:obj:`float`, `optional`, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (:obj:`Dict`, `optional`): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (:obj:`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (:obj:`float`, `optional`): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (:obj:`int`, `optional`): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (:obj:`float`, `optional`): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (:obj:`float`, `optional`): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (:obj:`float`, `optional`): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (:obj:`List[float]`, `optional`): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (:obj:`List[float]`, `optional`): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (:obj:`float`, `optional`): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (:obj:`float`, `optional`): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + embed_dropout (:obj:`float`, `optional`, defaults to 0.0): + The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (:obj:`float`, `optional`, defaults to 0.0): + The dropout ratio for the attention probabilities. + layer_norm_epsilon (:obj:`float`, `optional`, defaults to 1e-5): + The epsilon used by the layer normalization layers. + initializer_range (:obj:`float`, `optional`, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if ``configs.is_decoder=True``. + bos_token_id (:obj:`int`, `optional`, defaults to 0): + Beginning of stream token id. + eos_token_id (:obj:`int`, `optional`, defaults to 2): + End of stream token id. + tie_word_embeddings (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to tie weight embeddings + gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): + If True, use gradient checkpointing to save memory at the expense of slower backward pass. + + Example:: + + >>> from transformers import EXAONEModel, ExaoneConfig + + >>> # Initializing a EXAONE configuration + >>> configuration = ExaoneConfig() + + >>> # Initializing a model from configuration + >>> model = EXAONEModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.configs + """ + + model_type = "exaone" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"num_hidden_layers": "num_layers"} + + def __init__( + self, + vocab_size=102400, + max_position_embeddings=2048, + hidden_size=2048, + num_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + intermediate_size=None, + activation_function="silu", + rope_theta=10000.0, + rope_scaling=None, + embed_dropout=0.0, + attention_dropout=0.0, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + use_cache=True, + bos_token_id=0, + eos_token_id=2, + tie_word_embeddings=True, + **kwargs + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.num_layers = num_layers + self.num_attention_heads = num_attention_heads + self.num_hidden_layers = num_layers + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + if intermediate_size: + self.intermediate_size = intermediate_size + else: + self.intermediate_size = hidden_size * 4 + self.activation_function = activation_function + self.embed_dropout = embed_dropout + self.attention_dropout = attention_dropout + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs + ) diff --git a/python/sglang/srt/constrained/base_tool_cache.py b/python/sglang/srt/constrained/base_tool_cache.py index 4cbb6bd226..fa1aff5eac 100644 --- a/python/sglang/srt/constrained/base_tool_cache.py +++ b/python/sglang/srt/constrained/base_tool_cache.py @@ -54,7 +54,7 @@ def _init_with_timer(key): return val def init_value(self, key): - raise NotImplementedError + raise NotImplementedError() def get_cache_hit_rate(self): if self.metrics["total"] == 0: diff --git a/python/sglang/srt/constrained/fsm_cache.py b/python/sglang/srt/constrained/fsm_cache.py index fa41f90de3..6bc6ea6d26 100644 --- a/python/sglang/srt/constrained/fsm_cache.py +++ b/python/sglang/srt/constrained/fsm_cache.py @@ -15,6 +15,8 @@ """Cache for the compressed finite state machine.""" +from outlines.fsm.json_schema import build_regex_from_schema + from sglang.srt.constrained import RegexGuide, TransformerTokenizer from sglang.srt.constrained.base_tool_cache import BaseToolCache @@ -26,9 +28,12 @@ def __init__( tokenizer_args_dict, enable=True, skip_tokenizer_init=False, + json_schema_mode=False, ): super().__init__(enable=enable) + self.json_schema_mode = json_schema_mode + if ( skip_tokenizer_init or tokenizer_path.endswith(".json") @@ -72,5 +77,9 @@ def fset(self, value): tokenizer_path, **tokenizer_args_dict ) - def init_value(self, regex): - return RegexGuide(regex, self.outlines_tokenizer) + def init_value(self, value): + if self.json_schema_mode: + regex = build_regex_from_schema(value) + return RegexGuide(regex, self.outlines_tokenizer), regex + else: + return RegexGuide(value, self.outlines_tokenizer) diff --git a/python/sglang/srt/constrained/jump_forward.py b/python/sglang/srt/constrained/jump_forward.py index 7b694318e4..244931e050 100644 --- a/python/sglang/srt/constrained/jump_forward.py +++ b/python/sglang/srt/constrained/jump_forward.py @@ -23,6 +23,7 @@ import interegular import outlines.caching +from outlines.fsm.json_schema import build_regex_from_schema from sglang.srt.constrained import ( FSMInfo, @@ -62,16 +63,22 @@ def _init_state_to_jump_forward(regex_string): id_to_symbol.setdefault(id_, []).append(symbol) transitions = fsm_info.transitions + outgoings_ct = defaultdict(int) - state_to_jump_forward = {} + # NOTE(lsyin): Final states can lead to terminate, so they have one outgoing edge naturally + for s in fsm_info.finals: + outgoings_ct[s] = 1 + state_to_jump_forward = {} for (state, id_), next_state in transitions.items(): if id_ == fsm_info.alphabet_anything_value: + # Arbitrarily symbol cannot be recognized as jump forward continue + symbols = id_to_symbol[id_] for c in symbols: if len(c) > 1: - # Skip byte level transitions + # Skip byte level transitions like c = "5E" continue outgoings_ct[state] += 1 @@ -87,6 +94,9 @@ def _init_state_to_jump_forward(regex_string): # Process the byte level jump forward outgoings_ct = defaultdict(int) + for s in fsm_info.finals: + outgoings_ct[s] = 1 + for (state, id_), next_state in transitions.items(): if id_ == fsm_info.alphabet_anything_value: continue @@ -177,3 +187,5 @@ def test_main(regex_string): test_main(r"霍格沃茨特快列车|霍比特人比尔博") # 霍格: \xe9\x9c\x8d \xe6\xa0\xbc ... # 霍比: \xe9\x9c\x8d \xe6\xaf\x94 ... + + test_main(r"[-+]?[0-9]+[ ]*") diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py index 5ee1216974..d5ca327703 100644 --- a/python/sglang/srt/conversation.py +++ b/python/sglang/srt/conversation.py @@ -34,6 +34,7 @@ class SeparatorStyle(IntEnum): NO_COLON_TWO = auto() ADD_NEW_LINE_SINGLE = auto() LLAMA2 = auto() + LLAMA3 = auto() CHATGLM = auto() CHATML = auto() CHATINTERN = auto() @@ -137,6 +138,20 @@ def get_prompt(self) -> str: else: ret += role + ":" return ret + elif self.sep_style == SeparatorStyle.LLAMA3: + ret = "<|begin_of_text|>" + if self.system_message: + ret += system_prompt + else: + ret += "" + for i, (role, message) in enumerate(self.messages): + if message: + ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n" + ret += f"{message.strip()}<|eot_id|>" + else: + ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n" + # print(ret) + return ret elif self.sep_style == SeparatorStyle.LLAMA2: seps = [self.sep, self.sep2] if self.system_message: @@ -379,12 +394,23 @@ def generate_chat_conv( conv.append_message(conv.roles[0], message.content) else: real_content = "" + # calculate number of image_url + num_image_url = 0 + for content in message.content: + if content.type == "image_url": + num_image_url += 1 + if num_image_url > 1: + image_token = "" + else: + image_token = "\n" for content in message.content: if content.type == "text": + if num_image_url > 16: + real_content += "\n" # for video real_content += content.text elif content.type == "image_url": # NOTE: Only works for llava - real_content += "\n" + real_content += image_token conv.append_image(content.image_url.url) conv.append_message(conv.roles[0], real_content) elif msg_role == "assistant": @@ -425,6 +451,18 @@ def generate_chat_conv( ) ) +register_conv_template( + Conversation( + name="chatml-llava", + system_template="<|im_start|>system\n{system_message}", + system_message="You are a helpful assistant.", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_str=["<|endoftext|>", "<|im_end|>"], + ) +) + register_conv_template( Conversation( name="vicuna_v1.1", @@ -437,6 +475,17 @@ def generate_chat_conv( ) ) +register_conv_template( + Conversation( + name="llava_llama_3", + system_message="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.", + system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>", + roles=("user", "assistant"), + sep_style=SeparatorStyle.LLAMA3, + sep="", + stop_str=["<|end_of_text|>", "<|eot_id|>"], + ) +) # Reference: https://github.com/InternLM/lmdeploy/blob/387bf54b4f124e72aab30ae9755f562e435d3d01/lmdeploy/model.py#L425-L442 register_conv_template( Conversation( diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index 508843a395..7fce3b2401 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -15,6 +15,7 @@ """Utilities for Huggingface Transformers.""" +import contextlib import functools import json import os @@ -30,14 +31,26 @@ PreTrainedTokenizer, PreTrainedTokenizerFast, ) -from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig -from sglang.srt.utils import is_multimodal_model +try: + from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig + + from sglang.srt.configs import ExaoneConfig + + _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { + ChatGLMConfig.model_type: ChatGLMConfig, + DbrxConfig.model_type: DbrxConfig, + ExaoneConfig.model_type: ExaoneConfig, + } +except ImportError: + # We want this file to run without vllm dependency + _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {} -_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { - ChatGLMConfig.model_type: ChatGLMConfig, - DbrxConfig.model_type: DbrxConfig, -} +for name, cls in _CONFIG_REGISTRY.items(): + with contextlib.suppress(ValueError): + AutoConfig.register(name, cls) + +from sglang.srt.utils import is_multimodal_model def download_from_hf(model_path: str): @@ -48,7 +61,7 @@ def download_from_hf(model_path: str): def get_config_json(model_path: str): - with open(os.path.join(model_path, "config.json")) as f: + with open(os.path.join(model_path, "configs.json")) as f: config = json.load(f) return config @@ -84,7 +97,7 @@ def get_config( def get_context_length(config): - """Get the context length of a model from a huggingface model config.""" + """Get the context length of a model from a huggingface model configs.""" rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling: rope_scaling_factor = config.rope_scaling["factor"] @@ -114,41 +127,12 @@ def get_tokenizer( tokenizer_revision: Optional[str] = None, **kwargs, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: - if tokenizer_name.endswith(".json"): - return TiktokenTokenizer(tokenizer_name) - - if tokenizer_name.endswith(".model"): - return SentencePieceTokenizer(tokenizer_name) - """Gets a tokenizer for the given model name via Huggingface.""" - if is_multimodal_model(tokenizer_name): - processor = get_processor( - tokenizer_name, - *args, - trust_remote_code=trust_remote_code, - tokenizer_revision=tokenizer_revision, - **kwargs, - ) - tokenizer = processor.tokenizer - return tokenizer - if tokenizer_mode == "slow": if kwargs.get("use_fast", False): raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") kwargs["use_fast"] = False - if ( - "llama" in tokenizer_name.lower() - and kwargs.get("use_fast", True) - and tokenizer_name != _FAST_LLAMA_TOKENIZER - ): - pass - # warnings.warn( - # "For some LLaMA V1 models, initializing the fast tokenizer may " - # "take a long time. To reduce the initialization time, consider " - # f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " - # "tokenizer." - # ) try: tokenizer = AutoTokenizer.from_pretrained( tokenizer_name, @@ -206,129 +190,3 @@ def get_processor( **kwargs, ) return processor - - -class TiktokenTokenizer: - def __init__(self, tokenizer_path): - import tiktoken - from jinja2 import Template - - PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" - - # Read JSON - name = "tmp-json" - with open(tokenizer_path, "rb") as fin: - tok_dict = json.load(fin) - - mergeable_ranks = { - bytes(item["bytes"]): item["token"] for item in tok_dict["regular_tokens"] - } - special_tokens = { - bytes(item["bytes"]).decode(): item["token"] - for item in tok_dict["special_tokens"] - } - assert tok_dict["word_split"] == "V1" - - kwargs = { - "name": name, - "pat_str": tok_dict.get("pat_str", PAT_STR_B), - "mergeable_ranks": mergeable_ranks, - "special_tokens": special_tokens, - } - if "default_allowed_special" in tok_dict: - default_allowed_special = set( - [ - bytes(bytes_list).decode() - for bytes_list in tok_dict["default_allowed_special"] - ] - ) - else: - default_allowed_special = None - if "vocab_size" in tok_dict: - kwargs["explicit_n_vocab"] = tok_dict["vocab_size"] - - tokenizer = tiktoken.Encoding(**kwargs) - tokenizer._default_allowed_special = default_allowed_special or set() - tokenizer._default_allowed_special |= {"<|separator|>"} - - def encode_patched( - self, - text: str, - *, - allowed_special: Union[ - Literal["all"], AbstractSet[str] - ] = set(), # noqa: B006 - disallowed_special: Union[Literal["all"], Collection[str]] = "all", - ) -> List[int]: - if isinstance(allowed_special, set): - allowed_special |= self._default_allowed_special - return tiktoken.Encoding.encode( - self, - text, - allowed_special=allowed_special, - disallowed_special=disallowed_special, - ) - - tokenizer.encode = functools.partial(encode_patched, tokenizer) - - # Convert to HF interface - self.tokenizer = tokenizer - self.eos_token_id = tokenizer._special_tokens["<|eos|>"] - self.vocab_size = tokenizer.n_vocab - self.chat_template = Template( - "{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}" - ) - - def encode(self, x, add_special_tokens=False): - return self.tokenizer.encode(x) - - def decode(self, x): - return self.tokenizer.decode(x) - - def batch_decode( - self, batch, skip_special_tokens=True, spaces_between_special_tokens=False - ): - if isinstance(batch[0], int): - batch = [[x] for x in batch] - return self.tokenizer.decode_batch(batch) - - def apply_chat_template(self, messages, tokenize, add_generation_prompt): - ret = self.chat_template.render( - messages=messages, add_generation_prompt=add_generation_prompt - ) - return self.encode(ret) if tokenize else ret - - -class SentencePieceTokenizer: - def __init__(self, tokenizer_path): - import sentencepiece as spm - from jinja2 import Template - - tokenizer = spm.SentencePieceProcessor(model_file=tokenizer_path) - - # Convert to HF interface - self.tokenizer = tokenizer - self.eos_token_id = tokenizer.eos_id() - self.vocab_size = tokenizer.vocab_size() - self.chat_template = Template( - "{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}" - ) - - def encode(self, x, add_special_tokens=False): - return self.tokenizer.encode(x) - - def decode(self, x): - return self.tokenizer.decode(x) - - def batch_decode( - self, batch, skip_special_tokens=True, spaces_between_special_tokens=False - ): - if isinstance(batch[0], int): - batch = [[x] for x in batch] - return self.tokenizer.decode(batch) - - def apply_chat_template(self, messages, tokenize, add_generation_prompt): - ret = self.chat_template.render( - messages=messages, add_generation_prompt=add_generation_prompt - ) - return self.encode(ret) if tokenize else ret diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index 64d3915946..9047197af2 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -13,11 +13,20 @@ """Fused operators for activation layers.""" +from typing import Optional + import torch import torch.nn as nn import torch.nn.functional as F -from flashinfer.activation import silu_and_mul +from flashinfer.activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul +from vllm.distributed import ( + divide, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.utils import set_weight_attrs class SiluAndMul(CustomOp): @@ -31,3 +40,98 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: out = torch.empty(output_shape, dtype=x.dtype, device=x.device) silu_and_mul(x, out) return out + + +class GeluAndMul(CustomOp): + def __init__(self, approximate="tanh"): + super().__init__() + self.approximate = approximate + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:] + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + output_shape = x.shape[:-1] + (d,) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + if self.approximate == "tanh": + gelu_tanh_and_mul(x, out) + elif self.approximate == "none": + gelu_and_mul(x, out) + else: + raise RuntimeError("GeluAndMul only support tanh or none") + return out + + +class ScaledActivation(nn.Module): + """An activation function with post-scale parameters. + + This is used for some quantization methods like AWQ. + """ + + def __init__( + self, + act_module: nn.Module, + intermediate_size: int, + input_is_parallel: bool = True, + params_dtype: Optional[torch.dtype] = None, + ): + super().__init__() + self.act = act_module + self.input_is_parallel = input_is_parallel + if input_is_parallel: + tp_size = get_tensor_model_parallel_world_size() + intermediate_size_per_partition = divide(intermediate_size, tp_size) + else: + intermediate_size_per_partition = intermediate_size + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.scales = nn.Parameter( + torch.empty(intermediate_size_per_partition, dtype=params_dtype) + ) + set_weight_attrs(self.scales, {"weight_loader": self.weight_loader}) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.act(x) / self.scales + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): + param_data = param.data + if self.input_is_parallel: + tp_rank = get_tensor_model_parallel_rank() + shard_size = param_data.shape[0] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(0, start_idx, shard_size) + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + +_ACTIVATION_REGISTRY = { + "gelu": nn.GELU(), + "gelu_pytorch_tanh": nn.GELU(approximate="tanh"), +} + + +def get_act_fn( + act_fn_name: str, + quant_config: Optional[QuantizationConfig] = None, + intermediate_size: Optional[int] = None, + input_is_parallel: bool = True, + params_dtype: Optional[torch.dtype] = None, +) -> nn.Module: + """Get an activation function by name.""" + act_fn_name = act_fn_name.lower() + if act_fn_name not in _ACTIVATION_REGISTRY: + raise ValueError(f"Activation function {act_fn_name!r} is not supported.") + + act_fn = _ACTIVATION_REGISTRY[act_fn_name] + if quant_config is not None and act_fn_name in quant_config.get_scaled_act_names(): + if intermediate_size is None: + raise ValueError( + "intermediate_size must be specified for scaled " + "activation functions." + ) + return ScaledActivation( + act_fn, intermediate_size, input_is_parallel, params_dtype + ) + return act_fn diff --git a/python/sglang/srt/layers/decode_attention.py b/python/sglang/srt/layers/decode_attention.py index c868299ef4..dc92a65480 100644 --- a/python/sglang/srt/layers/decode_attention.py +++ b/python/sglang/srt/layers/decode_attention.py @@ -26,7 +26,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict -if global_server_args_dict.get("attention_reduce_in_fp32", False): +if global_server_args_dict.get("triton_attention_reduce_in_fp32", False): REDUCE_TRITON_TYPE = tl.float32 REDUCE_TORCH_TYPE = torch.float32 else: @@ -58,7 +58,6 @@ def _fwd_kernel_stage1( att_stride_h, kv_group_num: tl.constexpr, BLOCK_DMODEL: tl.constexpr, - BLOCK_DPE: tl.constexpr, BLOCK_N: tl.constexpr, logit_cap: tl.constexpr, ): @@ -78,10 +77,6 @@ def _fwd_kernel_stage1( off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d - if BLOCK_DPE > 0: - offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) - off_qpe = cur_batch * stride_qbs + cur_head * stride_qh + offs_dpe - offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) block_stard_index = start_n * BLOCK_N @@ -106,19 +101,6 @@ def _fwd_kernel_stage1( other=0.0, ).to(REDUCE_TRITON_TYPE) att_value = tl.sum(q[None, :] * k, 1) - if BLOCK_DPE > 0: - qpe = tl.load(Q + off_qpe + start_mark).to(REDUCE_TRITON_TYPE) - offs_buf_kpe = ( - k_loc[:, None] * stride_buf_kbs - + cur_kv_head * stride_buf_kh - + offs_dpe[None, :] - ) - kpe = tl.load( - K_Buffer + offs_buf_kpe, - mask=offs_n_new[:, None] < cur_batch_end_index, - other=0.0, - ).to(REDUCE_TRITON_TYPE) - att_value += tl.sum(qpe[None, :] * kpe, 1) att_value *= sm_scale if logit_cap > 0: @@ -214,14 +196,7 @@ def _decode_att_m_fwd( # shape constraints Lq, Lk = q.shape[-1], k_buffer.shape[-1] assert Lq == Lk - assert Lk in {16, 32, 64, 128, 256, 576} - - if Lk == 576: - BLOCK_DMODEL = 512 - BLOCK_DPE = 64 - else: - BLOCK_DMODEL = Lk - BLOCK_DPE = 0 + assert Lk in {16, 32, 64, 128, 256} batch, head_num = B_req_idx.shape[0], q.shape[1] @@ -249,8 +224,7 @@ def _decode_att_m_fwd( k_buffer.stride(1), att_out.stride(0), kv_group_num=kv_group_num, - BLOCK_DMODEL=BLOCK_DMODEL, - BLOCK_DPE=BLOCK_DPE, + BLOCK_DMODEL=Lk, BLOCK_N=BLOCK, logit_cap=logit_cap, num_warps=num_warps, @@ -296,6 +270,293 @@ def _decode_softmax_reducev_fwd( ) +@triton.jit +def _fwd_grouped_kernel_stage1( + Q, + K_Buffer, + sm_scale, + Req_to_tokens, + B_req_idx, + B_Start_Loc, + B_Seqlen, + Att_Out, + stride_req_to_tokens_b, + stride_qbs, + stride_qh, + stride_buf_kbs, + stride_buf_kh, + att_stride_h, + kv_group_num: tl.constexpr, + q_head_num: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DPE: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_H: tl.constexpr, + logit_cap: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_kv_head = tl.program_id(1) + start_n = tl.program_id(2) + + cur_head = cur_kv_head * kv_group_num + tl.arange(0, BLOCK_H) + mask_h = cur_head < (cur_kv_head + 1) * kv_group_num + mask_h = mask_h & (cur_head < q_head_num) + + offs_d = tl.arange(0, BLOCK_DMODEL) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + + cur_batch_start_index = 0 + cur_batch_end_index = cur_batch_seq_len + + offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] + + if BLOCK_DPE > 0: + offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) + off_qpe = ( + cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :] + ) + + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + + block_stard_index = start_n * BLOCK_N + block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0) + + for start_mark in range(0, block_mask, 1): + q = tl.load(Q + offs_q + start_mark, mask=mask_h[:, None]).to( + REDUCE_TRITON_TYPE + ) + offs_n_new = cur_batch_start_index + offs_n + k_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, + mask=offs_n_new < cur_batch_end_index, + other=0, + ) + offs_buf_k = ( + k_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_d[:, None] + ) + k = tl.load( + K_Buffer + offs_buf_k, + mask=offs_n_new[None, :] < cur_batch_end_index, + other=0.0, + ).to(REDUCE_TRITON_TYPE) + qk = tl.dot(q, k) + if BLOCK_DPE > 0: + qpe = tl.load(Q + off_qpe + start_mark, mask=mask_h[:, None]).to( + REDUCE_TRITON_TYPE + ) + offs_buf_kpe = ( + k_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_dpe[:, None] + ) + kpe = tl.load( + K_Buffer + offs_buf_kpe, + mask=offs_n_new[None, :] < cur_batch_end_index, + other=0.0, + ).to(REDUCE_TRITON_TYPE) + qk += tl.dot(qpe, kpe) + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + + offs_o = cur_head[:, None] * att_stride_h + ( + cur_batch_in_all_start_index + offs_n[None, :] + ) + + tl.store( + Att_Out + offs_o, + qk, + mask=mask_h[:, None] & (offs_n_new[None, :] < cur_batch_end_index), + ) + + +@triton.jit +def _fwd_grouped_kernel_stage2( + Logics, + V_Buffer, + Out, + Req_to_tokens, + B_req_idx, + B_Start_Loc, + B_Seqlen, + stride_logic_h, + stride_buf_vbs, + stride_buf_vh, + stride_obs, + stride_oh, + stride_req_to_token_b, + kv_group_num: tl.constexpr, + q_head_num: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_H: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_kv_head = tl.program_id(1) + + cur_head = cur_kv_head * kv_group_num + tl.arange(0, BLOCK_H) + mask_h = cur_head < (cur_kv_head + 1) * kv_group_num + mask_h = mask_h & (cur_head < q_head_num) + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_start_loc = tl.load(B_Start_Loc + cur_batch) + cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + + offs_buf_v = cur_kv_head * stride_buf_vh + offs_d[None, :] + v_ptrs = V_Buffer + offs_buf_v + + e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") + e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) + acc = tl.zeros([BLOCK_H, BLOCK_DMODEL], dtype=tl.float32) + + for start_n in range(0, cur_batch_seq_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + v_index = tl.load( + Req_to_tokens + + cur_batch_req_idx * stride_req_to_token_b + + (start_n + offs_n), + mask=(start_n + offs_n) < cur_batch_seq_len, + other=0, + ) + + offs_qk = cur_head[:, None] * stride_logic_h + ( + cur_batch_start_loc + start_n + offs_n[None, :] + ) + + qk = tl.load( + Logics + offs_qk, + mask=mask_h[:, None] & (start_n + offs_n[None, :] < cur_batch_seq_len), + other=float("-inf"), + ) + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + old_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max[:, None]) + e_sum = e_sum * old_scale + tl.sum(p, 1) + v = tl.load(v_ptrs + v_index[:, None] * stride_buf_vbs) + p = p.to(v.dtype) + acc = acc * old_scale[:, None] + tl.dot(p, v) + e_max = n_e_max + + acc = acc / e_sum[:, None] + off_o = cur_batch * stride_obs + cur_head[:, None] * stride_oh + offs_d[None, :] + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, mask=mask_h[:, None]) + + +def _decode_grouped_att_m_fwd( + q, + k_buffer, + att_out, + Req_to_tokens, + B_req_idx, + B_Start_Loc, + B_Seqlen, + max_len_in_batch, + sm_scale, + logit_cap, +): + BLOCK = 32 + # shape constraints + Lq, Lk = q.shape[-1], k_buffer.shape[-1] + assert Lq == Lk + assert Lk in {16, 32, 64, 128, 256, 576} + + if Lk == 576: + BLOCK_DMODEL = 512 + BLOCK_DPE = 64 + else: + BLOCK_DMODEL = Lk + BLOCK_DPE = 0 + + batch, head_num = B_req_idx.shape[0], q.shape[1] + kv_group_num = q.shape[1] // k_buffer.shape[1] + + BLOCK_H = max(16, triton.next_power_of_2(kv_group_num)) + grid = ( + batch, + triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), + triton.cdiv(max_len_in_batch, BLOCK), + ) + + num_warps = 4 + + _fwd_grouped_kernel_stage1[grid]( + q, + k_buffer, + sm_scale, + Req_to_tokens, + B_req_idx, + B_Start_Loc, + B_Seqlen, + att_out, + Req_to_tokens.stride(0), + q.stride(0), + q.stride(1), + k_buffer.stride(0), + k_buffer.stride(1), + att_out.stride(0), + kv_group_num=kv_group_num, + q_head_num=head_num, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DPE=BLOCK_DPE, + BLOCK_N=BLOCK, + BLOCK_H=BLOCK_H, + logit_cap=logit_cap, + num_warps=num_warps, + num_stages=1, + ) + + +def _decode_grouped_softmax_reducev_fwd( + logics, + v_buffer, + o, + req_to_tokens, + b_req_idx, + b_start_loc, + b_seq_len, +): + BLOCK = 128 + batch, head_num = b_seq_len.shape[0], logics.shape[0] + kv_group_num = logics.shape[0] // v_buffer.shape[1] + BLOCK_H = max(16, triton.next_power_of_2(kv_group_num)) + grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1) + + num_warps = 8 + + _fwd_grouped_kernel_stage2[grid]( + logics, + v_buffer, + o, + req_to_tokens, + b_req_idx, + b_start_loc, + b_seq_len, + logics.stride(0), + v_buffer.stride(0), + v_buffer.stride(1), + o.stride(0), + o.stride(1), + req_to_tokens.stride(0), + kv_group_num=kv_group_num, + q_head_num=head_num, + BLOCK_DMODEL=v_buffer.shape[-1], + BLOCK_N=BLOCK, + BLOCK_H=BLOCK_H, + num_warps=num_warps, + num_stages=1, + ) + + def decode_attention_fwd( q, k_buffer, @@ -316,24 +577,51 @@ def decode_attention_fwd( (q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device="cuda" ) - _decode_att_m_fwd( - q, - k_buffer, - att_m, - req_to_token, - b_req_idx, - b_start_loc, - b_seq_len, - max_len_in_batch, - sm_scale, - logit_cap, - ) - _decode_softmax_reducev_fwd( - att_m, - v_buffer, - o, - req_to_token, - b_req_idx, - b_start_loc, - b_seq_len, - ) + kv_group_num = q.shape[1] // v_buffer.shape[1] + + if kv_group_num == 1: + # MHA + _decode_att_m_fwd( + q, + k_buffer, + att_m, + req_to_token, + b_req_idx, + b_start_loc, + b_seq_len, + max_len_in_batch, + sm_scale, + logit_cap, + ) + _decode_softmax_reducev_fwd( + att_m, + v_buffer, + o, + req_to_token, + b_req_idx, + b_start_loc, + b_seq_len, + ) + else: + # GQA/MQA/MLA + _decode_grouped_att_m_fwd( + q, + k_buffer, + att_m, + req_to_token, + b_req_idx, + b_start_loc, + b_seq_len, + max_len_in_batch, + sm_scale, + logit_cap, + ) + _decode_grouped_softmax_reducev_fwd( + att_m, + v_buffer, + o, + req_to_token, + b_req_idx, + b_start_loc, + b_seq_len, + ) diff --git a/python/sglang/srt/layers/extend_attention.py b/python/sglang/srt/layers/extend_attention.py index 0a03f65626..097adca3ca 100644 --- a/python/sglang/srt/layers/extend_attention.py +++ b/python/sglang/srt/layers/extend_attention.py @@ -275,7 +275,9 @@ def extend_attention_fwd( BLOCK_DPE = 0 BLOCK_DV = Lv - if CUDA_CAPABILITY[0] >= 8: + if CUDA_CAPABILITY[0] >= 9: + BLOCK_M, BLOCK_N = (128, 64) + elif CUDA_CAPABILITY[0] >= 8: BLOCK_M, BLOCK_N = (128, 128) if Lq <= 128 else (64, 64) else: BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32) diff --git a/python/sglang/srt/layers/fused_moe/__init__.py b/python/sglang/srt/layers/fused_moe/__init__.py new file mode 100644 index 0000000000..5f7691c09f --- /dev/null +++ b/python/sglang/srt/layers/fused_moe/__init__.py @@ -0,0 +1 @@ +from sglang.srt.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase diff --git a/python/sglang/srt/layers/fused_moe.py b/python/sglang/srt/layers/fused_moe/fused_moe.py similarity index 78% rename from python/sglang/srt/layers/fused_moe.py rename to python/sglang/srt/layers/fused_moe/fused_moe.py index c5630fa5db..717be5ce96 100644 --- a/python/sglang/srt/layers/fused_moe.py +++ b/python/sglang/srt/layers/fused_moe/fused_moe.py @@ -1,20 +1,5 @@ -""" -Copyright 2023-2024 SGLang Team -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - # Adapted from -# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/layers/fused_moe/fused_moe.py#L1 +# https://github.com/vllm-project/vllm/tree/v0.5.4/vllm/model_executor/layers/fused_moe """Fused MoE kernel.""" import functools import json @@ -24,6 +9,7 @@ import torch import triton import triton.language as tl +import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger @@ -373,6 +359,31 @@ def get_default_config( return config +def try_get_optimal_moe_config( + w1_shape: Tuple[int, ...], + w2_shape: Tuple[int, ...], + top_k: int, + dtype: Optional[str], + M: int, + override_config: Optional[Dict[str, Any]] = None, +): + if override_config: + config = override_config + else: + # First try to load optimal config from the file + E, _, N = w2_shape + configs = get_moe_configs(E, N, dtype) + + if configs: + # If an optimal configuration map has been found, look up the + # optimal config + config = configs[min(configs.keys(), key=lambda x: abs(x - M))] + else: + # Else use the default config + config = get_default_config(M, E, N, w1_shape[2], top_k, dtype) + return config + + def fused_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -403,6 +414,41 @@ def fused_topk( return topk_weights, topk_ids +# This is used by the Deepseek-V2 model +def grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, +): + + assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" + + scores = torch.softmax(gating_output, dim=-1) + num_token = scores.shape[0] + group_scores = ( + scores.view(num_token, num_expert_group, -1).max(dim=-1).values + ) # [n, n_group] + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ + 1 + ] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group) + .reshape(num_token, -1) + ) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] + topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + return topk_weights, topk_ids + + def fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -425,24 +471,23 @@ def fused_experts( assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] - M, _ = hidden_states.shape + num_tokens, _ = hidden_states.shape E, N, _ = w1.shape + # We execute the fused_moe kernel in chunks to circumvent this issue: + # https://github.com/vllm-project/vllm/issues/5938 + CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE + M = min(num_tokens, CHUNK_SIZE) + + get_config_func = functools.partial( + try_get_optimal_moe_config, + w1.shape, + w2.shape, + topk_ids.shape[1], + "float8" if use_fp8 else None, + override_config=override_config, + ) - if override_config: - config = override_config - else: - # First try to load optimal config from the file - configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None) - - if configs: - # If an optimal configuration map has been found, look up the - # optimal config - config = configs[min(configs.keys(), key=lambda x: abs(x - M))] - else: - # Else use the default config - config = get_default_config( - M, E, N, w1.shape[2], topk_ids.shape[1], "float8" if use_fp8 else None - ) + config = get_config_func(M) intermediate_cache1 = torch.empty( (M, topk_ids.shape[1], N), @@ -460,56 +505,85 @@ def fused_experts( dtype=hidden_states.dtype, ) - sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( - topk_ids, config["BLOCK_SIZE_M"], E - ) compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 - invoke_fused_moe_kernel( - hidden_states, - w1, - intermediate_cache1, - a1_scale, - w1_scale, - topk_weights, - topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - False, - topk_ids.shape[1], - config, - compute_type=compute_type, - use_fp8=use_fp8, - ) + if inplace: + out_hidden_states = hidden_states + else: + out_hidden_states = torch.empty_like(hidden_states) - ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + for chunk in range((num_tokens // CHUNK_SIZE) + 1): + begin_chunk_idx, end_chunk_idx = ( + chunk * CHUNK_SIZE, + min((chunk + 1) * CHUNK_SIZE, num_tokens), + ) + curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] + tokens_in_chunk, _ = curr_hidden_states.shape + + if tokens_in_chunk == 0: + break + + if tokens_in_chunk < CHUNK_SIZE and chunk > 0: + # Adjust the intermediate cache size and config for the last + # chunk. Note that in most cases we only have one chunk + # so the cache size and config are already set correctly and + # do not need to be adjusted. + intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] + intermediate_cache2 = intermediate_cache2[:tokens_in_chunk] + intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] + config = get_config_func(tokens_in_chunk) + + curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] + curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] + + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + curr_topk_ids, config["BLOCK_SIZE_M"], E + ) - invoke_fused_moe_kernel( - intermediate_cache2, - w2, - intermediate_cache3, - a2_scale, - w2_scale, - topk_weights, - topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - True, - 1, - config, - compute_type=compute_type, - use_fp8=use_fp8, - ) + invoke_fused_moe_kernel( + curr_hidden_states, + w1, + intermediate_cache1, + a1_scale, + w1_scale, + curr_topk_weights, + curr_topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + topk_ids.shape[1], + config, + compute_type=compute_type, + use_fp8=use_fp8, + ) - if inplace: - return torch.sum( + ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + + invoke_fused_moe_kernel( + intermediate_cache2, + w2, + intermediate_cache3, + a2_scale, + w2_scale, + curr_topk_weights, + curr_topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + True, + 1, + config, + compute_type=compute_type, + use_fp8=use_fp8, + ) + + torch.sum( intermediate_cache3.view(*intermediate_cache3.shape), dim=1, - out=hidden_states, + out=out_hidden_states[begin_chunk_idx:end_chunk_idx], ) - return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) + return out_hidden_states def fused_moe( @@ -521,6 +595,9 @@ def fused_moe( renormalize: bool, inplace: bool = False, override_config: Optional[Dict[str, Any]] = None, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, use_fp8: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, @@ -543,6 +620,10 @@ def fused_moe( Defaults to False. - override_config (Optional[Dict[str, Any]]): Optional override for the kernel configuration. + - num_expert_group: Optional[int]: additional parameter for grouped_topk + - topk_group: Optional[int]: additional parameter for grouped_topk + - use_grouped_topk: If True, use grouped_topk instead of fused_topk + note: Deepseekv2 model uses grouped_topk - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner products for w1 and w2. Defaults to False. - w1_scale (Optional[torch.Tensor]): Optional scale to be used for @@ -556,12 +637,18 @@ def fused_moe( # Check constraints. assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" - if hasattr(ops, "topk_softmax"): - topk_weights, topk_ids = fused_topk( - hidden_states, gating_output, topk, renormalize + if use_grouped_topk: + assert num_expert_group is not None and topk_group is not None + topk_weights, topk_ids = grouped_topk( + hidden_states, + gating_output, + topk, + renormalize, + num_expert_group, + topk_group, ) else: - topk_weights, topk_ids = fused_topk_v0_4_3( + topk_weights, topk_ids = fused_topk( hidden_states, gating_output, topk, renormalize ) @@ -579,33 +666,3 @@ def fused_moe( a1_scale=a1_scale, a2_scale=a2_scale, ) - - -def fused_topk_v0_4_3( - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, -): - import vllm._moe_C as moe_kernels - - M, _ = hidden_states.shape - - topk_weights = torch.empty( - M, topk, dtype=torch.float32, device=hidden_states.device - ) - topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device) - token_expert_indicies = torch.empty( - M, topk, dtype=torch.int32, device=hidden_states.device - ) - moe_kernels.topk_softmax( - topk_weights, - topk_ids, - token_expert_indicies, - gating_output.float(), # TODO(woosuk): Optimize this. - ) - del token_expert_indicies # Not used. Will be used in the future. - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - - return topk_weights, topk_ids diff --git a/python/sglang/srt/layers/fused_moe/layer.py b/python/sglang/srt/layers/fused_moe/layer.py new file mode 100644 index 0000000000..e08ec5c58a --- /dev/null +++ b/python/sglang/srt/layers/fused_moe/layer.py @@ -0,0 +1,587 @@ +# Adapted from +# https://github.com/vllm-project/vllm/tree/v0.5.4/vllm/model_executor/layers/fused_moe +from abc import abstractmethod +from typing import List, Optional, Tuple + +import torch +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) +from vllm.logger import init_logger +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from vllm.model_executor.layers.quantization.fp8 import Fp8Config +from vllm.model_executor.utils import set_weight_attrs + +logger = init_logger(__name__) + + +class FusedMoEMethodBase(QuantizeMethodBase): + + @abstractmethod + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + raise NotImplementedError + + @abstractmethod + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + ) -> torch.Tensor: + raise NotImplementedError + + +class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): + """MoE method without quantization.""" + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + + # Fused gate_up_proj (column parallel) + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + # down_proj (row parallel) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, hidden_size, intermediate_size, dtype=params_dtype + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + ) -> torch.Tensor: + return self.forward( + x, + layer.w13_weight, + layer.w2_weight, + router_logits, + top_k, + renormalize, + use_grouped_topk, + num_expert_group, + topk_group, + ) + + def forward_cuda( + self, + x: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + num_expert_group: Optional[int], + topk_group: Optional[int], + ) -> torch.Tensor: + from sglang.srt.layers.fused_moe.fused_moe import fused_moe + + return fused_moe( + x, + w1, + w2, + router_logits, + top_k, + renormalize=renormalize, + inplace=True, + use_grouped_topk=use_grouped_topk, + num_expert_group=num_expert_group, + topk_group=topk_group, + ) + + def forward_cpu(self, *args, **kwargs): + raise NotImplementedError("The CPU backend currently does not support MoE.") + + def forward_tpu( + self, + x: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + num_expert_group: Optional[int], + topk_group: Optional[int], + ) -> torch.Tensor: + from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe + + assert not use_grouped_topk + assert num_expert_group is None + assert topk_group is None + return fused_moe(x, w1, w2, router_logits, top_k, renormalize) + + +class FusedMoE(torch.nn.Module): + """FusedMoE layer for MoE models. + + This layer contains both MergedColumnParallel weights (gate_up_proj / + w13) and RowParallelLinear weights (down_proj/ w2). + + Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We + copy that naming convention here and handle any remapping in the + load_weights function in each model implementation. + + Args: + num_experts: Number of experts in the model + top_k: Number of experts selected for each token + hidden_size: Input hidden state size of the transformer + intermediate_size: Intermediate size of the experts + params_dtype: Data type for the parameters. + reduce_results: Whether to all all_reduce on the output of the layer + renomalize: Whether to renormalize the logits in the fused_moe kernel + quant_config: Quantization configure. + """ + + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = False, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + prefix: str = "", + ): + super().__init__() + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + + self.tp_size = ( + tp_size if tp_size is not None else get_tensor_model_parallel_world_size() + ) + self.top_k = top_k + self.num_experts = num_experts + self.intermediate_size_per_partition = intermediate_size // self.tp_size + self.reduce_results = reduce_results + self.renormalize = renormalize + self.use_grouped_topk = use_grouped_topk + if self.use_grouped_topk: + assert num_expert_group is not None and topk_group is not None + self.num_expert_group = num_expert_group + self.topk_group = topk_group + + if quant_config is None: + self.quant_method: Optional[QuantizeMethodBase] = ( + UnquantizedFusedMoEMethod() + ) + else: + if isinstance(quant_config, Fp8Config): + self.quant_method = Fp8MoEMethod(quant_config) + else: + self.quant_method = quant_config.get_quant_method(self, prefix) + assert self.quant_method is not None + + self.quant_method.create_weights( + layer=self, + num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=self.intermediate_size_per_partition, + params_dtype=params_dtype, + weight_loader=self.weight_loader, + ) + + def weight_loader( + self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: int, + expert_id: int, + use_presharded_weights: bool = False, + ): + param_data = param.data + + # Input scales can be loaded directly and should be equal. + if "input_scale" in weight_name: + if ( + param_data[expert_id] != 1 + and (param_data[expert_id] - loaded_weight).abs() > 1e-5 + ): + raise ValueError( + "input_scales of w1 and w3 of a layer " + f"must be equal. But got {param_data[expert_id]} " + f"vs. {loaded_weight}" + ) + param_data[expert_id] = loaded_weight + # Weight scales + elif "weight_scale" in weight_name: + # If we are in merged column case (gate_up_proj) + # shard_id 0 == gate_proj / w1 + # shard_id 2 == up_proj / w3 + if shard_id == 0 or shard_id == 2: + # We have to keep the weight scales of w1 and w3 because + # we need to re-quantize w1/w3 weights after weight loading. + idx = 0 if shard_id == 0 else 1 + param_data[expert_id][idx] = loaded_weight + # If we are in the row parallel case (down_proj) + # shard_id 1 == down_proj / w2 + else: + param_data[expert_id] = loaded_weight + # Weights + else: + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.intermediate_size_per_partition + if use_presharded_weights: + shard = slice(None) + else: + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + + # w1, gate_proj case: Load into first shard of w13. + if shard_id == 0: + param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] + # w3, up_proj case: Load into second shard of w13. + elif shard_id == 2: + param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[ + shard, : + ] + # w2, down_proj case: Load into only shard of w2. + elif shard_id == 1: + param_data[expert_id, :, :] = loaded_weight[:, shard] + else: + raise ValueError(f"Shard id must be in [0,1,2] but got {shard_id}") + + def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): + assert self.quant_method is not None + + # Matrix multiply. + final_hidden_states = self.quant_method.apply( + self, + x=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + renormalize=self.renormalize, + use_grouped_topk=self.use_grouped_topk, + num_expert_group=self.num_expert_group, + topk_group=self.topk_group, + ) + + if self.reduce_results and self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) + + return final_hidden_states + + @classmethod + def make_expert_params_mapping( + cls, + ckpt_gate_proj_name: str, + ckpt_down_proj_name: str, + ckpt_up_proj_name: str, + num_experts: int, + ) -> List[Tuple[str, str, int, int]]: + + gate_up = [ckpt_gate_proj_name, ckpt_up_proj_name] + gate_down_up = [ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name] + + return ( + [ + # These are the weight scales for the experts + # (param_name, weight_name, expert_id, shard_id) + ( + ( + "experts.w13_scale" + if weight_name in gate_up + else "experts.w2_scale" + ), + f"experts.{expert_id}.{weight_name}.weight_scale", + expert_id, + shard_id, + ) + for expert_id in range(num_experts) + for shard_id, weight_name in enumerate(gate_down_up) + ] + + [ + # These are the weights for the experts + # (param_name, weight_name, expert_id, shard_id) + ( + ( + "experts.w13_weight" + if weight_name in gate_up + else "experts.w2_weight" + ), + f"experts.{expert_id}.{weight_name}.weight", + expert_id, + shard_id, + ) + for expert_id in range(num_experts) + for shard_id, weight_name in enumerate(gate_down_up) + ] + + [ + # These are the weight scales for the experts + # (param_name, weight_name, expert_id, shard_id) + ( + ( + "experts.a13_scale" + if weight_name in gate_up + else "experts.a2_scale" + ), + f"experts.{expert_id}.{weight_name}.input_scale", + expert_id, + shard_id, + ) + for expert_id in range(num_experts) + for shard_id, weight_name in enumerate(gate_down_up) + ] + ) + + +import torch +from torch.nn import Module +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + all_close_1d, + per_tensor_dequantize, +) +from vllm.utils import print_warning_once + + +class Fp8MoEMethod(FusedMoEMethodBase): + """MoE method for FP8. + Supports loading FP8 checkpoints with static weight scale and + dynamic/static activation scale. + + Also supports loading quantized FP16/BF16 model checkpoints with dynamic + activation scaling. The weight scaling factor will be initialized after + the model weights are loaded. + + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config: Fp8Config): + self.quant_config = quant_config + + def create_weights( + self, + layer: Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + + if self.quant_config.is_checkpoint_fp8_serialized: + params_dtype = torch.float8_e4m3fn + + # WEIGHTS + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, hidden_size, intermediate_size, dtype=params_dtype + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_scale = torch.nn.Parameter( + torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w13_scale", w13_scale) + + w2_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w2_scale", w2_scale) + + # If loading fp8 checkpoint, pass the weight loaders. + # If loading an fp16 checkpoint, do not (we will quantize in + # process_weights_after_loading() + if self.quant_config.is_checkpoint_fp8_serialized: + set_weight_attrs(w13_scale, extra_weight_attrs) + set_weight_attrs(w2_scale, extra_weight_attrs) + + # INPUT_SCALES + if self.quant_config.activation_scheme == "static": + if not self.quant_config.is_checkpoint_fp8_serialized: + raise ValueError( + "Found static activation scheme for checkpoint that " + "was not serialized fp8." + ) + + a13_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("a13_scale", a13_scale) + set_weight_attrs(a13_scale, extra_weight_attrs) + + a2_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("a2_scale", a2_scale) + set_weight_attrs(a2_scale, extra_weight_attrs) + else: + layer.a13_scale = None + layer.a2_scale = None + + def process_weights_after_loading(self, layer: Module) -> None: + + # If checkpoint is fp16, quantize in place. + if not self.quant_config.is_checkpoint_fp8_serialized: + w13_weight = torch.empty_like( + layer.w13_weight.data, dtype=torch.float8_e4m3fn + ) + w2_weight = torch.empty_like( + layer.w2_weight.data, dtype=torch.float8_e4m3fn + ) + + # Re-initialize w13_scale because we directly quantize + # merged w13 weights and generate a single scaling factor. + layer.w13_scale = torch.nn.Parameter( + torch.ones( + layer.num_experts, dtype=torch.float32, device=w13_weight.device + ), + requires_grad=False, + ) + for expert in range(layer.num_experts): + w13_weight[expert, :, :], layer.w13_scale[expert] = ( + ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) + ) + w2_weight[expert, :, :], layer.w2_scale[expert] = ops.scaled_fp8_quant( + layer.w2_weight.data[expert, :, :] + ) + layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + return + + # If checkpoint is fp8, we need to handle that the + # MoE kernels require single activation scale and single weight + # scale for w13 per expert. + else: + # Fp8 moe kernels require a single activation scale. + # We take the max of all the scales in case they differ. + if self.quant_config.activation_scheme == "static": + if layer.a13_scale is None or layer.a2_scale is None: + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None." + ) + if not all_close_1d(layer.a13_scale) or not all_close_1d( + layer.a2_scale + ): + print_warning_once( + "Found input_scales that are not equal for " + "fp8 MoE layer. Using the maximum across experts " + "for each layer. " + ) + layer.a13_scale = torch.nn.Parameter( + layer.a13_scale.max(), requires_grad=False + ) + layer.a2_scale = torch.nn.Parameter( + layer.a2_scale.max(), requires_grad=False + ) + + # Fp8 moe kernel needs single weight scale for w13 per expert. + # We take the max then dequant and requant each expert. + assert layer.w13_scale is not None + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_scale.max(dim=1).values + for expert_id in range(layer.num_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start : start + shard_size, :], + layer.w13_scale[expert_id][shard_id], + ) + layer.w13_weight[expert_id][start : start + shard_size, :], _ = ( + ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) + ) + start += shard_size + + layer.w13_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False) + return + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + ) -> torch.Tensor: + + from sglang.srt.layers.fused_moe.fused_moe import fused_moe + + return fused_moe( + x, + layer.w13_weight, + layer.w2_weight, + router_logits, + top_k, + renormalize=renormalize, + inplace=True, + use_fp8=True, + w1_scale=layer.w13_scale, + w2_scale=layer.w2_scale, + a1_scale=layer.a13_scale, + a2_scale=layer.a2_scale, + use_grouped_topk=use_grouped_topk, + num_expert_group=num_expert_group, + topk_group=topk_group, + ) diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index ac4d368d3f..4c24f50ffe 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -19,7 +19,12 @@ import torch import torch.nn as nn -from flashinfer.norm import fused_add_rmsnorm, rmsnorm +from flashinfer.norm import ( + fused_add_rmsnorm, + gemma_fused_add_rmsnorm, + gemma_rmsnorm, + rmsnorm, +) from vllm.model_executor.custom_op import CustomOp @@ -63,3 +68,44 @@ def forward_native( return x else: return x, residual + + +class GemmaRMSNorm(CustomOp): + def __init__( + self, + hidden_size: int, + eps: float = 1e-6, + ) -> None: + super().__init__() + self.weight = nn.Parameter(torch.zeros(hidden_size)) + self.variance_epsilon = eps + + def forward_native( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + orig_dtype = x.dtype + if residual is not None: + x = x + residual + residual = x + + x = x.float() + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + self.variance_epsilon) + x = x * (1.0 + self.weight.float()) + x = x.to(orig_dtype) + return x if residual is None else (x, residual) + + def forward_cuda( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if residual is not None: + gemma_fused_add_rmsnorm( + x, residual, self.weight.data, self.variance_epsilon + ) + return x, residual + out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon) + return out diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index cf5045fda5..b81f3d2a04 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -29,7 +29,7 @@ @dataclasses.dataclass -class LogitProcessorOutput: +class LogitsProcessorOutput: # The logits of the next tokens. shape: [#seq, vocab_size] next_token_logits: torch.Tensor # The logprobs of the next tokens. shape: [#seq, vocab_size] @@ -55,6 +55,9 @@ class LogitsMetadata: extend_start_loc: Optional[torch.Tensor] = None top_logprobs_nums: Optional[List[int]] = None + extend_seq_lens_cpu: List[int] = None + logprob_start_lens_cpu: List[int] = None + @classmethod def from_input_metadata(cls, input_metadata: InputMetadata): return cls( @@ -63,22 +66,30 @@ def from_input_metadata(cls, input_metadata: InputMetadata): extend_start_loc=input_metadata.extend_start_loc, return_logprob=input_metadata.return_logprob, top_logprobs_nums=input_metadata.top_logprobs_nums, + extend_seq_lens_cpu=input_metadata.extend_seq_lens_cpu, + logprob_start_lens_cpu=input_metadata.logprob_start_lens_cpu, ) class LogitsProcessor(nn.Module): - def __init__(self, config): + def __init__(self, config, skip_all_gather: bool = False): super().__init__() self.config = config - self.tp_size = get_tensor_model_parallel_world_size() + self.do_tensor_parallel_all_gather = ( + not skip_all_gather and get_tensor_model_parallel_world_size() > 1 + ) def _get_normalized_prompt_logprobs( - self, input_token_logprobs, logits_metadata: LogitsMetadata + self, + input_token_logprobs: torch.Tensor, + cum_start_len0: torch.Tensor, + cum_start_len1: torch.Tensor, + logits_metadata: LogitsMetadata, ): logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32) - start = logits_metadata.extend_start_loc.clone() - end = start + logits_metadata.extend_seq_lens - 2 + start = logits_metadata.extend_start_loc.clone() - cum_start_len0 + end = start + logits_metadata.extend_seq_lens - 2 - cum_start_len1 start.clamp_(min=0, max=input_token_logprobs.shape[0] - 1) end.clamp_(min=0, max=input_token_logprobs.shape[0] - 1) sum_logp = ( @@ -91,7 +102,7 @@ def _get_normalized_prompt_logprobs( return normalized_prompt_logprobs @staticmethod - def get_top_logprobs(all_logprobs, logits_metadata: LogitsMetadata): + def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata): if logits_metadata.forward_mode == ForwardMode.DECODE: output_top_logprobs = [] max_k = max(logits_metadata.top_logprobs_nums) @@ -105,7 +116,7 @@ def get_top_logprobs(all_logprobs, logits_metadata: LogitsMetadata): # TODO: vectorize the code below input_top_logprobs, output_top_logprobs = [], [] pt = 0 - extend_seq_lens_cpu = logits_metadata.extend_seq_lens.tolist() + extend_seq_lens_cpu = logits_metadata.extend_seq_lens_cpu max_k = max(logits_metadata.top_logprobs_nums) ret = all_logprobs.topk(max_k, dim=1) @@ -113,26 +124,30 @@ def get_top_logprobs(all_logprobs, logits_metadata: LogitsMetadata): indices = ret.indices.tolist() for i, extend_seq_len in enumerate(extend_seq_lens_cpu): + start_len = logits_metadata.logprob_start_lens_cpu[i] + pruned_len = extend_seq_len - start_len + if extend_seq_len == 0: input_top_logprobs.append([]) output_top_logprobs.append([]) continue + k = logits_metadata.top_logprobs_nums[i] input_top_logprobs.append( [ list(zip(values[pt + j][:k], indices[pt + j][:k])) - for j in range(extend_seq_len - 1) + for j in range(pruned_len - 1) ] ) output_top_logprobs.append( list( zip( - values[pt + extend_seq_len - 1][:k], - indices[pt + extend_seq_len - 1][:k], + values[pt + pruned_len - 1][:k], + indices[pt + pruned_len - 1][:k], ) ) ) - pt += extend_seq_len + pt += pruned_len return input_top_logprobs, output_top_logprobs @@ -159,18 +174,18 @@ def forward( last_hidden = hidden_states[last_index] last_logits = torch.matmul(last_hidden, weight.T) - if self.tp_size > 1: + if self.do_tensor_parallel_all_gather: last_logits = tensor_model_parallel_all_gather(last_logits) last_logits = last_logits[:, : self.config.vocab_size].float() if hasattr(self.config, "final_logit_softcapping"): - last_logits /= self.config.final_logit_softcapping - last_logits = torch.tanh(last_logits) - last_logits *= self.config.final_logit_softcapping + last_logits.div_(self.config.final_logit_softcapping) + torch.tanh(last_logits, out=last_logits) + last_logits.mul_(self.config.final_logit_softcapping) # Return only last_logits if logprob is not requested if not logits_metadata.return_logprob: - return LogitProcessorOutput( + return LogitsProcessorOutput( next_token_logits=last_logits, next_token_logprobs=None, normalized_prompt_logprobs=None, @@ -194,7 +209,7 @@ def forward( else: output_top_logprobs = None - return LogitProcessorOutput( + return LogitsProcessorOutput( next_token_logits=last_logits, next_token_logprobs=last_logprobs, normalized_prompt_logprobs=None, @@ -203,15 +218,31 @@ def forward( output_top_logprobs=output_top_logprobs, ) else: - all_logits = torch.matmul(hidden_states, weight.T) - if self.tp_size > 1: + pt, states, pruned_input_ids = 0, [], [] + for i, extend_len in enumerate(logits_metadata.extend_seq_lens_cpu): + start_len = logits_metadata.logprob_start_lens_cpu[i] + states.append(hidden_states[pt + start_len : pt + extend_len]) + pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len]) + pt += extend_len + + states = torch.cat(states, dim=0) + pruned_input_ids = torch.cat(pruned_input_ids, dim=0) + + cum_start_len1 = torch.tensor( + logits_metadata.logprob_start_lens_cpu, device="cuda" + ).cumsum(0) + cum_start_len0 = torch.zeros_like(cum_start_len1) + cum_start_len0[1:] = cum_start_len1[:-1] + + all_logits = torch.matmul(states, weight.T) + if self.do_tensor_parallel_all_gather: all_logits = tensor_model_parallel_all_gather(all_logits) all_logits = all_logits[:, : self.config.vocab_size].float() if hasattr(self.config, "final_logit_softcapping"): - all_logits /= self.config.final_logit_softcapping - all_logits = torch.tanh(all_logits) - all_logits *= self.config.final_logit_softcapping + all_logits.div_(self.config.final_logit_softcapping) + torch.tanh(all_logits, out=all_logits) + all_logits.mul_(self.config.final_logit_softcapping) all_logprobs = all_logits del all_logits, hidden_states @@ -228,20 +259,26 @@ def forward( else: input_top_logprobs = output_top_logprobs = None - last_logprobs = all_logprobs[last_index] + last_logprobs = all_logprobs[last_index - cum_start_len1] # Compute the logprobs and normalized logprobs for the prefill tokens. # Note that we pad a zero at the end of each sequence for easy computation. input_token_logprobs = all_logprobs[ torch.arange(all_logprobs.shape[0], device="cuda"), - torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]), + torch.cat([pruned_input_ids[1:], torch.tensor([0], device="cuda")]), ] normalized_prompt_logprobs = self._get_normalized_prompt_logprobs( - input_token_logprobs, logits_metadata + input_token_logprobs, + cum_start_len0, + cum_start_len1, + logits_metadata, ) - return LogitProcessorOutput( + # Remove the last token logprob for the prefill tokens. + input_token_logprobs = input_token_logprobs[:-1] + + return LogitsProcessorOutput( next_token_logits=last_logits, next_token_logprobs=last_logprobs, normalized_prompt_logprobs=normalized_prompt_logprobs, diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 1568cf6d96..91735a1b81 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -15,6 +15,8 @@ """Radix attention.""" +from typing import Optional + import torch from flashinfer.cascade import merge_state from torch import nn @@ -34,6 +36,7 @@ def __init__( scaling: float, num_kv_heads: int, layer_id: int, + sliding_window_size: Optional[int] = None, logit_cap: int = -1, v_head_dim: int = -1, ): @@ -46,6 +49,7 @@ def __init__( self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim self.scaling = scaling self.layer_id = layer_id + self.sliding_window_size = sliding_window_size if sliding_window_size else -1 if ( not global_server_args_dict.get("disable_flashinfer", False) @@ -113,14 +117,25 @@ def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata): return o def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata): + # using two wrappers is unnecessary in the current PR, but are prepared for future PRs + prefill_wrapper_paged = input_metadata.flashinfer_prefill_wrapper_paged + if self.sliding_window_size != -1: + prefill_wrapper_paged = prefill_wrapper_paged[0] + else: + if isinstance(prefill_wrapper_paged, list): + prefill_wrapper_paged = prefill_wrapper_paged[1] + if not input_metadata.flashinfer_use_ragged: - self.store_kv_cache(k, v, input_metadata) + if k is not None: + assert v is not None + self.store_kv_cache(k, v, input_metadata) - o = input_metadata.flashinfer_prefill_wrapper_paged.forward( + o = prefill_wrapper_paged.forward( q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id), causal=True, sm_scale=self.scaling, + window_left=self.sliding_window_size, logits_soft_cap=self.logit_cap, ) else: @@ -138,14 +153,12 @@ def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata): if input_metadata.extend_no_prefix: o = o1 else: - o2, s2 = ( - input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse( - q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), - input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id), - causal=False, - sm_scale=self.scaling, - logits_soft_cap=self.logit_cap, - ) + o2, s2 = prefill_wrapper_paged.forward_return_lse( + q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), + input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id), + causal=False, + sm_scale=self.scaling, + logits_soft_cap=self.logit_cap, ) o, _ = merge_state(o1, s1, o2, s2) @@ -158,9 +171,18 @@ def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata): return o.view(-1, self.tp_q_head_num * self.head_dim) def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata): - self.store_kv_cache(k, v, input_metadata) + decode_wrapper = input_metadata.flashinfer_decode_wrapper + if self.sliding_window_size != -1: + decode_wrapper = decode_wrapper[0] + else: + if isinstance(decode_wrapper, list): + decode_wrapper = decode_wrapper[1] - o = input_metadata.flashinfer_decode_wrapper.forward( + if k is not None: + assert v is not None + self.store_kv_cache(k, v, input_metadata) + + o = decode_wrapper.forward( q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id), sm_scale=self.scaling, @@ -170,8 +192,10 @@ def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata): return o.view(-1, self.tp_q_head_num * self.head_dim) def forward(self, q, k, v, input_metadata: InputMetadata): - k = k.view(-1, self.tp_k_head_num, self.qk_head_dim) - v = v.view(-1, self.tp_v_head_num, self.v_head_dim) + if k is not None: + assert v is not None + k = k.view(-1, self.tp_k_head_num, self.qk_head_dim) + v = v.view(-1, self.tp_v_head_num, self.v_head_dim) if input_metadata.forward_mode == ForwardMode.EXTEND: return self.extend_forward(q, k, v, input_metadata) @@ -179,7 +203,6 @@ def forward(self, q, k, v, input_metadata: InputMetadata): return self.decode_forward(q, k, v, input_metadata) def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata): - k_cache = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id) - v_cache = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id) - k_cache[input_metadata.out_cache_loc] = cache_k - v_cache[input_metadata.out_cache_loc] = cache_v + input_metadata.token_to_kv_pool.set_kv_buffer( + self.layer_id, input_metadata.out_cache_loc, cache_k, cache_v + ) diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py new file mode 100644 index 0000000000..6cb7d0a7c1 --- /dev/null +++ b/python/sglang/srt/layers/sampler.py @@ -0,0 +1,154 @@ +import dataclasses +import logging +from typing import Union + +import torch +from flashinfer.sampling import ( + min_p_sampling_from_probs, + top_k_renorm_prob, + top_k_top_p_sampling_from_probs, + top_p_renorm_prob, +) +from vllm.model_executor.custom_op import CustomOp + +from sglang.srt.layers.logits_processor import LogitsProcessorOutput + +# TODO: move this dict to another place +from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class SampleOutput: + success: torch.Tensor + probs: torch.Tensor + batch_next_token_ids: torch.Tensor + + +class Sampler(CustomOp): + def __init__(self): + super().__init__() + + def _apply_penalties(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo): + # min-token, presence, frequency + if sampling_info.linear_penalties is not None: + logits += sampling_info.linear_penalties + + # repetition + if sampling_info.scaling_penalties is not None: + logits = torch.where( + logits > 0, + logits / sampling_info.scaling_penalties, + logits * sampling_info.scaling_penalties, + ) + + return logits + + def _get_probs( + self, + logits: torch.Tensor, + sampling_info: SamplingBatchInfo, + is_torch_compile: bool = False, + ): + # Post process logits + logits = logits.contiguous() + logits.div_(sampling_info.temperatures) + if is_torch_compile: + # FIXME: Temporary workaround for unknown bugs in torch.compile + logits.add_(0) + + if sampling_info.logit_bias is not None: + logits.add_(sampling_info.logit_bias) + + if sampling_info.vocab_mask is not None: + logits = logits.masked_fill(~sampling_info.vocab_mask, float("-inf")) + + logits = self._apply_penalties(logits, sampling_info) + + return torch.softmax(logits, dim=-1) + + def forward_cuda( + self, + logits: Union[torch.Tensor, LogitsProcessorOutput], + sampling_info: SamplingBatchInfo, + ): + if isinstance(logits, LogitsProcessorOutput): + logits = logits.next_token_logits + + probs = self._get_probs(logits, sampling_info) + + if not global_server_args_dict["disable_flashinfer_sampling"]: + max_top_k_round, batch_size = 32, probs.shape[0] + uniform_samples = torch.rand( + (max_top_k_round, batch_size), device=probs.device + ) + if sampling_info.need_min_p_sampling: + probs = top_k_renorm_prob(probs, sampling_info.top_ks) + probs = top_p_renorm_prob(probs, sampling_info.top_ps) + batch_next_token_ids, success = min_p_sampling_from_probs( + probs, uniform_samples, sampling_info.min_ps + ) + else: + batch_next_token_ids, success = top_k_top_p_sampling_from_probs( + probs, uniform_samples, sampling_info.top_ks, sampling_info.top_ps + ) + else: + # Here we provide a slower fallback implementation. + batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch( + probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps + ) + + return SampleOutput(success, probs, batch_next_token_ids) + + def forward_native( + self, + logits: Union[torch.Tensor, LogitsProcessorOutput], + sampling_info: SamplingBatchInfo, + ): + if isinstance(logits, LogitsProcessorOutput): + logits = logits.next_token_logits + + probs = self._get_probs(logits, sampling_info, is_torch_compile=True) + + batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch( + probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps + ) + + return SampleOutput(success, probs, batch_next_token_ids) + + +def top_k_top_p_min_p_sampling_from_probs_torch( + probs: torch.Tensor, + top_ks: torch.Tensor, + top_ps: torch.Tensor, + min_ps: torch.Tensor, +): + """A top-k, top-p and min-p sampling implementation with native pytorch operations.""" + probs_sort, probs_idx = probs.sort(dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + min_p_thresholds = probs_sort[:, 0] * min_ps + probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0 + probs_sort[ + torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1) + >= top_ks.view(-1, 1) + ] = 0.0 + probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0 + probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0]) + try: + # FIXME: torch.multiomial does not support num_samples = 1 + sampled_index = torch.multinomial(probs_sort, num_samples=2, replacement=True)[ + :, :1 + ] + except RuntimeError as e: + logger.warning(f"Sampling error: {e}") + batch_next_token_ids = torch.zeros( + (probs_sort.shape[0],), dtype=torch.int32, device=probs.device + ) + success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device) + return batch_next_token_ids, success + + batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1) + success = torch.ones(probs.shape[0], dtype=torch.bool, device=probs.device) + return batch_next_token_ids, success diff --git a/python/sglang/srt/managers/controller_multi.py b/python/sglang/srt/managers/controller_multi.py index 98a1464656..67c4454c43 100644 --- a/python/sglang/srt/managers/controller_multi.py +++ b/python/sglang/srt/managers/controller_multi.py @@ -40,6 +40,8 @@ from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import kill_parent_process from sglang.utils import get_cache_info, get_exception_traceback +from sglang.srt.utils import configure_logger, kill_parent_process +from sglang.utils import get_exception_traceback logger = logging.getLogger(__name__) @@ -239,10 +241,7 @@ def start_controller_process( ): """Start a controller process.""" - logging.basicConfig( - level=getattr(logging, server_args.log_level.upper()), - format="%(message)s", - ) + configure_logger(server_args) try: controller = ControllerMultiFlex(server_args, port_args, model_overide_args) @@ -257,6 +256,4 @@ def start_controller_process( except Exception: logger.error("Exception in ControllerMultiFlex:\n" + get_exception_traceback()) finally: - for w in controller.workers: - os.kill(w.proc.pid, 9) kill_parent_process() diff --git a/python/sglang/srt/managers/controller_single.py b/python/sglang/srt/managers/controller_single.py index 2ada4fa2aa..4175b6a252 100644 --- a/python/sglang/srt/managers/controller_single.py +++ b/python/sglang/srt/managers/controller_single.py @@ -17,7 +17,6 @@ import logging import multiprocessing -import os from typing import List import numpy as np @@ -30,7 +29,7 @@ launch_tp_servers, ) from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.utils import kill_parent_process +from sglang.srt.utils import configure_logger, kill_parent_process from sglang.utils import get_exception_traceback logger = logging.getLogger(__name__) @@ -58,7 +57,7 @@ def __init__( # Need by multi flex infer self.controller_info = controller_info - # Init communication + # Init inter-process communication context = zmq.Context(2) if not self.is_dp_worker: @@ -142,11 +141,11 @@ def start_controller_process( controller_info: ControllerInfo = None, ): """Start a controller process.""" - - logging.basicConfig( - level=getattr(logging, server_args.log_level.upper()), - format="%(message)s", - ) + if is_data_parallel_worker: + logger_prefix = f" DP{dp_worker_id} TP0" + else: + logger_prefix = " TP0" + configure_logger(server_args, prefix=logger_prefix) if not is_data_parallel_worker: tp_size_local = server_args.tp_size // server_args.nnodes @@ -176,6 +175,4 @@ def start_controller_process( except Exception: logger.error("Exception in ControllerSingle:\n" + get_exception_traceback()) finally: - for t in controller.tp_procs: - os.kill(t.pid, 9) kill_parent_process() diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 08ccfd5cef..cd5f63125c 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -17,7 +17,6 @@ import asyncio import dataclasses -import inspect from typing import List import uvloop @@ -29,6 +28,7 @@ BatchEmbeddingOut, BatchStrOut, BatchTokenIDOut, + UpdateWeightReqOutput, ) from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR from sglang.srt.server_args import PortArgs, ServerArgs @@ -39,6 +39,8 @@ @dataclasses.dataclass class DecodeStatus: + """Store the status of incremental decoding.""" + vid: int decoded_text: str decode_ids: List[int] @@ -47,11 +49,14 @@ class DecodeStatus: class DetokenizerManager: + """DetokenizerManager is a process that detokenizes the token ids.""" + def __init__( self, server_args: ServerArgs, port_args: PortArgs, ): + # Init inter-process communication context = zmq.asyncio.Context(2) self.recv_from_router = context.socket(zmq.PULL) self.recv_from_router.bind(f"tcp://127.0.0.1:{port_args.detokenizer_port}") @@ -71,10 +76,13 @@ def __init__( self.decode_status = {} async def handle_loop(self): + """The event loop that handles requests""" + while True: - recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj() + recv_obj = await self.recv_from_router.recv_pyobj() if isinstance(recv_obj, BatchEmbeddingOut): + # If it is embedding model, no detokenization is needed. self.send_to_tokenizer.send_pyobj( BatchEmbeddingOut( rids=recv_obj.rids, @@ -84,15 +92,18 @@ async def handle_loop(self): ) ) continue + elif isinstance(recv_obj, UpdateWeightReqOutput): + # If it is a weight update request, no detokenization is needed. + self.send_to_tokenizer.send_pyobj(recv_obj) + continue + elif self.tokenizer is None: + # If the tokenizer is skipped, no detokenization is needed + self.send_to_tokenizer.send_pyobj(recv_obj) + continue assert isinstance(recv_obj, BatchTokenIDOut) bs = len(recv_obj.rids) - if self.tokenizer is None: - # Send BatchTokenIDOut if no tokenizer init'ed. - self.send_to_tokenizer.send_pyobj(recv_obj) - continue - # Initialize decode status read_ids, surr_ids = [], [] for i in range(bs): @@ -126,8 +137,7 @@ async def handle_loop(self): spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0], ) - # Trim stop str - # TODO(lmzheng): handle the case where multiple stop strs are hit + # Incremental decoding output_strs = [] for i in range(bs): s = self.decode_status[recv_obj.rids[i]] @@ -144,6 +154,7 @@ async def handle_loop(self): output_strs.append(s.decoded_text + new_text) + # Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR): pos = output_strs[i].find(recv_obj.finished_reason[i].matched) if pos != -1: diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 5df635d270..85cbb27761 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -20,6 +20,7 @@ import multiprocessing import uuid +import copy from dataclasses import dataclass from multiprocessing import Value from typing import Dict, List, Optional, Union @@ -32,6 +33,7 @@ from sglang.utils import get_cache_info + @dataclass class GenerateReqInput: # The input prompt. It can be a single prompt or a batch of prompts. @@ -47,9 +49,9 @@ class GenerateReqInput: rid: Optional[Union[List[str], str]] = None # Whether to return logprobs. return_logprob: Optional[Union[List[bool], bool]] = None - # The start location of the prompt for return_logprob. + # If return logprobs, the start location in the prompt for returning logprobs. logprob_start_len: Optional[Union[List[int], int]] = None - # The number of top logprobs to return. + # If return logprobs, the number of top logprobs to return at each position. top_logprobs_num: Optional[Union[List[int], int]] = None # Whether to detokenize tokens in text in the returned logprobs. return_text_in_logprobs: bool = False @@ -61,6 +63,7 @@ def post_init(self): self.text is not None and self.input_ids is not None ): raise ValueError("Either text or input_ids should be provided.") + if ( isinstance(self.sampling_params, dict) and self.sampling_params.get("n", 1) != 1 @@ -81,7 +84,7 @@ def post_init(self): if self.return_logprob is None: self.return_logprob = False if self.logprob_start_len is None: - self.logprob_start_len = 0 + self.logprob_start_len = -1 if self.top_logprobs_num is None: self.top_logprobs_num = 0 else: @@ -147,7 +150,7 @@ def post_init(self): self.return_logprob = [self.return_logprob] * num if self.logprob_start_len is None: - self.logprob_start_len = [0] * num + self.logprob_start_len = [-1] * num elif not isinstance(self.logprob_start_len, list): self.logprob_start_len = [self.logprob_start_len] * num @@ -159,16 +162,27 @@ def post_init(self): @dataclass class TokenizedGenerateReqInput: + # The request id rid: str + # The input text input_text: str + # The input token ids input_ids: List[int] + # The pixel values for input images pixel_values: List[float] - image_hash: int - image_size: List[int] + # The hash values of input images + image_hashes: List[int] + # The image sizes + image_sizes: List[List[int]] + # The sampling parameters sampling_params: SamplingParams + # Whether to return the logprobs return_logprob: bool + # If return logprobs, the start location in the prompt for returning logprobs. logprob_start_len: int + # If return logprobs, the number of top logprobs to return at each position. top_logprobs_num: int + # Whether to stream output stream: bool @@ -219,15 +233,21 @@ def post_init(self): @dataclass class TokenizedEmbeddingReqInput: + # The request id rid: str + # The input text input_text: str + # The input token ids input_ids: List[int] + # Dummy sampling params for compatibility sampling_params: SamplingParams @dataclass class BatchTokenIDOut: + # The request id rids: List[str] + # The version id to sync decode status with in detokenizer_manager vids: List[int] decoded_texts: List[str] decode_ids: List[int] @@ -237,20 +257,32 @@ class BatchTokenIDOut: meta_info: List[Dict] finished_reason: List[BaseFinishReason] + def __post_init__(self): + # deepcopy meta_info to avoid modification in place + self.meta_info = copy.deepcopy(self.meta_info) + @dataclass class BatchStrOut: + # The request id rids: List[str] + # The output decoded strings output_strs: List[str] + # The meta info meta_info: List[Dict] + # The finish reason finished_reason: List[BaseFinishReason] @dataclass class BatchEmbeddingOut: + # The request id rids: List[str] + # The output embedding embeddings: List[List[float]] + # The meta info meta_info: List[Dict] + # The finish reason finished_reason: List[BaseFinishReason] @@ -260,8 +292,11 @@ class FlushCacheReq: @dataclass -class AbortReq: - rid: str +class UpdateWeightReqInput: + # The model path with the new weights + model_path: str + # The format to load the weights + load_format: Optional[str] = None @dataclass @@ -283,3 +318,14 @@ def __init__(self, server_args, model_overide_args): self.running_reqs.append(Value("i", 0)) self.waiting_reqs.append(Value("i", 0)) self.swap_in_queue.append(multiprocessing.Queue()) + +class UpdateWeightReqOutput: + success: bool + message: str + + +@dataclass +class AbortReq: + # The request id + rid: str + diff --git a/python/sglang/srt/managers/policy_scheduler.py b/python/sglang/srt/managers/policy_scheduler.py index 9d5f991975..04169e8086 100644 --- a/python/sglang/srt/managers/policy_scheduler.py +++ b/python/sglang/srt/managers/policy_scheduler.py @@ -111,11 +111,14 @@ def __init__( rem_total_tokens: int, rem_input_tokens: int, rem_chunk_tokens: Optional[int], + mixed_with_decode_tokens: int = 0, ): self.tree_cache = tree_cache - self.rem_total_tokens = rem_total_tokens - self.rem_input_tokens = rem_input_tokens + self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens + self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens self.rem_chunk_tokens = rem_chunk_tokens + if self.rem_chunk_tokens is not None: + self.rem_chunk_tokens -= mixed_with_decode_tokens self.can_run_list = [] self.new_inflight_req = None diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index a461fa1812..f5b9c9eb27 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1,3 +1,5 @@ +from __future__ import annotations + """ Copyright 2023-2024 SGLang Team Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,20 +18,22 @@ """Meta data for requests and batches""" import logging -import warnings from dataclasses import dataclass -from typing import List, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Union import torch -from flashinfer.sampling import top_k_top_p_sampling_from_probs -import sglang.srt.sampling.penaltylib as penaltylib from sglang.global_config import global_config from sglang.srt.constrained import RegexGuide from sglang.srt.constrained.jump_forward import JumpForwardMap from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool +from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo + +if TYPE_CHECKING: + from sglang.srt.layers.sampler import SampleOutput + INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 @@ -37,7 +41,7 @@ global_server_args_dict = { "disable_flashinfer": False, "disable_flashinfer_sampling": False, - "attention_reduce_in_fp32": False, + "triton_attention_reduce_in_fp32": False, "enable_mla": False, } @@ -123,8 +127,8 @@ def __init__(self, rid, origin_input_text, origin_input_ids): # For vision input self.pixel_values = None - self.image_size = None - self.image_offset = None + self.image_sizes = None + self.image_offsets = None self.pad_value = None # Prefix info @@ -235,10 +239,12 @@ def check_finished(self): return last_token_id = self.output_ids[-1] - if self.tokenizer is None: - matched_eos = last_token_id in self.sampling_params.stop_token_ids - else: - matched_eos = last_token_id == self.tokenizer.eos_token_id + + matched_eos = last_token_id in self.sampling_params.stop_token_ids + + if self.tokenizer is not None: + matched_eos |= last_token_id == self.tokenizer.eos_token_id + if matched_eos and not self.sampling_params.ignore_eos: self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id) return @@ -262,11 +268,18 @@ def jump_forward_and_retokenize(self, jump_forward_str, next_state): all_text = self.origin_input_text + self.decoded_text + jump_forward_str all_ids = self.tokenizer.encode(all_text) + if not all_ids: + logger.warning("Encoded all_text resulted in empty all_ids") + return False + prompt_tokens = len(self.origin_input_ids_unpadded) + if prompt_tokens > len(all_ids): + logger.warning("prompt_tokens is larger than encoded all_ids") + return False if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]: # TODO(lsyin): fix token fusion - warnings.warn( + logger.warning( "Token fusion between input and output, try to avoid this by removing the space at the end of the input." ) return False @@ -325,17 +338,13 @@ class ScheduleBatch: out_cache_loc: torch.Tensor = None extend_num_tokens: int = None + # For mixed chunekd prefill + prefix_lens_cpu: List[int] = None + # For processing logprobs return_logprob: bool = False top_logprobs_nums: List[int] = None - # Batched sampling params - temperatures: torch.Tensor = None - top_ps: torch.Tensor = None - top_ks: torch.Tensor = None - penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None - logit_bias: torch.Tensor = None - @classmethod def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache): return_logprob = any(req.return_logprob for req in reqs) @@ -383,51 +392,7 @@ def alloc_token_slots(self, num_tokens: int): return out_cache_loc - def batch_sampling_params(self, vocab_size, int_token_logit_bias): - device = "cuda" - bs, reqs = self.batch_size(), self.reqs - self.temperatures = torch.tensor( - [r.sampling_params.temperature for r in reqs], - dtype=torch.float, - device=device, - ).view(-1, 1) - self.top_ps = torch.tensor( - [r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device - ) - self.top_ks = torch.tensor( - [r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device - ) - - # Each penalizers will do nothing if they evaluate themselves as not required by looking at - # the sampling_params of the requests (See {_is_required()} of each penalizers). So this - # should not add hefty computation overhead other than simple checks. - # - # While we choose not to even create the class instances if they are not required, this - # could add additional complexity to the {ScheduleBatch} class, especially we need to - # handle {filter_batch()} and {merge()} cases as well. - self.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator( - vocab_size=vocab_size, - batch=self, - device=device, - Penalizers={ - penaltylib.BatchedFrequencyPenalizer, - penaltylib.BatchedMinNewTokensPenalizer, - penaltylib.BatchedPresencePenalizer, - penaltylib.BatchedRepetitionPenalizer, - }, - ) - - # Handle logit bias but only allocate when needed - self.logit_bias = None - for i in range(bs): - if reqs[i].sampling_params.dtype == "int": - if self.logit_bias is None: - self.logit_bias = torch.zeros( - (bs, vocab_size), dtype=torch.float32, device=device - ) - self.logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias - - def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor): + def prepare_for_extend(self, vocab_size: int): bs = self.batch_size() reqs = self.reqs input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs] @@ -465,8 +430,32 @@ def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor self.extend_num_tokens = extend_num_tokens self.out_cache_loc = out_cache_loc self.top_logprobs_nums = [r.top_logprobs_num for r in reqs] + self.prefix_lens_cpu = [len(r.prefix_indices) for r in reqs] - self.batch_sampling_params(vocab_size, int_token_logit_bias) + self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size) + + def mix_with_running(self, running_batch: "ScheduleBatch"): + # NOTE: prefix_indices is what has been cached, but we don't cache each decode step + prefix_lens_cpu = [len(r.prefix_indices) for r in self.reqs] + prefix_lens_cpu.extend( + [ + len(r.origin_input_ids) + len(r.output_ids) - 1 + for r in running_batch.reqs + ] + ) + + for req in running_batch.reqs: + req.fill_ids = req.origin_input_ids + req.output_ids + req.extend_input_len = 1 + + input_ids = torch.cat([self.input_ids, running_batch.input_ids]) + out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc]) + extend_num_tokens = self.extend_num_tokens + running_batch.batch_size() + self.merge(running_batch) + self.input_ids = input_ids + self.out_cache_loc = out_cache_loc + self.extend_num_tokens = extend_num_tokens + self.prefix_lens_cpu = prefix_lens_cpu def check_decode_mem(self): bs = self.batch_size() @@ -617,12 +606,12 @@ def check_for_jump_forward(self, model_runner): if req.pixel_values is not None: ( req.origin_input_ids, - req.image_offset, + req.image_offsets, ) = model_runner.model.pad_input_ids( req.origin_input_ids_unpadded, req.pad_value, - req.pixel_values.shape, - req.image_size, + req.pixel_values, + req.image_sizes, ) jump_forward_reqs.append(req) @@ -639,7 +628,7 @@ def prepare_for_decode(self, input_ids=None): for r in self.reqs ] else: - self.penalizer_orchestrator.cumulate_input_tokens(input_ids) + self.sampling_info.penalizer_orchestrator.cumulate_input_tokens(input_ids) self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda") self.seq_lens.add_(1) @@ -652,6 +641,8 @@ def prepare_for_decode(self, input_ids=None): self.req_pool_indices, self.seq_lens - 1 ] = self.out_cache_loc + self.sampling_info.update_regex_vocab_mask(self) + def filter_batch(self, unfinished_indices: List[int]): if unfinished_indices is None or len(unfinished_indices) == 0: # Filter out all requests @@ -672,23 +663,13 @@ def filter_batch(self, unfinished_indices: List[int]): self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices] self.return_logprob = any(req.return_logprob for req in self.reqs) - self.penalizer_orchestrator.filter(unfinished_indices, new_indices) - - for item in [ - "temperatures", - "top_ps", - "top_ks", - "logit_bias", - ]: - self_val = getattr(self, item, None) - if self_val is not None: # logit_bias can be None - setattr(self, item, self_val[new_indices]) + self.sampling_info.filter(unfinished_indices, new_indices) def merge(self, other: "ScheduleBatch"): # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it # needs to be called with pre-merged Batch.reqs. - self.penalizer_orchestrator.merge(other.penalizer_orchestrator) + self.sampling_info.merge(other.sampling_info) self.reqs.extend(other.reqs) @@ -703,111 +684,17 @@ def merge(self, other: "ScheduleBatch"): self.top_logprobs_nums.extend(other.top_logprobs_nums) self.return_logprob = any(req.return_logprob for req in self.reqs) - for item in [ - "temperatures", - "top_ps", - "top_ks", - ]: - self_val = getattr(self, item, None) - other_val = getattr(other, item, None) - setattr(self, item, torch.concat([self_val, other_val])) - - # logit_bias can be None - if self.logit_bias is not None or other.logit_bias is not None: - vocab_size = ( - self.logit_bias.shape[1] - if self.logit_bias is not None - else other.logit_bias.shape[1] - ) - if self.logit_bias is None: - self.logit_bias = torch.zeros( - (len(self.reqs), vocab_size), dtype=torch.float32, device="cuda" - ) - if other.logit_bias is None: - other.logit_bias = torch.zeros( - (len(other.reqs), vocab_size), dtype=torch.float32, device="cuda" - ) - self.logit_bias = torch.concat([self.logit_bias, other.logit_bias]) - - def sample(self, logits: torch.Tensor): - # TODO(lsyin): move this into a part of layer and run with CUDA Graph - # Post process logits - logits = logits.contiguous() - logits.div_(self.temperatures) - if self.logit_bias is not None: - logits.add_(self.logit_bias) - - has_regex = any(req.regex_fsm is not None for req in self.reqs) - if has_regex: - allowed_mask = torch.empty_like(logits[0], dtype=torch.bool) - for i, req in enumerate(self.reqs): - if req.regex_fsm is not None: - allowed_mask.zero_() - allowed_mask[ - req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens - ] = 1 - logits[i].masked_fill_(~allowed_mask, float("-inf")) - - logits = self.penalizer_orchestrator.apply(logits) - - probs = torch.softmax(logits, dim=-1) - - if not global_server_args_dict["disable_flashinfer_sampling"]: - max_top_k_round, batch_size = 32, probs.shape[0] - uniform_samples = torch.rand( - (max_top_k_round, batch_size), device=probs.device - ) - batch_next_token_ids, success = top_k_top_p_sampling_from_probs( - probs, uniform_samples, self.top_ks, self.top_ps - ) - else: - # Here we provide a slower fallback implementation. - batch_next_token_ids, success = top_k_top_p_sampling_from_probs_torch( - probs, self.top_ks, self.top_ps - ) - - if not torch.all(success): - warnings.warn("Sampling failed, fallback to top_k=1 strategy") + def check_sample_results(self, sample_output: SampleOutput): + if not torch.all(sample_output.success): + probs = sample_output.probs + batch_next_token_ids = sample_output.batch_next_token_ids + logging.warning("Sampling failed, fallback to top_k=1 strategy") probs = probs.masked_fill(torch.isnan(probs), 0.0) argmax_ids = torch.argmax(probs, dim=-1) batch_next_token_ids = torch.where( - success, batch_next_token_ids, argmax_ids + sample_output.success, batch_next_token_ids, argmax_ids ) + sample_output.probs = probs + sample_output.batch_next_token_ids = batch_next_token_ids - if has_regex: - batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy() - for i, req in enumerate(self.reqs): - if req.regex_fsm is not None: - req.regex_fsm_state = req.regex_fsm.get_next_state( - req.regex_fsm_state, batch_next_token_ids_cpu[i] - ) - - self.penalizer_orchestrator.cumulate_output_tokens(batch_next_token_ids) - - return batch_next_token_ids - - -def top_k_top_p_sampling_from_probs_torch( - probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor -): - """A top-k and top-k sampling implementation with native pytorch operations.""" - probs_sort, probs_idx = probs.sort(dim=-1, descending=True) - probs_sum = torch.cumsum(probs_sort, dim=-1) - probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0 - probs_sort[ - torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1) - >= top_ks.view(-1, 1) - ] = 0.0 - probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0]) - try: - sampled_index = torch.multinomial(probs_sort, num_samples=1) - except RuntimeError: - batch_next_token_ids = torch.zeros( - (probs_sort.shape[0],), dtype=torch.int32, device=probs.device - ) - success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device) - return batch_next_token_ids, success - - batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1) - success = torch.ones(probs.shape[0], dtype=torch.bool, device=probs.device) - return batch_next_token_ids, success + return sample_output.batch_next_token_ids diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index e1bfbc7e67..5ad4152ea9 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -21,8 +21,9 @@ import logging import multiprocessing as mp import os -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union +import fastapi import numpy as np import transformers import uvloop @@ -46,9 +47,11 @@ GenerateReqInput, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, + UpdateWeightReqInput, + UpdateWeightReqOutput, ) from sglang.srt.mm_utils import expand2square, process_anyres_image -from sglang.srt.sampling_params import SamplingParams +from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import is_generation_model, is_multimodal_model, load_image from sglang.utils import get_exception_traceback @@ -60,12 +63,16 @@ @dataclasses.dataclass class ReqState: + """Store the state a request.""" + out_list: List finished: bool event: asyncio.Event class TokenizerManager: + """TokenizerManager is a process that tokenizes the text.""" + def __init__( self, server_args: ServerArgs, @@ -74,6 +81,7 @@ def __init__( ): self.server_args = server_args + # Init inter-process communication context = zmq.asyncio.Context(2) self.recv_from_detokenizer = context.socket(zmq.PULL) self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}") @@ -81,6 +89,7 @@ def __init__( self.send_to_router = context.socket(zmq.PUSH) self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.controller_port}") + # Read model args self.model_path = server_args.model_path self.served_model_name = server_args.served_model_name self.hf_config = get_config( @@ -88,17 +97,18 @@ def __init__( trust_remote_code=server_args.trust_remote_code, model_overide_args=model_overide_args, ) - self.is_generation = is_generation_model(self.hf_config.architectures) - - if server_args.context_length is not None: - self.context_len = server_args.context_length - else: - self.context_len = get_context_length(self.hf_config) + self.is_generation = is_generation_model( + self.hf_config.architectures, self.server_args.is_embedding + ) + self.context_len = server_args.context_length or get_context_length( + self.hf_config + ) + # Create tokenizer if server_args.skip_tokenizer_init: self.tokenizer = self.processor = None else: - if is_multimodal_model(self.model_path): + if is_multimodal_model(self.hf_config.architectures): self.processor = get_processor( server_args.tokenizer_path, tokenizer_mode=server_args.tokenizer_mode, @@ -106,6 +116,9 @@ def __init__( ) self.tokenizer = self.processor.tokenizer os.environ["TOKENIZERS_PARALLELISM"] = "false" + + # We want to parallelize the image pre-processing so we + # create an executor for it self.executor = concurrent.futures.ProcessPoolExecutor( initializer=init_global_processor, mp_context=mp.get_context("fork"), @@ -118,34 +131,25 @@ def __init__( trust_remote_code=server_args.trust_remote_code, ) + # Store states self.to_create_loop = True self.rid_to_state: Dict[str, ReqState] = {} - async def get_pixel_values(self, image_data): - aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None) - grid_pinpoints = ( - self.hf_config.image_grid_pinpoints if aspect_ratio == "anyres" else None - ) - if self.executor is not None: - loop = asyncio.get_event_loop() - return await loop.run_in_executor( - self.executor, - get_pixel_values, - image_data, - aspect_ratio, - grid_pinpoints, - ) - else: - return get_pixel_values( - image_data, aspect_ratio, grid_pinpoints, self.processor - ) + # For update model weights + self.model_update_lock = asyncio.Lock() + self.model_update_result = None async def generate_request( - self, obj: Union[GenerateReqInput, EmbeddingReqInput], request=None + self, + obj: Union[GenerateReqInput, EmbeddingReqInput], + request: Optional[fastapi.Request] = None, ): if self.to_create_loop: self.create_handle_loop() + while self.model_update_lock.locked(): + await asyncio.sleep(0.001) + obj.post_init() is_single = obj.is_single @@ -153,18 +157,15 @@ async def generate_request( async for response in self._handle_single_request(obj, request): yield response else: - if hasattr(obj, "stream") and obj.stream: - raise ValueError("Do not support stream for batch mode.") - async for response in self._handle_batch_request(obj, request): yield response async def _handle_single_request( self, obj: Union[GenerateReqInput, EmbeddingReqInput], - request, - index=None, - is_cache_for_prefill=False, + request: Optional[fastapi.Request] = None, + index: Optional[int] = None, + is_cache_for_prefill: Optional[bool] = False, ): if not is_cache_for_prefill: # The normal case with a single prompt not_use_index = index is None @@ -184,7 +185,7 @@ async def _handle_single_request( ) if self.is_generation: - pixel_values, image_hash, image_size = await self._get_pixel_values( + pixel_values, image_hashes, image_sizes = await self._get_pixel_values( obj.image_data if not_use_index else obj.image_data[index] ) return_logprob = ( @@ -195,6 +196,8 @@ async def _handle_single_request( if not_use_index else obj.logprob_start_len[index] ) + if return_logprob and logprob_start_len == -1: + logprob_start_len = len(input_ids) - 1 top_logprobs_num = ( obj.top_logprobs_num if not_use_index @@ -237,21 +240,24 @@ async def _handle_single_request( sampling_params = SamplingParams(**obj.sampling_params[0]) sampling_params.max_new_tokens = 0 - pixel_values, image_hash, image_size = await self._get_pixel_values( + pixel_values, image_hashes, image_sizes = await self._get_pixel_values( obj.image_data[0] ) return_logprob = obj.return_logprob[0] logprob_start_len = obj.logprob_start_len[0] top_logprobs_num = obj.top_logprobs_num[0] + # Send to the controller if self.is_generation: + if return_logprob and logprob_start_len == -1: + logprob_start_len = len(input_ids) - 1 tokenized_obj = TokenizedGenerateReqInput( rid, input_text, input_ids, pixel_values, - image_hash, - image_size, + image_hashes, + image_sizes, sampling_params, return_logprob, logprob_start_len, @@ -265,31 +271,31 @@ async def _handle_single_request( input_ids, sampling_params, ) - self.send_to_router.send_pyobj(tokenized_obj) + # Recv results event = asyncio.Event() state = ReqState([], False, event) self.rid_to_state[rid] = state if not is_cache_for_prefill: - async for response in self._wait_for_response( - event, state, obj, rid, request - ): + async for response in self._wait_for_response(state, obj, rid, request): yield response else: assert self.is_generation - await self._wait_for_cache_prefill_response(event, state, obj, rid, request) + await self._wait_for_cache_prefill_response(state, obj, rid, request) yield input_ids async def _handle_batch_request( - self, obj: Union[GenerateReqInput, EmbeddingReqInput], request + self, + obj: Union[GenerateReqInput, EmbeddingReqInput], + request: Optional[fastapi.Request] = None, ): batch_size = obj.batch_size if self.is_generation: parallel_sample_num = obj.parallel_sample_num if parallel_sample_num != 1: - # Send prefill requests to cache the common input + # Send prefill requests to cache the common prefix parallel_sample_num += 1 input_id_result = [] if obj.input_ids is None else None for i in range(batch_size): @@ -306,6 +312,7 @@ async def _handle_batch_request( parallel_sample_num = 1 # First send out all requests + generators = [] for i in range(batch_size): for j in range(parallel_sample_num): if j == 0 and parallel_sample_num != 1: @@ -334,8 +341,10 @@ async def _handle_batch_request( sampling_params = self._get_sampling_params(obj.sampling_params[index]) if self.is_generation: - pixel_values, image_hash, image_size = await self._get_pixel_values( - obj.image_data[index] + if obj.return_logprob[index] and obj.logprob_start_len[index] == -1: + obj.logprob_start_len[index] = len(input_ids) - 1 + pixel_values, image_hashes, image_sizes = ( + await self._get_pixel_values(obj.image_data[index]) ) tokenized_obj = TokenizedGenerateReqInput( @@ -343,8 +352,8 @@ async def _handle_batch_request( input_text, input_ids, pixel_values, - image_hash, - image_size, + image_hashes, + image_sizes, sampling_params, obj.return_logprob[index], obj.logprob_start_len[index], @@ -364,42 +373,47 @@ async def _handle_batch_request( state = ReqState([], False, event) self.rid_to_state[rid] = state - # Then wait for all responses - output_list = [] - for i in range(batch_size): - for j in range(parallel_sample_num): - if j == 0 and parallel_sample_num != 1: - continue - index = i * parallel_sample_num + j - if parallel_sample_num != 1: - index += batch_size - 1 - i - rid = obj.rid[index] - state = self.rid_to_state[rid] - - while True: - try: - await asyncio.wait_for(state.event.wait(), timeout=4) - break - except asyncio.TimeoutError: - if request is not None and await request.is_disconnected(): - for rid in obj.rid: - self.abort_request(rid) - raise ValueError(f"Abort request {rid}") - continue - if self.is_generation: - output_list.append( - self.convert_logprob_style( - state.out_list[-1], - obj.return_logprob[index], - obj.top_logprobs_num[index], - obj.return_text_in_logprobs, - ) + generators.append( + self._wait_for_response( + state, + obj, + rid, + request, + index=index, + response_index=len(generators), ) - else: - output_list.append(state.out_list[-1]) - assert state.finished - del self.rid_to_state[rid] - yield output_list + ) + + # Then process the responses based on streaming option + is_stream = hasattr(obj, "stream") and obj.stream + + tasks = [asyncio.create_task(gen.__anext__()) for gen in generators] + output_list = [None] * len(tasks) + + # Recv results + while tasks: + done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + + for task in done: + cur_index = tasks.index(task) + + try: + result = task.result() + + if is_stream: + yield result + else: + output_list[result["index"]] = result + + tasks[cur_index] = asyncio.create_task( + generators[cur_index].__anext__() + ) + except StopAsyncIteration: + del generators[cur_index] + del tasks[cur_index] + + if not is_stream: + yield output_list def _validate_input_length(self, input_ids: List[int]): if len(input_ids) >= self.context_len: @@ -415,48 +429,44 @@ def _get_sampling_params(self, sampling_params_data: dict): sampling_params.verify() return sampling_params - async def _get_pixel_values(self, image_data): - if isinstance(image_data, list) and len(image_data) > 0: - return await self.get_pixel_values(image_data[0]) - elif isinstance(image_data, str): - return await self.get_pixel_values(image_data) - else: - return None, None, None - async def _wait_for_response( self, - event: asyncio.Event, state: ReqState, obj: Union[GenerateReqInput, EmbeddingReqInput], rid: str, - request, + request: Optional[fastapi.Request] = None, + index: Optional[int] = None, + response_index: int = 0, ): while True: try: - await asyncio.wait_for(event.wait(), timeout=4) + await asyncio.wait_for(state.event.wait(), timeout=4) except asyncio.TimeoutError: if request is not None and await request.is_disconnected(): - self.abort_request(rid) + for rid in [obj.rid] if obj.is_single else obj.rid: + self.abort_request(rid) raise ValueError(f"Abort request {rid}") continue if self.is_generation: out = self.convert_logprob_style( state.out_list[-1], - obj.return_logprob, - obj.top_logprobs_num, + obj.return_logprob if index is None else obj.return_logprob[index], + ( + obj.top_logprobs_num + if index is None + else obj.top_logprobs_num[index] + ), obj.return_text_in_logprobs, ) else: # isinstance(obj, EmbeddingReqInput) out = state.out_list[-1] + out["index"] = response_index + # Log requests if self.server_args.log_requests and state.finished: - if obj.text is None: - in_obj = {"input_ids": obj.input_ids} - else: - in_obj = {"text": obj.text} - logger.info(f"in={in_obj}, out={out}") + logger.info(f"in={obj}, out={out}") state.out_list = [] if state.finished: @@ -464,16 +474,15 @@ async def _wait_for_response( yield out break - event.clear() + state.event.clear() yield out async def _wait_for_cache_prefill_response( self, - event: asyncio.Event, state: ReqState, obj: GenerateReqInput, rid: str, - request, + request: Optional[fastapi.Request] = None, ): while True: try: @@ -500,6 +509,32 @@ def abort_request(self, rid: str): req = AbortReq(rid) self.send_to_router.send_pyobj(req) + async def update_weights( + self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None + ): + if self.to_create_loop: + self.create_handle_loop() + + # default the load format to the server_args + if obj.load_format is None: + obj.load_format = self.server_args.load_format + + if not self.model_update_lock.locked(): + async with self.model_update_lock: + # wait for the previous generation requests to finish + while len(self.rid_to_state) > 0: + await asyncio.sleep(0) + self.send_to_router.send_pyobj(obj) + self.model_update_result = asyncio.Future() + result = await self.model_update_result + if result.success: + self.server_args.model_path = obj.model_path + self.server_args.load_format = obj.load_format + self.model_path = obj.model_path + return result.success, result.message + else: + return False, "Another update is in progress. Please try again later." + def create_abort_task(self, obj: GenerateReqInput): # Abort the request if the client is disconnected. async def abort_request(): @@ -507,7 +542,7 @@ async def abort_request(): if obj.is_single: self.abort_request(obj.rid) else: - for rid in obj.rids: + for rid in obj.rid: self.abort_request(rid) background_tasks = BackgroundTasks() @@ -515,18 +550,29 @@ async def abort_request(): return background_tasks def create_handle_loop(self): + if not self.to_create_loop: + return + self.to_create_loop = False loop = asyncio.get_event_loop() loop.create_task(self.handle_loop()) async def handle_loop(self): + """The event loop that handles requests""" + while True: - recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut] = ( - await self.recv_from_detokenizer.recv_pyobj() - ) + recv_obj: Union[ + BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut, UpdateWeightReqOutput + ] = await self.recv_from_detokenizer.recv_pyobj() + + if isinstance(recv_obj, UpdateWeightReqOutput): + self.model_update_result.set_result(recv_obj) + continue + assert isinstance( recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut) ), f"Unexpected obj received: {type(recv_obj)}" + for i, rid in enumerate(recv_obj.rids): state = self.rid_to_state.get(rid, None) if state is None: @@ -610,11 +656,75 @@ def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool): ) return top_logprobs + async def _get_pixel_values(self, image_data: List[Union[str, bytes]]): + if not image_data: + return None, None, None + + aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None) + grid_pinpoints = ( + self.hf_config.image_grid_pinpoints + if hasattr(self.hf_config, "image_grid_pinpoints") + and "anyres" in aspect_ratio + else None + ) + + if isinstance(image_data, list) and len(image_data) > 0: + # Multiple images + if len(image_data) > 1: + aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres + pixel_values, image_hashes, image_sizes = [], [], [] + for img_data in image_data: + pixel_v, image_h, image_s = await self._process_single_image( + img_data, aspect_ratio, grid_pinpoints + ) + pixel_values.append(pixel_v) + image_hashes.append(image_h) + image_sizes.append(image_s) + + if isinstance(pixel_values[0], np.ndarray): + pixel_values = np.stack(pixel_values, axis=0) + else: + # A single image + pixel_values, image_hash, image_size = await self._process_single_image( + image_data[0], aspect_ratio, grid_pinpoints + ) + image_hashes = [image_hash] + image_sizes = [image_size] + elif isinstance(image_data, str): + # A single image + pixel_values, image_hash, image_size = await self._process_single_image( + image_data, aspect_ratio, grid_pinpoints + ) + image_hashes = [image_hash] + image_sizes = [image_size] + else: + raise ValueError(f"Invalid image data: {image_data}") + + return pixel_values, image_hashes, image_sizes + + async def _process_single_image( + self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str + ): + if self.executor is not None: + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + self.executor, + _process_single_image_task, + image_data, + aspect_ratio, + grid_pinpoints, + ) + else: + return _process_single_image_task( + image_data, aspect_ratio, grid_pinpoints, self.processor + ) + global global_processor def init_global_processor(server_args: ServerArgs): + """Init the global processor for multi modal models.""" global global_processor transformers.logging.set_verbosity_error() global_processor = get_processor( @@ -624,13 +734,17 @@ def init_global_processor(server_args: ServerArgs): ) -def get_pixel_values( - image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None +def _process_single_image_task( + image_data: Union[str, bytes], + image_aspect_ratio: Optional[str] = None, + image_grid_pinpoints: Optional[str] = None, + processor=None, ): try: processor = processor or global_processor image, image_size = load_image(image_data) if image_size is not None: + # It is a video with multiple images image_hash = hash(image_data) pixel_values = processor.image_processor(image)["pixel_values"] for _ in range(len(pixel_values)): @@ -638,20 +752,28 @@ def get_pixel_values( pixel_values = np.stack(pixel_values, axis=0) return pixel_values, image_hash, image_size else: + # It is an image image_hash = hash(image_data) if image_aspect_ratio == "pad": image = expand2square( image, tuple(int(x * 255) for x in processor.image_processor.image_mean), ) - pixel_values = processor.image_processor(image)["pixel_values"][0] - elif image_aspect_ratio == "anyres": + pixel_values = processor.image_processor(image.convert("RGB"))[ + "pixel_values" + ][0] + elif image_aspect_ratio == "anyres" or ( + image_aspect_ratio is not None and "anyres_max" in image_aspect_ratio + ): pixel_values = process_anyres_image( image, processor.image_processor, image_grid_pinpoints ) else: pixel_values = processor.image_processor(image)["pixel_values"][0] - pixel_values = pixel_values.astype(np.float16) + + if isinstance(pixel_values, np.ndarray): + pixel_values = pixel_values.astype(np.float16) + return pixel_values, image_hash, image.size except Exception: - print("Exception in TokenizerManager:\n" + get_exception_traceback()) + logger.error("Exception in TokenizerManager:\n" + get_exception_traceback()) diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index d273e1f668..c9f1868229 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -32,7 +32,7 @@ from sglang.srt.constrained.fsm_cache import FSMCache from sglang.srt.constrained.jump_forward import JumpForwardCache from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer -from sglang.srt.layers.logits_processor import LogitProcessorOutput +from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.io_struct import ( AbortReq, BatchEmbeddingOut, @@ -41,6 +41,8 @@ FlushCacheReq, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, + UpdateWeightReqInput, + UpdateWeightReqOutput, ) from sglang.srt.managers.policy_scheduler import PolicyScheduler, PrefillAdder from sglang.srt.managers.schedule_batch import ( @@ -56,7 +58,7 @@ from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( - get_int_token_logit_bias, + configure_logger, is_multimodal_model, set_random_seed, suppress_other_loggers, @@ -66,8 +68,7 @@ logger = logging.getLogger(__name__) -# TODO: Rename "CI" to "SGLANG_IS_IN_CI". -crash_on_warning = os.getenv("CI", "false") == "true" +crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true" class ModelTpServer: @@ -94,10 +95,6 @@ def __init__( self.schedule_policy = server_args.schedule_policy self.disable_regex_jump_forward = server_args.disable_regex_jump_forward - # Chunked prefill - self.chunked_prefill_size = server_args.chunked_prefill_size - self.current_inflight_req = None - # Init model and tokenizer self.model_config = ModelConfig( server_args.model_path, @@ -105,6 +102,7 @@ def __init__( context_length=server_args.context_length, model_overide_args=model_overide_args, ) + self.model_runner = ModelRunner( model_config=self.model_config, mem_fraction_static=server_args.mem_fraction_static, @@ -127,7 +125,7 @@ def __init__( if server_args.skip_tokenizer_init: self.tokenizer = self.processor = None else: - if is_multimodal_model(server_args.model_path): + if is_multimodal_model(self.model_config.hf_config.architectures): self.processor = get_processor( server_args.tokenizer_path, tokenizer_mode=server_args.tokenizer_mode, @@ -150,18 +148,21 @@ def __init__( ), self.model_runner.req_to_token_pool.size - 1, ) - self.int_token_logit_bias = torch.tensor( - get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size) - ) self.max_req_input_len = min( self.model_config.context_len - 1, self.max_total_num_tokens - 1, ) + + # Sync random seed + server_args.random_seed = broadcast_recv_input( + [server_args.random_seed], + self.tp_rank, + self.model_runner.tp_group.cpu_group, + )[0] set_random_seed(server_args.random_seed) # Print info logger.info( - f"[gpu={self.gpu_id}] " f"max_total_num_tokens={self.max_total_num_tokens}, " f"max_prefill_tokens={self.max_prefill_tokens}, " f"max_running_requests={self.max_running_requests}, " @@ -198,6 +199,13 @@ def __init__( self.num_generated_tokens = 0 self.last_stats_tic = time.time() + # Chunked prefill + self.chunked_prefill_size = server_args.chunked_prefill_size + self.current_inflight_req = None + self.is_mixed_chunk = ( + self.chunked_prefill_size is not None and server_args.enable_mixed_chunk + ) + # Init the FSM cache for constrained generation if not server_args.skip_tokenizer_init: self.regex_fsm_cache = FSMCache( @@ -207,6 +215,16 @@ def __init__( "trust_remote_code": server_args.trust_remote_code, }, skip_tokenizer_init=server_args.skip_tokenizer_init, + json_schema_mode=False, + ) + self.json_fsm_cache = FSMCache( + server_args.tokenizer_path, + { + "tokenizer_mode": server_args.tokenizer_mode, + "trust_remote_code": server_args.trust_remote_code, + }, + skip_tokenizer_init=server_args.skip_tokenizer_init, + json_schema_mode=True, ) self.jump_forward_cache = JumpForwardCache() @@ -234,6 +252,9 @@ def exposed_step(self, recv_reqs: List): self.flush_cache() elif isinstance(recv_req, AbortReq): self.abort_request(recv_req) + elif isinstance(recv_req, UpdateWeightReqInput): + success, message = self.update_weights(recv_req) + self.out_pyobjs.append(UpdateWeightReqOutput(success, message)) else: raise ValueError(f"Invalid request: {recv_req}") @@ -291,7 +312,7 @@ def print_decode_stats(self): self.num_generated_tokens = 0 self.last_stats_tic = time.time() logger.info( - f"[gpu={self.gpu_id}] Decode batch. " + f"Decode batch. " f"#running-req: {len(self.running_batch.reqs)}, " f"#token: {num_used}, " f"token usage: {num_used / self.max_total_num_tokens:.2f}, " @@ -330,29 +351,42 @@ def handle_generate_request( if self.model_runner.is_generation: req.pixel_values = recv_req.pixel_values if req.pixel_values is not None: + # Use image hash as fake token_ids, which is then used + # for prefix matching + image_hash = hash(tuple(recv_req.image_hashes)) req.pad_value = [ - (recv_req.image_hash) % self.model_config.vocab_size, - (recv_req.image_hash >> 16) % self.model_config.vocab_size, - (recv_req.image_hash >> 32) % self.model_config.vocab_size, - (recv_req.image_hash >> 64) % self.model_config.vocab_size, + (image_hash) % self.model_config.vocab_size, + (image_hash >> 16) % self.model_config.vocab_size, + (image_hash >> 32) % self.model_config.vocab_size, + (image_hash >> 64) % self.model_config.vocab_size, ] - req.image_size = recv_req.image_size + req.image_sizes = recv_req.image_sizes ( req.origin_input_ids, - req.image_offset, + req.image_offsets, ) = self.model_runner.model.pad_input_ids( req.origin_input_ids_unpadded, req.pad_value, - req.pixel_values.shape, - req.image_size, + req.pixel_values, + req.image_sizes, ) req.return_logprob = recv_req.return_logprob req.logprob_start_len = recv_req.logprob_start_len req.top_logprobs_num = recv_req.top_logprobs_num req.stream = recv_req.stream + # Init regex fsm fron json + if req.sampling_params.json_schema is not None: + req.regex_fsm, computed_regex_string = self.json_fsm_cache.query( + req.sampling_params.json_schema + ) + if not self.disable_regex_jump_forward: + req.jump_forward_map = self.jump_forward_cache.query( + computed_regex_string + ) + # Init regex fsm - if req.sampling_params.regex is not None: + elif req.sampling_params.regex is not None: req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex) if not self.disable_regex_jump_forward: req.jump_forward_map = self.jump_forward_cache.query( @@ -388,11 +422,14 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]: # Get priority queue prefix_computed = self.scheduler.calc_priority(self.waiting_queue) + num_mixed_running = running_bs if self.is_mixed_chunk else 0 + adder = PrefillAdder( self.tree_cache, self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(), self.max_prefill_tokens, self.chunked_prefill_size, + num_mixed_running, ) if self.running_batch is not None: @@ -428,26 +465,37 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]: # Print stats if self.tp_rank == 0: - self.tree_cache_metrics["total"] += ( - adder.log_input_tokens + adder.log_hit_tokens - ) / 10**9 - self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9 - - try: + if isinstance(self.tree_cache, RadixCache): + self.tree_cache_metrics["total"] += ( + adder.log_input_tokens + adder.log_hit_tokens + ) / 10**9 + self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9 tree_cache_hit_rate = ( self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"] ) - except ZeroDivisionError: - tree_cache_hit_rate = 1.0 - logger.info( - f"[gpu={self.gpu_id}] Prefill batch. " - f"#new-seq: {len(can_run_list)}, " - f"#new-token: {adder.log_input_tokens}, " - f"#cached-token: {adder.log_hit_tokens}, " - f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, " - f"#running-req: {running_bs}, " - f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}" - ) + else: + tree_cache_hit_rate = 0.0 + + if num_mixed_running > 0: + logger.info( + f"Prefill batch" + f"(mixed #running-req: {num_mixed_running}). " + f"#new-seq: {len(can_run_list)}, " + f"#new-token: {adder.log_input_tokens}, " + f"#cached-token: {adder.log_hit_tokens}, " + f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, " + f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}" + ) + else: + logger.info( + f"Prefill batch. " + f"#new-seq: {len(can_run_list)}, " + f"#new-token: {adder.log_input_tokens}, " + f"#cached-token: {adder.log_hit_tokens}, " + f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, " + f"#running-req: {running_bs}, " + f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}" + ) # Return the new batch new_batch = ScheduleBatch.init_new( @@ -461,9 +509,14 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]: def forward_prefill_batch(self, batch: ScheduleBatch): # Build batch tensors - batch.prepare_for_extend( - self.model_config.vocab_size, self.int_token_logit_bias - ) + batch.prepare_for_extend(self.model_config.vocab_size) + + decoding_reqs = [] + if self.is_mixed_chunk and self.running_batch is not None: + self.running_batch.prepare_for_decode() + batch.mix_with_running(self.running_batch) + decoding_reqs = self.running_batch.reqs + self.running_batch = None if self.controller_info: num = 0 @@ -486,18 +539,29 @@ def forward_prefill_batch(self, batch: ScheduleBatch): if self.model_runner.is_generation: # Forward and sample the next tokens if batch.extend_num_tokens != 0: - output = self.model_runner.forward(batch, ForwardMode.EXTEND) - next_token_ids = batch.sample(output.next_token_logits) + sample_output, logits_output = self.model_runner.forward( + batch, ForwardMode.EXTEND + ) + next_token_ids = batch.check_sample_results(sample_output) + batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( + next_token_ids + ) # Move logprobs to cpu - if output.next_token_logprobs is not None: - output.next_token_logprobs = output.next_token_logprobs[ - torch.arange(len(next_token_ids), device=next_token_ids.device), - next_token_ids, - ].tolist() - output.input_token_logprobs = output.input_token_logprobs.tolist() - output.normalized_prompt_logprobs = ( - output.normalized_prompt_logprobs.tolist() + if logits_output.next_token_logprobs is not None: + logits_output.next_token_logprobs = ( + logits_output.next_token_logprobs[ + torch.arange( + len(next_token_ids), device=next_token_ids.device + ), + next_token_ids, + ].tolist() + ) + logits_output.input_token_logprobs = ( + logits_output.input_token_logprobs.tolist() + ) + logits_output.normalized_prompt_logprobs = ( + logits_output.normalized_prompt_logprobs.tolist() ) next_token_ids = next_token_ids.tolist() @@ -520,9 +584,15 @@ def forward_prefill_batch(self, batch: ScheduleBatch): req.output_ids.append(next_token_ids[i]) req.check_finished() + if req.regex_fsm is not None: + req.regex_fsm_state = req.regex_fsm.get_next_state( + req.regex_fsm_state, next_token_ids[i] + ) + if req.finished(): self.tree_cache.cache_finished_req(req) - else: + elif req not in decoding_reqs: + # To reduce overhead, only cache prefill reqs self.tree_cache.cache_unfinished_req(req) if req is self.current_inflight_req: @@ -530,12 +600,14 @@ def forward_prefill_batch(self, batch: ScheduleBatch): self.req_to_token_pool.free(req.req_pool_idx) if req.return_logprob: - self.add_logprob_return_values(i, req, pt, next_token_ids, output) + self.add_logprob_return_values( + i, req, pt, next_token_ids, logits_output + ) pt += req.extend_input_len else: assert batch.extend_num_tokens != 0 - output = self.model_runner.forward(batch, ForwardMode.EXTEND) - embeddings = output.embeddings.tolist() + logits_output = self.model_runner.forward(batch, ForwardMode.EXTEND) + embeddings = logits_output.embeddings.tolist() # Check finish conditions for i, req in enumerate(batch.reqs): @@ -563,7 +635,7 @@ def add_logprob_return_values( req: Req, pt: int, next_token_ids: List[int], - output: LogitProcessorOutput, + output: LogitsProcessorOutput, ): if req.normalized_prompt_logprob is None: req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i] @@ -622,7 +694,7 @@ def forward_decode_batch(self, batch: ScheduleBatch): self.new_token_ratio = new_token_ratio logger.info( - f"[gpu{self.gpu_id}]decode out of memory happened, " + "Decode out of memory happened. " f"#retracted_reqs: {len(retracted_reqs)}, " f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}" ) @@ -663,12 +735,17 @@ def forward_decode_batch(self, batch: ScheduleBatch): ) # Forward and sample the next tokens - output = self.model_runner.forward(batch, ForwardMode.DECODE) - next_token_ids = batch.sample(output.next_token_logits) + sample_output, logits_output = self.model_runner.forward( + batch, ForwardMode.DECODE + ) + next_token_ids = batch.check_sample_results(sample_output) + batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( + next_token_ids + ) # Move logprobs to cpu - if output.next_token_logprobs is not None: - next_token_logprobs = output.next_token_logprobs[ + if logits_output.next_token_logprobs is not None: + next_token_logprobs = logits_output.next_token_logprobs[ torch.arange(len(next_token_ids), device=next_token_ids.device), next_token_ids, ].tolist() @@ -681,6 +758,11 @@ def forward_decode_batch(self, batch: ScheduleBatch): req.output_ids.append(next_token_id) req.check_finished() + if req.regex_fsm is not None: + req.regex_fsm_state = req.regex_fsm.get_next_state( + req.regex_fsm_state, next_token_id + ) + if req.finished(): self.tree_cache.cache_finished_req(req) @@ -689,7 +771,7 @@ def forward_decode_batch(self, batch: ScheduleBatch): (next_token_logprobs[i], next_token_id) ) if req.top_logprobs_num > 0: - req.output_top_logprobs.append(output.output_top_logprobs[i]) + req.output_top_logprobs.append(logits_output.output_top_logprobs[i]) self.handle_finished_requests(batch) @@ -804,12 +886,15 @@ def flush_cache(self): self.token_to_kv_pool.clear() torch.cuda.empty_cache() logger.info("Cache flushed successfully!") + if_success = True else: - warnings.warn( + logging.warning( f"Cache not flushed because there are pending requests. " f"#queue-req: {len(self.waiting_queue)}, " f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}" ) + if_success = False + return if_success def abort_request(self, recv_req): # Delete requests in the waiting queue @@ -829,6 +914,15 @@ def abort_request(self, recv_req): req.finished_reason = FINISH_ABORT() break + def update_weights(self, recv_req): + success, message = self.model_runner.update_weights( + recv_req.model_path, recv_req.load_format + ) + if success: + flash_cache_success = self.flush_cache() + assert flash_cache_success, "Cache flush failed after updating weights" + return success, message + def run_tp_server( gpu_id: int, @@ -837,7 +931,9 @@ def run_tp_server( nccl_port: int, model_overide_args: dict, ): - """Run a tensor parallel server.""" + """Run a tensor parallel model server.""" + configure_logger(server_args, prefix=f" TP{tp_rank}") + try: model_server = ModelTpServer( gpu_id, @@ -893,6 +989,7 @@ def broadcast_recv_input( dist.broadcast(tensor_size, src=0, group=dist_group) dist.broadcast(tensor_data, src=0, group=dist_group) + return data else: tensor_size = torch.tensor([0], dtype=torch.long) dist.broadcast(tensor_size, src=0, group=dist_group) diff --git a/python/sglang/srt/mem_cache/chunk_cache.py b/python/sglang/srt/mem_cache/chunk_cache.py index 35b9171e5b..e7e48ecee4 100644 --- a/python/sglang/srt/mem_cache/chunk_cache.py +++ b/python/sglang/srt/mem_cache/chunk_cache.py @@ -68,7 +68,7 @@ def cache_unfinished_req(self, req: Req, token_ids: Optional[List[int]] = None): req.last_node = entry def insert(self): - raise NotImplementedError + raise NotImplementedError() def evict(self, num_tokens: int, evict_callback: Callable): pass diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 68cefbbf9f..fef74321ac 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -16,7 +16,8 @@ """Memory pool.""" import logging -from typing import List, Union +from abc import ABC, abstractmethod +from typing import List, Tuple, Union import torch @@ -52,14 +53,21 @@ def clear(self): self.free_slots = list(range(self.size)) -class BaseTokenToKVPool: +class BaseTokenToKVPool(ABC): """A memory pool that maps a token to its kv cache locations""" def __init__( self, size: int, + dtype: torch.dtype, ): self.size = size + self.dtype = dtype + if dtype == torch.float8_e5m2: + # NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2 + self.store_dtype = torch.uint8 + else: + self.store_dtype = dtype # We also add one slot. This slot is used for writing dummy output from padded tokens. self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda") @@ -112,6 +120,28 @@ def clear(self): # We also add one slot. This slot is used for writing dummy output from padded tokens. self.mem_state[0] = False + @abstractmethod + def get_key_buffer(self, layer_id: int) -> torch.Tensor: + raise NotImplementedError() + + @abstractmethod + def get_value_buffer(self, layer_id: int) -> torch.Tensor: + raise NotImplementedError() + + @abstractmethod + def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError() + + @abstractmethod + def set_kv_buffer( + self, + layer_id: int, + loc: torch.Tensor, + cache_k: torch.Tensor, + cache_v: torch.Tensor, + ) -> None: + raise NotImplementedError() + class MHATokenToKVPool(BaseTokenToKVPool): @@ -123,26 +153,52 @@ def __init__( head_dim: int, layer_num: int, ): - super().__init__(size) + super().__init__(size, dtype) # [size, head_num, head_dim] for each layer self.k_buffer = [ - torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda") + torch.empty( + (size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda" + ) for _ in range(layer_num) ] self.v_buffer = [ - torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda") + torch.empty( + (size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda" + ) for _ in range(layer_num) ] def get_key_buffer(self, layer_id: int): + if self.store_dtype != self.dtype: + return self.k_buffer[layer_id].view(self.dtype) return self.k_buffer[layer_id] def get_value_buffer(self, layer_id: int): + if self.store_dtype != self.dtype: + return self.v_buffer[layer_id].view(self.dtype) return self.v_buffer[layer_id] def get_kv_buffer(self, layer_id: int): - return self.k_buffer[layer_id], self.v_buffer[layer_id] + return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id) + + def set_kv_buffer( + self, + layer_id: int, + loc: torch.Tensor, + cache_k: torch.Tensor, + cache_v: torch.Tensor, + ): + if cache_k.dtype != self.dtype: + cache_k = cache_k.to(self.dtype) + if cache_v.dtype != self.dtype: + cache_v = cache_v.to(self.dtype) + if self.store_dtype != self.dtype: + self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype) + self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype) + else: + self.k_buffer[layer_id][loc] = cache_k + self.v_buffer[layer_id][loc] = cache_v class MLATokenToKVPool(BaseTokenToKVPool): @@ -155,23 +211,41 @@ def __init__( qk_rope_head_dim: int, layer_num: int, ): - super().__init__(size) + super().__init__(size, dtype) self.kv_lora_rank = kv_lora_rank self.kv_buffer = [ torch.empty( (size + 1, 1, kv_lora_rank + qk_rope_head_dim), - dtype=dtype, + dtype=self.store_dtype, device="cuda", ) for _ in range(layer_num) ] def get_key_buffer(self, layer_id: int): + if self.store_dtype != self.dtype: + return self.kv_buffer[layer_id].view(self.dtype) return self.kv_buffer[layer_id] def get_value_buffer(self, layer_id: int): + if self.store_dtype != self.dtype: + return self.kv_buffer[layer_id][..., : self.kv_lora_rank].view(self.dtype) return self.kv_buffer[layer_id][..., : self.kv_lora_rank] def get_kv_buffer(self, layer_id: int): return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id) + + def set_kv_buffer( + self, + layer_id: int, + loc: torch.Tensor, + cache_k: torch.Tensor, + cache_v: torch.Tensor, + ): + if cache_k.dtype != self.dtype: + cache_k = cache_k.to(self.dtype) + if self.store_dtype != self.dtype: + self.kv_buffer[layer_id][loc] = cache_k.view(self.store_dtype) + else: + self.kv_buffer[layer_id][loc] = cache_k diff --git a/python/sglang/srt/mm_utils.py b/python/sglang/srt/mm_utils.py index e09c8215c6..7918f3f711 100644 --- a/python/sglang/srt/mm_utils.py +++ b/python/sglang/srt/mm_utils.py @@ -13,10 +13,25 @@ limitations under the License. """ -# Source: https://github.com/haotian-liu/LLaVA/blob/main/llava/mm_utils.py +# Source: https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/llava/mm_utils.py +""" +Utilities for multi-modal models. + +This python file mainly contains utilities that were used in the +image processing logic of llava-next including operations such as +anyres and anyres_max + +Currently supports the anyres and anyres_max operation for CLIP and +SigLip. For more information, you may refer to the paper or the blog + +LLaVA-NeXT : https://llava-vl.github.io/blog/2024-01-30-llava-next/ +LLaVA-Onevision : https://arxiv.org/pdf/2408.03326 + +""" import ast import base64 import math +import re from io import BytesIO import numpy as np @@ -40,10 +55,13 @@ def select_best_resolution(original_size, possible_resolutions): min_wasted_resolution = float("inf") for width, height in possible_resolutions: + # Calculate the downscaled size to keep the aspect ratio scale = min(width / original_width, height / original_height) downscaled_width, downscaled_height = int(original_width * scale), int( original_height * scale ) + + # Calculate effective and wasted resolutions effective_resolution = min( downscaled_width * downscaled_height, original_width * original_height ) @@ -129,6 +147,26 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): Returns: tuple: The shape of the image patch grid in the format (width, height). """ + if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints: + assert patch_size in [ + 224, + 336, + 384, + 448, + 512, + ], "patch_size should be in [224, 336, 384, 448, 512]" + # Use regex to extract the range from the input string + matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints) + range_start = tuple(map(int, matches[0])) + range_end = tuple(map(int, matches[-1])) + # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1]) + grid_pinpoints = [ + (i, j) + for i in range(range_start[0], range_end[0] + 1) + for j in range(range_start[1], range_end[1] + 1) + ] + # Multiply all elements by patch_size + grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints] if type(grid_pinpoints) is list: possible_resolutions = grid_pinpoints else: @@ -149,6 +187,31 @@ def process_anyres_image(image, processor, grid_pinpoints): Returns: np.array: An np array containing the processed image patches. """ + if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints: + try: + patch_size = processor.size[0] + except Exception as e: + patch_size = processor.size["shortest_edge"] + assert patch_size in [ + 224, + 336, + 384, + 448, + 512, + ], "patch_size should be in [224, 336, 384, 448, 512]" + # Use regex to extract the range from the input string + matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints) + range_start = tuple(map(int, matches[0])) + range_end = tuple(map(int, matches[-1])) + # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1]) + grid_pinpoints = [ + (i, j) + for i in range(range_start[0], range_end[0] + 1) + for j in range(range_start[1], range_end[1] + 1) + ] + # Multiply all elements by patch_size + grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints] + if type(grid_pinpoints) is list: possible_resolutions = grid_pinpoints else: @@ -156,15 +219,24 @@ def process_anyres_image(image, processor, grid_pinpoints): best_resolution = select_best_resolution(image.size, possible_resolutions) image_padded = resize_and_pad_image(image, best_resolution) - patches = divide_to_patches(image_padded, processor.crop_size["height"]) - - image_original_resize = image.resize( - (processor.size["shortest_edge"], processor.size["shortest_edge"]) + # For Siglip processor, only have size but no crop size + crop_size = ( + processor.crop_size["height"] + if "crop_size" in processor.__dict__ + else processor.size["height"] ) + shortest_edge = ( + processor.size["shortest_edge"] + if "shortest_edge" in processor.size + else processor.size["height"] + ) + patches = divide_to_patches(image_padded, crop_size) + + image_original_resize = image.resize((shortest_edge, shortest_edge)) image_patches = [image_original_resize] + patches image_patches = [ - processor.preprocess(image_patch)["pixel_values"][0] + processor.preprocess(image_patch.convert("RGB"))["pixel_values"][0] for image_patch in image_patches ] return np.stack(image_patches, axis=0) @@ -255,7 +327,7 @@ def process_images(images, image_processor, model_cfg): ) image = image_processor.preprocess(image)["pixel_values"][0] new_images.append(image) - elif image_aspect_ratio == "anyres": + elif "anyres" in image_aspect_ratio: for image in images: image = process_anyres_image( image, image_processor, model_cfg.image_grid_pinpoints diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 9bfd4a646c..40c87af88c 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -17,6 +17,7 @@ import bisect from contextlib import contextmanager +from typing import Callable, List import torch from flashinfer import BatchDecodeWithPagedKVCacheWrapper @@ -25,16 +26,18 @@ from vllm.model_executor.custom_op import CustomOp from sglang.srt.layers.logits_processor import ( - LogitProcessorOutput, LogitsMetadata, LogitsProcessor, + LogitsProcessorOutput, ) +from sglang.srt.layers.sampler import SampleOutput from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.model_executor.forward_batch_info import ( ForwardMode, InputMetadata, update_flashinfer_indices, ) +from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.utils import monkey_patch_vllm_all_gather @@ -51,12 +54,12 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False): @contextmanager def patch_model( - model: torch.nn.Module, use_compile: bool, tp_group: "GroupCoordinator" + model: torch.nn.Module, enable_compile: bool, tp_group: "GroupCoordinator" ): backup_ca_comm = None try: - if use_compile: + if enable_compile: _to_torch(model) monkey_patch_vllm_all_gather() backup_ca_comm = tp_group.ca_comm @@ -65,7 +68,7 @@ def patch_model( else: yield model.forward finally: - if use_compile: + if enable_compile: _to_torch(model, reverse=True) monkey_patch_vllm_all_gather(reverse=True) tp_group.ca_comm = backup_ca_comm @@ -84,13 +87,20 @@ def set_torch_compile_config(): class CudaGraphRunner: - def __init__(self, model_runner, max_batch_size_to_capture, use_torch_compile): + def __init__( + self, + model_runner: "ModelRunner", + max_batch_size_to_capture: int, + use_torch_compile: bool, + disable_padding: bool, + ): self.model_runner = model_runner self.graphs = {} self.input_buffers = {} self.output_buffers = {} self.flashinfer_handlers = {} self.graph_memory_pool = None + self.disable_padding = disable_padding # Common inputs self.max_bs = max_batch_size_to_capture @@ -98,8 +108,8 @@ def __init__(self, model_runner, max_batch_size_to_capture, use_torch_compile): self.req_pool_indices = torch.zeros( (self.max_bs,), dtype=torch.int32, device="cuda" ) - self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda") - self.position_ids_offsets = torch.zeros( + self.seq_lens = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda") + self.position_ids_offsets = torch.ones( (self.max_bs,), dtype=torch.int32, device="cuda" ) self.out_cache_loc = torch.zeros( @@ -107,9 +117,6 @@ def __init__(self, model_runner, max_batch_size_to_capture, use_torch_compile): ) # FlashInfer inputs - self.flashinfer_workspace_buffer = ( - self.model_runner.flashinfer_workspace_buffers[0] - ) self.flashinfer_kv_indptr = torch.zeros( (self.max_bs + 1,), dtype=torch.int32, device="cuda" ) @@ -121,16 +128,40 @@ def __init__(self, model_runner, max_batch_size_to_capture, use_torch_compile): self.flashinfer_kv_last_page_len = torch.ones( (self.max_bs,), dtype=torch.int32, device="cuda" ) + if model_runner.sliding_window_size is None: + self.flashinfer_workspace_buffer = ( + self.model_runner.flashinfer_workspace_buffer + ) + else: + self.flashinfer_workspace_buffer = ( + self.model_runner.flashinfer_workspace_buffer + ) + + self.flashinfer_kv_indptr = [ + self.flashinfer_kv_indptr, + self.flashinfer_kv_indptr.clone(), + ] + self.flashinfer_kv_indices = [ + self.flashinfer_kv_indices, + self.flashinfer_kv_indices.clone(), + ] + + # Sampling inputs + vocab_size = model_runner.model_config.vocab_size + self.sampling_info = SamplingBatchInfo.dummy_one(self.max_bs, vocab_size) self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else [] if use_torch_compile: set_torch_compile_config() - def can_run(self, batch_size): - return batch_size < self.max_bs + def can_run(self, batch_size: int): + if self.disable_padding: + return batch_size in self.graphs + else: + return batch_size <= self.max_bs - def capture(self, batch_size_list): + def capture(self, batch_size_list: List[int]): self.batch_size_list = batch_size_list with graph_capture() as graph_capture_context: self.stream = graph_capture_context.stream @@ -151,7 +182,7 @@ def capture(self, batch_size_list): self.output_buffers[bs] = output_buffers self.flashinfer_handlers[bs] = flashinfer_handler - def capture_one_batch_size(self, bs, forward): + def capture_one_batch_size(self, bs: int, forward: Callable): graph = torch.cuda.CUDAGraph() stream = self.stream @@ -171,15 +202,32 @@ def capture_one_batch_size(self, bs, forward): use_tensor_cores = True else: use_tensor_cores = False - flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( - self.flashinfer_workspace_buffer, - "NHD", - use_cuda_graph=True, - use_tensor_cores=use_tensor_cores, - paged_kv_indptr_buffer=self.flashinfer_kv_indptr[: bs + 1], - paged_kv_indices_buffer=self.flashinfer_kv_indices, - paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs], - ) + if self.model_runner.sliding_window_size is None: + flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( + self.flashinfer_workspace_buffer, + "NHD", + use_cuda_graph=True, + use_tensor_cores=use_tensor_cores, + paged_kv_indptr_buffer=self.flashinfer_kv_indptr[: bs + 1], + paged_kv_indices_buffer=self.flashinfer_kv_indices, + paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs], + ) + else: + flashinfer_decode_wrapper = [] + for i in range(2): + flashinfer_decode_wrapper.append( + BatchDecodeWithPagedKVCacheWrapper( + self.flashinfer_workspace_buffer, + "NHD", + use_cuda_graph=True, + use_tensor_cores=use_tensor_cores, + paged_kv_indptr_buffer=self.flashinfer_kv_indptr[i][: bs + 1], + paged_kv_indices_buffer=self.flashinfer_kv_indices[i], + paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[ + :bs + ], + ) + ) update_flashinfer_indices( ForwardMode.DECODE, self.model_runner, @@ -193,6 +241,7 @@ def capture_one_batch_size(self, bs, forward): def run_once(): input_metadata = InputMetadata( forward_mode=ForwardMode.DECODE, + sampling_info=self.sampling_info[:bs], batch_size=bs, req_pool_indices=req_pool_indices, seq_lens=seq_lens, @@ -201,19 +250,30 @@ def run_once(): out_cache_loc=out_cache_loc, return_logprob=False, top_logprobs_nums=0, - positions=(seq_lens - 1).to(torch.int64), + positions=(seq_lens - 1 + position_ids_offsets).to(torch.int64), flashinfer_decode_wrapper=flashinfer_decode_wrapper, ) return forward(input_ids, input_metadata.positions, input_metadata) for _ in range(2): + torch.cuda.synchronize() + self.model_runner.tp_group.barrier() + run_once() + torch.cuda.synchronize() + self.model_runner.tp_group.barrier() + torch.cuda.synchronize() + self.model_runner.tp_group.barrier() + with torch.cuda.graph(graph, pool=self.graph_memory_pool, stream=stream): out = run_once() + torch.cuda.synchronize() + self.model_runner.tp_group.barrier() + self.graph_memory_pool = graph.pool() return graph, None, out, flashinfer_decode_wrapper @@ -225,8 +285,8 @@ def replay(self, batch: ScheduleBatch): index = bisect.bisect_left(self.batch_size_list, raw_bs) bs = self.batch_size_list[index] if bs != raw_bs: - self.seq_lens.fill_(1) - self.position_ids_offsets.zero_() + self.seq_lens.zero_() + self.position_ids_offsets.fill_(1) self.out_cache_loc.zero_() # Common inputs @@ -246,25 +306,35 @@ def replay(self, batch: ScheduleBatch): self.flashinfer_handlers[bs], ) + # Sampling inputs + self.sampling_info.inplace_assign(raw_bs, batch.sampling_info) + # Replay + torch.cuda.synchronize() self.graphs[bs].replay() - output = self.output_buffers[bs] + torch.cuda.synchronize() + sample_output, logits_output = self.output_buffers[bs] # Unpad if bs != raw_bs: - output = LogitProcessorOutput( - next_token_logits=output.next_token_logits[:raw_bs], + logits_output = LogitsProcessorOutput( + next_token_logits=logits_output.next_token_logits[:raw_bs], next_token_logprobs=None, normalized_prompt_logprobs=None, input_token_logprobs=None, input_top_logprobs=None, output_top_logprobs=None, ) + sample_output = SampleOutput( + sample_output.success[:raw_bs], + sample_output.probs[:raw_bs], + sample_output.batch_next_token_ids[:raw_bs], + ) # Extract logprobs if batch.return_logprob: - output.next_token_logprobs = torch.nn.functional.log_softmax( - output.next_token_logits, dim=-1 + logits_output.next_token_logprobs = torch.nn.functional.log_softmax( + logits_output.next_token_logits, dim=-1 ) return_top_logprob = any(x > 0 for x in batch.top_logprobs_nums) if return_top_logprob: @@ -272,8 +342,8 @@ def replay(self, batch: ScheduleBatch): forward_mode=ForwardMode.DECODE, top_logprobs_nums=batch.top_logprobs_nums, ) - output.output_top_logprobs = LogitsProcessor.get_top_logprobs( - output.next_token_logprobs, logits_metadata + logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs( + logits_output.next_token_logprobs, logits_metadata )[1] - return output + return sample_output, logits_output diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index eb7aaaf2c1..3d40c9d755 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -1,3 +1,5 @@ +from __future__ import annotations + """ Copyright 2023-2024 SGLang Team Licensed under the Apache License, Version 2.0 (the "License"); @@ -26,6 +28,7 @@ if TYPE_CHECKING: from sglang.srt.model_executor.model_runner import ModelRunner + from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo class ForwardMode(IntEnum): @@ -42,6 +45,7 @@ class InputMetadata: """Store all inforamtion of a forward pass.""" forward_mode: ForwardMode + sampling_info: SamplingBatchInfo batch_size: int req_pool_indices: torch.Tensor seq_lens: torch.Tensor @@ -58,17 +62,20 @@ class InputMetadata: # For extend extend_seq_lens: torch.Tensor = None + extend_prefix_lens: torch.Tensor = None extend_start_loc: torch.Tensor = None extend_no_prefix: bool = None - # Output options + # For logprob return_logprob: bool = False top_logprobs_nums: List[int] = None + extend_seq_lens_cpu: List[int] = None + logprob_start_lens_cpu: List[int] = None # For multimodal pixel_values: List[torch.Tensor] = None - image_sizes: List[List[int]] = None - image_offsets: List[int] = None + image_sizes: List[List[List[int]]] = None + image_offsets: List[List[int]] = None # Trition attention backend triton_max_seq_len: int = 0 @@ -85,15 +92,8 @@ class InputMetadata: def init_multimuldal_info(self, batch: ScheduleBatch): reqs = batch.reqs self.pixel_values = [r.pixel_values for r in reqs] - self.image_sizes = [r.image_size for r in reqs] - self.image_offsets = [ - ( - (r.image_offset - len(r.prefix_indices)) - if r.image_offset is not None - else 0 - ) - for r in reqs - ] + self.image_sizes = [r.image_sizes for r in reqs] + self.image_offsets = [r.image_offsets for r in reqs] def compute_positions(self, batch: ScheduleBatch): position_ids_offsets = batch.position_ids_offsets @@ -109,8 +109,8 @@ def compute_positions(self, batch: ScheduleBatch): self.positions = torch.tensor( np.concatenate( [ - np.arange(len(req.prefix_indices), len(req.fill_ids)) - for req in batch.reqs + np.arange(batch.prefix_lens_cpu[i], len(req.fill_ids)) + for i, req in enumerate(batch.reqs) ], axis=0, ), @@ -123,7 +123,7 @@ def compute_positions(self, batch: ScheduleBatch): np.concatenate( [ np.arange( - len(req.prefix_indices) + position_ids_offsets_cpu[i], + batch.prefix_lens_cpu[i] + position_ids_offsets_cpu[i], len(req.fill_ids) + position_ids_offsets_cpu[i], ) for i, req in enumerate(batch.reqs) @@ -139,14 +139,30 @@ def compute_positions(self, batch: ScheduleBatch): def compute_extend_infos(self, batch: ScheduleBatch): if self.forward_mode == ForwardMode.DECODE: self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None + self.extend_seq_lens_cpu = self.logprob_start_lens_cpu = None else: extend_lens_cpu = [ - len(r.fill_ids) - len(r.prefix_indices) for r in batch.reqs + len(r.fill_ids) - batch.prefix_lens_cpu[i] + for i, r in enumerate(batch.reqs) ] self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda") + self.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda") self.extend_start_loc = torch.zeros_like(self.seq_lens) self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0) - self.extend_no_prefix = all(len(r.prefix_indices) == 0 for r in batch.reqs) + self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu) + + self.extend_seq_lens_cpu = extend_lens_cpu + self.logprob_start_lens_cpu = [ + ( + min( + req.logprob_start_len - batch.prefix_lens_cpu[i], + extend_lens_cpu[i] - 1, + ) + if req.logprob_start_len >= batch.prefix_lens_cpu[i] + else extend_lens_cpu[i] - 1 # Fake extend, actually decode + ) + for i, req in enumerate(batch.reqs) + ] @classmethod def from_schedule_batch( @@ -157,6 +173,7 @@ def from_schedule_batch( ): ret = cls( forward_mode=forward_mode, + sampling_info=batch.sampling_info, batch_size=batch.batch_size(), req_pool_indices=batch.req_pool_indices, seq_lens=batch.seq_lens, @@ -167,6 +184,8 @@ def from_schedule_batch( top_logprobs_nums=batch.top_logprobs_nums, ) + ret.sampling_info.prepare_penalties() + ret.compute_positions(batch) ret.compute_extend_infos(batch) @@ -180,44 +199,47 @@ def from_schedule_batch( if forward_mode != ForwardMode.DECODE: ret.init_multimuldal_info(batch) - prefix_lens = None - if forward_mode != ForwardMode.DECODE: - prefix_lens = torch.tensor( - [len(r.prefix_indices) for r in batch.reqs], device="cuda" - ) - if model_runner.server_args.disable_flashinfer: - ret.init_triton_args(batch, prefix_lens) + ret.init_triton_args(batch) flashinfer_use_ragged = False if not model_runner.server_args.disable_flashinfer: if ( forward_mode != ForwardMode.DECODE and int(torch.sum(ret.seq_lens)) > 4096 + and model_runner.sliding_window_size is None ): flashinfer_use_ragged = True ret.init_flashinfer_handlers( - model_runner, prefix_lens, flashinfer_use_ragged + model_runner, batch.prefix_lens_cpu, flashinfer_use_ragged ) return ret - def init_triton_args(self, batch: ScheduleBatch, prefix_lens): + def init_triton_args(self, batch: ScheduleBatch): """Init auxiliary variables for triton attention backend.""" self.triton_max_seq_len = int(torch.max(self.seq_lens)) - self.triton_prefix_lens = prefix_lens self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32) self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0) if self.forward_mode == ForwardMode.DECODE: self.triton_max_extend_len = None else: - extend_seq_lens = self.seq_lens - prefix_lens + self.triton_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda") + extend_seq_lens = self.seq_lens - self.triton_prefix_lens self.triton_max_extend_len = int(torch.max(extend_seq_lens)) def init_flashinfer_handlers( - self, model_runner, prefix_lens, flashinfer_use_ragged + self, + model_runner, + prefix_lens_cpu, + flashinfer_use_ragged, ): + if self.forward_mode == ForwardMode.DECODE: + prefix_lens = None + else: + prefix_lens = self.extend_prefix_lens + update_flashinfer_indices( self.forward_mode, model_runner, @@ -255,65 +277,139 @@ def update_flashinfer_indices( head_dim = model_runner.model_config.head_dim batch_size = len(req_pool_indices) - if flashinfer_use_ragged: - paged_kernel_lens = prefix_lens - else: - paged_kernel_lens = seq_lens - - kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda") - kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) - req_pool_indices_cpu = req_pool_indices.cpu().numpy() - paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy() - kv_indices = torch.cat( - [ - model_runner.req_to_token_pool.req_to_token[ - req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i] - ] - for i in range(batch_size) - ], - dim=0, - ).contiguous() - kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda") - - if forward_mode == ForwardMode.DECODE: - # CUDA graph uses different flashinfer_decode_wrapper - if flashinfer_decode_wrapper is None: - flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper - - flashinfer_decode_wrapper.end_forward() - flashinfer_decode_wrapper.begin_forward( - kv_indptr, - kv_indices, - kv_last_page_len, - num_qo_heads, - num_kv_heads, - head_dim, - 1, - ) - else: - # extend part - qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda") - qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0) - + if model_runner.sliding_window_size is None: if flashinfer_use_ragged: - model_runner.flashinfer_prefill_wrapper_ragged.end_forward() - model_runner.flashinfer_prefill_wrapper_ragged.begin_forward( - qo_indptr, - qo_indptr, + paged_kernel_lens = prefix_lens + else: + paged_kernel_lens = seq_lens + + kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda") + kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) + req_pool_indices_cpu = req_pool_indices.cpu().numpy() + paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy() + kv_indices = torch.cat( + [ + model_runner.req_to_token_pool.req_to_token[ + req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i] + ] + for i in range(batch_size) + ], + dim=0, + ).contiguous() + kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda") + + if forward_mode == ForwardMode.DECODE: + # CUDA graph uses different flashinfer_decode_wrapper + if flashinfer_decode_wrapper is None: + flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper + + flashinfer_decode_wrapper.end_forward() + flashinfer_decode_wrapper.begin_forward( + kv_indptr, + kv_indices, + kv_last_page_len, num_qo_heads, num_kv_heads, head_dim, + 1, + data_type=model_runner.kv_cache_dtype, + q_data_type=model_runner.dtype, ) + else: + # extend part + qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda") + qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0) + + if flashinfer_use_ragged: + model_runner.flashinfer_prefill_wrapper_ragged.end_forward() + model_runner.flashinfer_prefill_wrapper_ragged.begin_forward( + qo_indptr, + qo_indptr, + num_qo_heads, + num_kv_heads, + head_dim, + ) - # cached part - model_runner.flashinfer_prefill_wrapper_paged.end_forward() - model_runner.flashinfer_prefill_wrapper_paged.begin_forward( - qo_indptr, - kv_indptr, - kv_indices, - kv_last_page_len, - num_qo_heads, - num_kv_heads, - head_dim, - 1, - ) + # cached part + model_runner.flashinfer_prefill_wrapper_paged.end_forward() + model_runner.flashinfer_prefill_wrapper_paged.begin_forward( + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + 1, + ) + else: + # window attention use paged only + kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda") + for wrapper_id in range(2): + if wrapper_id == 0: + if forward_mode == ForwardMode.DECODE: + paged_kernel_lens = torch.minimum( + seq_lens, torch.tensor(model_runner.sliding_window_size + 1) + ) + else: + paged_kernel_lens = torch.minimum( + seq_lens, + torch.tensor(model_runner.sliding_window_size) + + seq_lens + - prefix_lens, + ) + else: + paged_kernel_lens = seq_lens + + kv_start_idx = seq_lens - paged_kernel_lens + + kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda") + kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) + req_pool_indices_cpu = req_pool_indices.cpu().numpy() + paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy() + kv_indices = torch.cat( + [ + model_runner.req_to_token_pool.req_to_token[ + req_pool_indices_cpu[i], + kv_start_idx[i] : kv_start_idx[i] + paged_kernel_lens_cpu[i], + ] + for i in range(batch_size) + ], + dim=0, + ).contiguous() + + if forward_mode == ForwardMode.DECODE: + # CUDA graph uses different flashinfer_decode_wrapper + if flashinfer_decode_wrapper is None: + flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper + + flashinfer_decode_wrapper[wrapper_id].end_forward() + flashinfer_decode_wrapper[wrapper_id].begin_forward( + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + 1, + data_type=model_runner.kv_cache_dtype, + q_data_type=model_runner.dtype, + ) + else: + # extend part + qo_indptr = torch.zeros( + (batch_size + 1,), dtype=torch.int32, device="cuda" + ) + qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0) + + model_runner.flashinfer_prefill_wrapper_paged[wrapper_id].end_forward() + model_runner.flashinfer_prefill_wrapper_paged[wrapper_id].begin_forward( + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + 1, + ) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 574ad36580..e6f5e74311 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -15,13 +15,13 @@ """ModelRunner runs the forward passes of the models.""" +import gc import importlib import importlib.resources import logging import pkgutil -import warnings from functools import lru_cache -from typing import Optional, Type +from typing import Optional, Tuple, Type import torch import torch.nn as nn @@ -37,23 +37,28 @@ get_tp_group, init_distributed_environment, initialize_model_parallel, + set_custom_all_reduce, ) +from vllm.distributed.parallel_state import in_the_same_node_as +from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import ModelRegistry from sglang.global_config import global_config +from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.layers.sampler import SampleOutput from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.mem_cache.memory_pool import ( MHATokenToKVPool, MLATokenToKVPool, ReqToTokenPool, ) -from sglang.srt.model_config import AttentionArch +from sglang.srt.model_config import AttentionArch, ModelConfig from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( get_available_gpu_memory, is_generation_model, - is_llama3_405b_fp8, + is_llama3_405b_fp8_head_16, is_multimodal_model, monkey_patch_vllm_dummy_weight_loader, monkey_patch_vllm_p2p_access_check, @@ -66,7 +71,7 @@ class ModelRunner: def __init__( self, - model_config, + model_config: ModelConfig, mem_fraction_static: float, gpu_id: int, tp_rank: int, @@ -82,27 +87,49 @@ def __init__( self.tp_size = tp_size self.nccl_port = nccl_port self.server_args = server_args - self.is_multimodal_model = is_multimodal_model(self.model_config) + self.is_multimodal_model = is_multimodal_model( + self.model_config.hf_config.architectures + ) global_server_args_dict.update( { "disable_flashinfer": server_args.disable_flashinfer, "disable_flashinfer_sampling": server_args.disable_flashinfer_sampling, - "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32, + "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32, "enable_mla": server_args.enable_mla, } ) + if self.is_multimodal_model: + logger.info( + "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models." + ) + server_args.chunked_prefill_size = None + server_args.mem_fraction_static *= 0.95 + + min_per_gpu_memory = self.init_torch_distributed() + self.load_model() + self.init_memory_pool( + min_per_gpu_memory, + server_args.max_num_reqs, + server_args.max_total_tokens, + ) + self.init_cublas() + self.init_flashinfer() + self.init_cuda_graphs() + + def init_torch_distributed(self): # Init torch distributed torch.cuda.set_device(self.gpu_id) - logger.info(f"[gpu={self.gpu_id}] Init nccl begin.") + logger.info("Init nccl begin.") - if not server_args.enable_p2p_check: + if not self.server_args.enable_p2p_check: monkey_patch_vllm_p2p_access_check(self.gpu_id) - if server_args.nccl_init_addr: - nccl_init_method = f"tcp://{server_args.nccl_init_addr}" + if self.server_args.nccl_init_addr: + nccl_init_method = f"tcp://{self.server_args.nccl_init_addr}" else: nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}" + set_custom_all_reduce(not self.server_args.disable_custom_all_reduce) init_distributed_environment( backend="nccl", world_size=self.tp_size, @@ -111,43 +138,45 @@ def __init__( distributed_init_method=nccl_init_method, ) initialize_model_parallel(tensor_model_parallel_size=self.tp_size) - self.tp_group = get_tp_group() - total_gpu_memory = get_available_gpu_memory( + min_per_gpu_memory = get_available_gpu_memory( self.gpu_id, distributed=self.tp_size > 1 ) + self.tp_group = get_tp_group() + # Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph, + # so we disable padding in cuda graph. + if not all(in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)): + self.server_args.disable_cuda_graph_padding = True + logger.info( + "Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism." + ) + + # Check memory for tensor parallelism if self.tp_size > 1: - total_local_gpu_memory = get_available_gpu_memory(self.gpu_id) - if total_local_gpu_memory < total_gpu_memory * 0.9: + local_gpu_memory = get_available_gpu_memory(self.gpu_id) + if min_per_gpu_memory < local_gpu_memory * 0.9: raise ValueError( "The memory capacity is unbalanced. Some GPUs may be occupied by other processes." ) - # Load the model and create memory pool - self.load_model() - self.init_memory_pool( - total_gpu_memory, - server_args.max_num_reqs, - server_args.max_total_tokens, - ) - self.init_cublas() - self.init_flashinfer() - - if self.is_generation: - # FIXME Currently, cuda graph only capture decode steps, which only exists in causal models - # Capture cuda graphs - self.init_cuda_graphs() + return min_per_gpu_memory def load_model(self): logger.info( - f"[gpu={self.gpu_id}] Load weight begin. " - f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" + f"Load weight begin. avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" ) + if torch.cuda.get_device_capability()[0] < 8: + logger.info( + "Compute capability below sm80. Use float16 due to lack of bfloat16 support." + ) + self.server_args.dtype = "float16" + if torch.cuda.get_device_capability()[1] < 5: + raise RuntimeError("SGLang only supports sm75 and above.") monkey_patch_vllm_dummy_weight_loader() - device_config = DeviceConfig() - load_config = LoadConfig(load_format=self.server_args.load_format) - vllm_model_config = VllmModelConfig( + self.device_config = DeviceConfig() + self.load_config = LoadConfig(load_format=self.server_args.load_format) + self.vllm_model_config = VllmModelConfig( model=self.server_args.model_path, quantization=self.server_args.quantization, tokenizer=None, @@ -158,47 +187,132 @@ def load_model(self): skip_tokenizer_init=True, ) - if is_llama3_405b_fp8(self.model_config) and self.tp_size <= 8: - # A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints + # A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints + # Drop this after Sept, 2024. + if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8: self.model_config.hf_config.num_key_value_heads = 8 - vllm_model_config.hf_config.num_key_value_heads = 8 + self.vllm_model_config.hf_config.num_key_value_heads = 8 monkey_patch_vllm_qvk_linear_loader() - self.dtype = vllm_model_config.dtype + self.dtype = self.vllm_model_config.dtype if self.model_config.model_overide_args is not None: - vllm_model_config.hf_config.update(self.model_config.model_overide_args) - - if ( - self.server_args.efficient_weight_load - and "llama" in self.server_args.model_path.lower() - and self.server_args.quantization == "fp8" - ): - from sglang.srt.model_loader.model_loader import get_model - else: - from vllm.model_executor.model_loader import get_model + self.vllm_model_config.hf_config.update( + self.model_config.model_overide_args + ) self.model = get_model( - model_config=vllm_model_config, - device_config=device_config, - load_config=load_config, - lora_config=None, - multimodal_config=None, + model_config=self.vllm_model_config, + load_config=self.load_config, + device_config=self.device_config, parallel_config=None, scheduler_config=None, + lora_config=None, cache_config=None, ) + self.sliding_window_size = ( + self.model.get_attention_sliding_window_size() + if hasattr(self.model, "get_attention_sliding_window_size") + else None + ) self.is_generation = is_generation_model( - self.model_config.hf_config.architectures + self.model_config.hf_config.architectures, self.server_args.is_embedding ) logger.info( - f"[gpu={self.gpu_id}] Load weight end. " + f"Load weight end. " f"type={type(self.model).__name__}, " f"dtype={self.dtype}, " f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" ) - def profile_max_num_token(self, total_gpu_memory): + def update_weights(self, model_path: str, load_format: str): + """Update weights in-place.""" + from vllm.model_executor.model_loader.loader import ( + DefaultModelLoader, + device_loading_context, + get_model_loader, + ) + from vllm.model_executor.model_loader.utils import set_default_torch_dtype + + logger.info( + f"Update weights begin. " + f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" + ) + + target_device = torch.device(self.device_config.device) + + try: + # TODO: Use a better method to check this + vllm_model_config = VllmModelConfig( + model=model_path, + quantization=self.server_args.quantization, + tokenizer=None, + tokenizer_mode=None, + trust_remote_code=self.server_args.trust_remote_code, + dtype=self.server_args.dtype, + seed=42, + skip_tokenizer_init=True, + ) + except Exception as e: + logger.error(f"Failed to load model config: {e}") + return False, "Failed to update model weights" + + load_config = LoadConfig(load_format=load_format) + + # Only support vllm DefaultModelLoader for now + loader = get_model_loader(load_config) + if not isinstance(loader, DefaultModelLoader): + logger.error("Failed to get weights iterator: Unsupported loader") + return False, "Failed to update model weights" + + def get_weight_iter(config): + iter = loader._get_weights_iterator( + config.model, + config.revision, + fall_back_to_pt=getattr( + self.model, "fall_back_to_pt_during_load", True + ), + ) + return iter + + def model_load_weights(model, iter): + model.load_weights(iter) + for _, module in self.model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + with device_loading_context(module, target_device): + quant_method.process_weights_after_loading(module) + return model + + with set_default_torch_dtype(vllm_model_config.dtype): + try: + iter = get_weight_iter(vllm_model_config) + except Exception as e: + message = f"Failed to get weights iterator: {e}" + logger.error(message) + return False, message + try: + model = model_load_weights(self.model, iter) + except Exception as e: + message = f"Failed to update weights: {e}. \n Rolling back to original weights" + logger.error(message) + del iter + gc.collect() + iter = get_weight_iter(self.vllm_model_config) + self.model = model_load_weights(self.model, iter) + return False, message + + self.model = model + self.server_args.model_path = model_path + self.server_args.load_format = load_format + self.vllm_model_config = vllm_model_config + self.load_config = load_config + self.model_config.path = model_path + + logger.info("Update weights end.") + return True, "Succeeded to update model weights" + + def profile_max_num_token(self, total_gpu_memory: int): available_gpu_memory = get_available_gpu_memory( self.gpu_id, distributed=self.tp_size > 1 ) @@ -209,7 +323,7 @@ def profile_max_num_token(self, total_gpu_memory): cell_size = ( (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim) * self.model_config.num_hidden_layers - * torch._utils._element_size(self.dtype) + * torch._utils._element_size(self.kv_cache_dtype) ) else: cell_size = ( @@ -217,7 +331,7 @@ def profile_max_num_token(self, total_gpu_memory): * self.model_config.head_dim * self.model_config.num_hidden_layers * 2 - * torch._utils._element_size(self.dtype) + * torch._utils._element_size(self.kv_cache_dtype) ) rest_memory = available_gpu_memory - total_gpu_memory * ( 1 - self.mem_fraction_static @@ -226,12 +340,30 @@ def profile_max_num_token(self, total_gpu_memory): return max_num_token def init_memory_pool( - self, total_gpu_memory, max_num_reqs=None, max_total_tokens=None + self, + total_gpu_memory: int, + max_num_reqs: int = None, + max_total_tokens: int = None, ): + if self.server_args.kv_cache_dtype == "auto": + self.kv_cache_dtype = self.dtype + elif self.server_args.kv_cache_dtype == "fp8_e5m2": + if self.server_args.disable_flashinfer or self.server_args.enable_mla: + logger.warning( + "FP8 KV cache is not supported for Triton kernel now, using auto kv cache dtype" + ) + self.kv_cache_dtype = self.dtype + else: + self.kv_cache_dtype = torch.float8_e5m2 + else: + raise ValueError( + f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}." + ) + self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory) if max_total_tokens is not None: if max_total_tokens > self.max_total_num_tokens: - warnings.warn( + logging.warning( f"max_total_tokens={max_total_tokens} is larger than the profiled value " f"{self.max_total_num_tokens}. " f"Use the profiled value instead." @@ -264,7 +396,7 @@ def init_memory_pool( ): self.token_to_kv_pool = MLATokenToKVPool( self.max_total_num_tokens, - dtype=self.dtype, + dtype=self.kv_cache_dtype, kv_lora_rank=self.model_config.kv_lora_rank, qk_rope_head_dim=self.model_config.qk_rope_head_dim, layer_num=self.model_config.num_hidden_layers, @@ -275,13 +407,13 @@ def init_memory_pool( else: self.token_to_kv_pool = MHATokenToKVPool( self.max_total_num_tokens, - dtype=self.dtype, + dtype=self.kv_cache_dtype, head_num=self.model_config.get_num_kv_heads(self.tp_size), head_dim=self.model_config.head_dim, layer_num=self.model_config.num_hidden_layers, ) logger.info( - f"[gpu={self.gpu_id}] Memory pool end. " + f"Memory pool end. " f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" ) @@ -295,7 +427,11 @@ def init_cublas(self): return c def init_flashinfer(self): + """Init flashinfer attention kernel wrappers.""" if self.server_args.disable_flashinfer: + assert ( + self.sliding_window_size is None + ), "turn on flashinfer to support window attention" self.flashinfer_prefill_wrapper_ragged = None self.flashinfer_prefill_wrapper_paged = None self.flashinfer_decode_wrapper = None @@ -309,36 +445,72 @@ def init_flashinfer(self): else: use_tensor_cores = False - self.flashinfer_workspace_buffers = torch.empty( - 2, global_config.flashinfer_workspace_size, dtype=torch.uint8, device="cuda" - ) - self.flashinfer_prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper( - self.flashinfer_workspace_buffers[0], "NHD" - ) - self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper( - self.flashinfer_workspace_buffers[1], "NHD" - ) - self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( - self.flashinfer_workspace_buffers[0], - "NHD", - use_tensor_cores=use_tensor_cores, - ) + if self.sliding_window_size is None: + self.flashinfer_workspace_buffer = torch.empty( + global_config.flashinfer_workspace_size, + dtype=torch.uint8, + device="cuda", + ) + self.flashinfer_prefill_wrapper_ragged = ( + BatchPrefillWithRaggedKVCacheWrapper( + self.flashinfer_workspace_buffer, "NHD" + ) + ) + self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper( + self.flashinfer_workspace_buffer, "NHD" + ) + self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( + self.flashinfer_workspace_buffer, + "NHD", + use_tensor_cores=use_tensor_cores, + ) + else: + self.flashinfer_workspace_buffer = torch.empty( + global_config.flashinfer_workspace_size, + dtype=torch.uint8, + device="cuda", + ) + self.flashinfer_prefill_wrapper_ragged = None + self.flashinfer_prefill_wrapper_paged = [] + self.flashinfer_decode_wrapper = [] + for i in range(2): + self.flashinfer_prefill_wrapper_paged.append( + BatchPrefillWithPagedKVCacheWrapper( + self.flashinfer_workspace_buffer, "NHD" + ) + ) + self.flashinfer_decode_wrapper.append( + BatchDecodeWithPagedKVCacheWrapper( + self.flashinfer_workspace_buffer, + "NHD", + use_tensor_cores=use_tensor_cores, + ) + ) def init_cuda_graphs(self): + """Capture cuda graphs.""" + if not self.is_generation: + # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models + return + from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer: self.cuda_graph_runner = None return - logger.info( - f"[gpu={self.gpu_id}] Capture cuda graph begin. This can take up to several minutes." - ) - batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 17)] + logger.info("Capture cuda graph begin. This can take up to several minutes.") + + if self.server_args.disable_cuda_graph_padding: + batch_size_list = list(range(1, 32)) + [64, 128] + else: + batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 21)] + self.cuda_graph_runner = CudaGraphRunner( self, max_batch_size_to_capture=max(batch_size_list), use_torch_compile=self.server_args.enable_torch_compile, + disable_padding=self.server_args.disable_cuda_graph_padding, ) try: self.cuda_graph_runner.capture(batch_size_list) @@ -346,19 +518,25 @@ def init_cuda_graphs(self): raise Exception( f"Capture cuda graph failed: {e}\n" "Possible solutions:\n" - "1. disable torch compile by not using --enable-torch-compile\n" - "2. disable cuda graph by --disable-cuda-graph\n" - "3. set --mem-fraction-static to a smaller value\n" + "1. disable cuda graph by --disable-cuda-graph\n" + "2. set --mem-fraction-static to a smaller value\n" + "3. disable torch compile by not using --enable-torch-compile\n" "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n" ) @torch.inference_mode() def forward_decode(self, batch: ScheduleBatch): - if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)): + if ( + self.cuda_graph_runner + and self.cuda_graph_runner.can_run(len(batch.reqs)) + and not batch.sampling_info.has_bias() + ): return self.cuda_graph_runner.replay(batch) input_metadata = InputMetadata.from_schedule_batch( - self, batch, ForwardMode.DECODE + self, + batch, + ForwardMode.DECODE, ) return self.model.forward( @@ -368,16 +546,29 @@ def forward_decode(self, batch: ScheduleBatch): @torch.inference_mode() def forward_extend(self, batch: ScheduleBatch): input_metadata = InputMetadata.from_schedule_batch( - self, batch, forward_mode=ForwardMode.EXTEND - ) - return self.model.forward( - batch.input_ids, input_metadata.positions, input_metadata + self, + batch, + forward_mode=ForwardMode.EXTEND, ) + if self.is_generation: + return self.model.forward( + batch.input_ids, input_metadata.positions, input_metadata + ) + else: + # Only embedding models have get_embedding parameter + return self.model.forward( + batch.input_ids, + input_metadata.positions, + input_metadata, + get_embedding=True, + ) @torch.inference_mode() def forward_extend_multi_modal(self, batch: ScheduleBatch): input_metadata = InputMetadata.from_schedule_batch( - self, batch, forward_mode=ForwardMode.EXTEND + self, + batch, + forward_mode=ForwardMode.EXTEND, ) return self.model.forward( batch.input_ids, @@ -388,7 +579,9 @@ def forward_extend_multi_modal(self, batch: ScheduleBatch): input_metadata.image_offsets, ) - def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode): + def forward( + self, batch: ScheduleBatch, forward_mode: ForwardMode + ) -> Tuple[SampleOutput, LogitsProcessorOutput]: if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND: return self.forward_extend_multi_modal(batch) elif forward_mode == ForwardMode.DECODE: @@ -444,4 +637,4 @@ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]: # Monkey patch model loader -setattr(ModelRegistry, "load_model_cls", load_model_cls_srt) +setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt) diff --git a/python/sglang/srt/model_loader/model_loader.py b/python/sglang/srt/model_loader/model_loader.py deleted file mode 100644 index 4b7e32b6e5..0000000000 --- a/python/sglang/srt/model_loader/model_loader.py +++ /dev/null @@ -1,292 +0,0 @@ -""" -Copyright 2023-2024 SGLang Team -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -# temporarily adapted from https://github.com/vllm-project/vllm/blob/10383887e03412196a2689b9398290719c4797bf/vllm/model_executor/model_loader/loader.py -# FIXME: in progress of refactoring the model loader - -import glob -import os -import re -from typing import Any, Dict, Generator, List, Optional, Tuple, Type - -import torch -from torch import nn -from tqdm import tqdm -from vllm.config import ( - CacheConfig, - DeviceConfig, - LoadConfig, - LoadFormat, - LoRAConfig, - ModelConfig, - MultiModalConfig, - ParallelConfig, - SchedulerConfig, -) -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig -from vllm.model_executor.model_loader.utils import ( - get_model_architecture, - set_default_torch_dtype, -) -from vllm.platforms import current_platform - -from sglang.srt.model_loader.utils import ( - download_safetensors_index_file_from_hf, - download_weights_from_hf, - filter_duplicate_safetensors_files, - get_quant_config, - safetensors_weights_iterator, -) - - -def _get_quantization_config( - model_config: ModelConfig, load_config: LoadConfig -) -> Optional[QuantizationConfig]: - """Get the quantization config.""" - if model_config.quantization is not None: - quant_config = get_quant_config(model_config, load_config) - capability = current_platform.get_device_capability() - capability = capability[0] * 10 + capability[1] - if capability < quant_config.get_min_capability(): - raise ValueError( - f"The quantization method {model_config.quantization} is not " - "supported for the current GPU. " - f"Minimum capability: {quant_config.get_min_capability()}. " - f"Current capability: {capability}." - ) - supported_dtypes = quant_config.get_supported_act_dtypes() - if model_config.dtype not in supported_dtypes: - raise ValueError( - f"{model_config.dtype} is not supported for quantization " - f"method {model_config.quantization}. Supported dtypes: " - f"{supported_dtypes}" - ) - return quant_config - return None - - -def _get_model_initialization_kwargs( - model_class: Type[nn.Module], - lora_config: Optional[LoRAConfig], - multimodal_config: Optional[MultiModalConfig], -) -> Dict[str, Any]: - """Get extra kwargs for model initialization.""" - extra_kwargs: Dict[str, Any] = {} - - assert lora_config is None - assert multimodal_config is None - - return extra_kwargs - - -def _initialize_model( - model_config: ModelConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - multimodal_config: Optional[MultiModalConfig], - cache_config: CacheConfig, -) -> nn.Module: - """Initialize a model with the given configurations.""" - model_class = get_model_architecture(model_config)[0] - quant_config = _get_quantization_config(model_config, load_config) - - return model_class( - config=model_config.hf_config, - cache_config=cache_config, - quant_config=quant_config, - efficient_weight_load=True, - **_get_model_initialization_kwargs(model_class, lora_config, multimodal_config), - ) - - -class ModelLoader: - """Model loader that can load different file types from disk.""" - - def __init__(self, load_config: LoadConfig): - self.load_config = load_config - - def _prepare_weights( - self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool - ) -> Tuple[str, List[str], bool]: - """Prepare weights for the model. - - If the model is not local, it will be downloaded.""" - - is_local = os.path.isdir(model_name_or_path) - load_format = self.load_config.load_format - use_safetensors = False - # Some quantized models use .pt files for storing the weights. - if load_format == LoadFormat.AUTO: - allow_patterns = ["*.safetensors", "*.bin"] - elif load_format == LoadFormat.SAFETENSORS: - use_safetensors = True - allow_patterns = ["*.safetensors"] - elif load_format == LoadFormat.PT: - allow_patterns = ["*.pt"] - elif load_format == LoadFormat.NPCACHE: - allow_patterns = ["*.bin"] - else: - raise ValueError(f"Unknown load_format: {load_format}") - - if fall_back_to_pt: - allow_patterns += ["*.pt"] - - if not is_local: - hf_folder = download_weights_from_hf( - model_name_or_path, - self.load_config.download_dir, - allow_patterns, - revision, - ) - else: - hf_folder = model_name_or_path - - hf_weights_files: List[str] = [] - for pattern in allow_patterns: - hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) - if len(hf_weights_files) > 0: - if pattern == "*.safetensors": - use_safetensors = True - break - - if use_safetensors: - # For models like Mistral-7B-Instruct-v0.3 - # there are both sharded safetensors files and a consolidated - # safetensors file. Using both breaks. - # Here, we download the `model.safetensors.index.json` and filter - # any files not found in the index. - if not is_local: - download_safetensors_index_file_from_hf( - model_name_or_path, self.load_config.download_dir, revision - ) - hf_weights_files = filter_duplicate_safetensors_files( - hf_weights_files, hf_folder - ) - else: - hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files) - - if len(hf_weights_files) == 0: - raise RuntimeError( - f"Cannot find any model weights with `{model_name_or_path}`" - ) - - return hf_folder, hf_weights_files, use_safetensors - - def _get_weights_iterator( - self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool - ) -> Generator[Tuple[str, torch.Tensor], None, None]: - """Get an iterator for the model weights based on the load format.""" - hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( - model_name_or_path, revision, fall_back_to_pt - ) - if self.load_config.load_format == LoadFormat.NPCACHE: - # Currently np_cache only support *.bin checkpoints - assert use_safetensors is False - weights_iterator = np_cache_weights_iterator( - model_name_or_path, - self.load_config.download_dir, - hf_folder, - hf_weights_files, - ) - elif use_safetensors: - weights_iterator = safetensors_weights_iterator(hf_weights_files) - else: - weights_iterator = pt_weights_iterator(hf_weights_files) - - return weights_iterator - - def load_model( - self, - *, - model_config: ModelConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - multimodal_config: Optional[MultiModalConfig], - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - cache_config: CacheConfig, - ) -> nn.Module: - with set_default_torch_dtype(model_config.dtype): - with torch.device(device_config.device): - model = _initialize_model( - model_config, - self.load_config, - lora_config, - multimodal_config, - cache_config, - ) - weights = self._get_weights_iterator( - model_config.model, - model_config.revision, - fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True), - ) - - modules = {} - for name, module in model.named_modules(): - modules[name] = module - - def apply_quant_method(module): - quant_method = getattr(module, "quant_method", None) - if quant_method is not None: - # print("before apply quant", module.weight, module.weight.dtype) - quant_method.process_weights_after_loading(module) - # print("after apply quant", module.weight, module.weight.dtype) - # FIXME: Remove this after Mixtral is updated - # to use quant_method. - if hasattr(module, "process_weights_after_loading"): - module.process_weights_after_loading() - - if torch.cuda.current_device() == 0: - weights = tqdm( - weights, total=model.get_num_params() * 1.5, desc="load model" - ) - - num_shard = {} - num_loaded = {} - for name, loaded_weight in weights: - model.load_weights(None, name, loaded_weight) - module_name, shard_num = model.get_module_name(name) - num_shard[module_name] = shard_num - if module_name not in num_loaded: - num_loaded[module_name] = 1 - else: - num_loaded[module_name] += 1 - if num_loaded[module_name] == num_shard[module_name]: - apply_quant_method(modules[module_name]) - - return model.eval() - - -def get_model( - *, - model_config: ModelConfig, - load_config: LoadConfig, - device_config: DeviceConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - lora_config: Optional[LoRAConfig], - multimodal_config: Optional[MultiModalConfig], - cache_config: CacheConfig, -) -> nn.Module: - loader = ModelLoader(load_config) - return loader.load_model( - model_config=model_config, - device_config=device_config, - lora_config=lora_config, - multimodal_config=multimodal_config, - parallel_config=parallel_config, - scheduler_config=scheduler_config, - cache_config=cache_config, - ) diff --git a/python/sglang/srt/model_loader/utils.py b/python/sglang/srt/model_loader/utils.py deleted file mode 100644 index 9d6520e2ae..0000000000 --- a/python/sglang/srt/model_loader/utils.py +++ /dev/null @@ -1,275 +0,0 @@ -""" -Copyright 2023-2024 SGLang Team -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -# temporarily adapted from vLLM -# FIXME: in progress of refactoring the model loader -"""Utilities for selecting and loading models.""" -import contextlib -import fnmatch -import hashlib -import json -import logging -import os -import tempfile -from typing import Any, Generator, Iterable, List, Optional, Tuple, Type - -import filelock -import huggingface_hub.constants -import torch -from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download -from safetensors.torch import load_file, safe_open, save_file -from torch import nn -from tqdm.auto import tqdm -from transformers.utils import SAFE_WEIGHTS_INDEX_NAME -from vllm.config import LoadConfig, ModelConfig -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig - -from sglang.srt.layers.quantization import get_quantization_config - -logger = logging.getLogger(__name__) -temp_dir = tempfile.gettempdir() - - -@contextlib.contextmanager -def set_default_torch_dtype(dtype: torch.dtype): - """Sets the default torch dtype to the given dtype.""" - old_dtype = torch.get_default_dtype() - torch.set_default_dtype(dtype) - yield - torch.set_default_dtype(old_dtype) - - -def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: - architectures = getattr(model_config.hf_config, "architectures", []) - # Special handling for quantized Mixtral. - # FIXME(woosuk): This is a temporary hack. - if ( - model_config.quantization is not None - and model_config.quantization != "fp8" - and "MixtralForCausalLM" in architectures - ): - architectures = ["QuantMixtralForCausalLM"] - - for arch in architectures: - model_cls = ModelRegistry.load_model_cls(arch) - if model_cls is not None: - return (model_cls, arch) - raise ValueError( - f"Model architectures {architectures} are not supported for now. " - f"Supported architectures: {ModelRegistry.get_supported_archs()}" - ) - - -class DisabledTqdm(tqdm): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs, disable=True) - - -def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None): - lock_dir = cache_dir or temp_dir - os.makedirs(os.path.dirname(lock_dir), exist_ok=True) - model_name = model_name_or_path.replace("/", "-") - hash_name = hashlib.sha256(model_name.encode()).hexdigest() - # add hash to avoid conflict with old users' lock files - lock_file_name = hash_name + model_name + ".lock" - # mode 0o666 is required for the filelock to be shared across users - lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), mode=0o666) - return lock - - -def download_weights_from_hf( - model_name_or_path: str, - cache_dir: Optional[str], - allow_patterns: List[str], - revision: Optional[str] = None, -) -> str: - """Download model weights from Hugging Face Hub. - - Args: - model_name_or_path (str): The model name or path. - cache_dir (Optional[str]): The cache directory to store the model - weights. If None, will use HF defaults. - allow_patterns (List[str]): The allowed patterns for the - weight files. Files matched by any of the patterns will be - downloaded. - revision (Optional[str]): The revision of the model. - - Returns: - str: The path to the downloaded model weights. - """ - if not huggingface_hub.constants.HF_HUB_OFFLINE: - # Before we download we look at that is available: - fs = HfFileSystem() - file_list = fs.ls(model_name_or_path, detail=False, revision=revision) - - # depending on what is available we download different things - for pattern in allow_patterns: - matching = fnmatch.filter(file_list, pattern) - if len(matching) > 0: - allow_patterns = [pattern] - break - - logger.info("Using model weights format %s", allow_patterns) - # Use file lock to prevent multiple processes from - # downloading the same model weights at the same time. - with get_lock(model_name_or_path, cache_dir): - hf_folder = snapshot_download( - model_name_or_path, - allow_patterns=allow_patterns, - cache_dir=cache_dir, - tqdm_class=DisabledTqdm, - revision=revision, - local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, - ) - return hf_folder - - -def download_safetensors_index_file_from_hf( - model_name_or_path: str, - cache_dir: Optional[str], - revision: Optional[str] = None, -) -> None: - """Download hf safetensors index file from Hugging Face Hub. - - Args: - model_name_or_path (str): The model name or path. - cache_dir (Optional[str]): The cache directory to store the model - weights. If None, will use HF defaults. - revision (Optional[str]): The revision of the model. - """ - # Use file lock to prevent multiple processes from - # downloading the same model weights at the same time. - with get_lock(model_name_or_path, cache_dir): - try: - # Download the safetensors index file. - hf_hub_download( - repo_id=model_name_or_path, - filename=SAFE_WEIGHTS_INDEX_NAME, - cache_dir=cache_dir, - revision=revision, - local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, - ) - # If file not found on remote or locally, we should not fail since - # only some models will have SAFE_WEIGHTS_INDEX_NAME. - except huggingface_hub.utils.EntryNotFoundError: - logger.info("No %s found in remote.", SAFE_WEIGHTS_INDEX_NAME) - except huggingface_hub.utils.LocalEntryNotFoundError: - logger.info("No %s found in local cache.", SAFE_WEIGHTS_INDEX_NAME) - - -# For models like Mistral-7B-v0.3, there are both sharded -# safetensors files and a consolidated safetensors file. -# Passing both of these to the weight loader functionality breaks. -# So, we use the SAFE_WEIGHTS_INDEX_NAME to -# look up which safetensors files should be used. -def filter_duplicate_safetensors_files( - hf_weights_files: List[str], hf_folder: str -) -> List[str]: - # model.safetensors.index.json is a mapping from keys in the - # torch state_dict to safetensors file holding that weight. - index_file_name = os.path.join(hf_folder, SAFE_WEIGHTS_INDEX_NAME) - if not os.path.isfile(index_file_name): - return hf_weights_files - - # Iterate through the weight_map (weight_name: safetensors files) - # to identify weights that we should use. - with open(index_file_name) as index_file: - weight_map = json.load(index_file)["weight_map"] - weight_files_in_index = set() - for weight_name in weight_map: - weight_files_in_index.add(os.path.join(hf_folder, weight_map[weight_name])) - # Filter out any fields that are not found in the index file. - hf_weights_files = [f for f in hf_weights_files if f in weight_files_in_index] - return hf_weights_files - - -def safetensors_weights_iterator( - hf_weights_files: List[str], -) -> Generator[Tuple[str, torch.Tensor], None, None]: - """Iterate over the weights in the model safetensor files.""" - for st_file in hf_weights_files: - with safe_open(st_file, framework="pt") as f: - for name in f.keys(): # noqa: SIM118 - param = f.get_tensor(name) - yield name, param - - -def get_quant_config( - model_config: ModelConfig, load_config: LoadConfig -) -> QuantizationConfig: - quant_cls = get_quantization_config(model_config.quantization) - # Read the quantization config from the HF model config, if available. - hf_quant_config = getattr(model_config.hf_config, "quantization_config", None) - if hf_quant_config is None: - # compressed-tensors uses a compressions_config - hf_quant_config = getattr(model_config.hf_config, "compression_config", None) - if hf_quant_config is not None: - return quant_cls.from_config(hf_quant_config) - # In case of bitsandbytes/QLoRA, get quant config from the adapter model. - if model_config.quantization == "bitsandbytes": - if ( - not load_config.model_loader_extra_config - or "qlora_adapter_name_or_path" not in load_config.model_loader_extra_config - ): - return quant_cls.from_config({"adapter_name_or_path": ""}) - model_name_or_path = load_config.model_loader_extra_config[ - "qlora_adapter_name_or_path" - ] - - else: - model_name_or_path = model_config.model - is_local = os.path.isdir(model_name_or_path) - if not is_local: - # Download the config files. - with get_lock(model_name_or_path, load_config.download_dir): - hf_folder = snapshot_download( - model_name_or_path, - revision=model_config.revision, - allow_patterns="*.json", - cache_dir=load_config.download_dir, - local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, - tqdm_class=DisabledTqdm, - ) - else: - hf_folder = model_name_or_path - - possible_config_filenames = quant_cls.get_config_filenames() - - # If the quantization config is not found, use the default config. - if not possible_config_filenames: - return quant_cls() - - config_files = glob.glob(os.path.join(hf_folder, "*.json")) - - quant_config_files = [ - f for f in config_files if any(f.endswith(x) for x in possible_config_filenames) - ] - if len(quant_config_files) == 0: - raise ValueError(f"Cannot find the config file for {model_config.quantization}") - if len(quant_config_files) > 1: - raise ValueError( - f"Found multiple config files for {model_config.quantization}: " - f"{quant_config_files}" - ) - - quant_config_file = quant_config_files[0] - with open(quant_config_file, "r") as f: - config = json.load(f) - - if model_config.quantization == "bitsandbytes": - config["adapter_name_or_path"] = model_name_or_path - - return quant_cls.from_config(config) diff --git a/python/sglang/srt/models/chatglm.py b/python/sglang/srt/models/chatglm.py index d2ad02fbf4..9eb04dc263 100644 --- a/python/sglang/srt/models/chatglm.py +++ b/python/sglang/srt/models/chatglm.py @@ -17,15 +17,13 @@ # Adapted from # https://github.com/THUDM/ChatGLM2-6B """Inference-only ChatGLM model compatible with THUDM weights.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, Optional, Tuple import torch from torch import nn from torch.nn import LayerNorm from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, @@ -33,18 +31,18 @@ ) from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs import ChatGLMConfig +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata LoraConfig = None @@ -383,17 +381,11 @@ def forward( input_metadata: InputMetadata, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, input_metadata) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters(remove_duplicate=False)) diff --git a/python/sglang/srt/models/commandr.py b/python/sglang/srt/models/commandr.py index 1259285c46..c360106f97 100644 --- a/python/sglang/srt/models/commandr.py +++ b/python/sglang/srt/models/commandr.py @@ -50,7 +50,6 @@ get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) -from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, @@ -62,8 +61,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.utils import set_weight_attrs +from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -326,6 +327,7 @@ def __init__( self.config = config self.quant_config = quant_config self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() self.model = CohereModel(config, quant_config) @torch.no_grad() @@ -340,9 +342,11 @@ def forward( positions, input_metadata, ) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/dbrx.py b/python/sglang/srt/models/dbrx.py index 39ac4aefa7..b3a76b56ae 100644 --- a/python/sglang/srt/models/dbrx.py +++ b/python/sglang/srt/models/dbrx.py @@ -45,6 +45,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -382,6 +383,7 @@ def __init__( padding_size=DEFAULT_VOCAB_PADDING_SIZE, ) self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() @torch.no_grad() def forward( @@ -391,9 +393,11 @@ def forward( input_metadata: InputMetadata, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, input_metadata) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): expert_params_mapping = [ diff --git a/python/sglang/srt/models/deepseek.py b/python/sglang/srt/models/deepseek.py index 98dcfd28df..b939602c1b 100644 --- a/python/sglang/srt/models/deepseek.py +++ b/python/sglang/srt/models/deepseek.py @@ -27,9 +27,7 @@ get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) -from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe -from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, @@ -44,8 +42,11 @@ ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -385,6 +386,7 @@ def __init__( config.vocab_size, config.hidden_size, quant_config=quant_config ) self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() @torch.no_grad() def forward( @@ -394,9 +396,11 @@ def forward( input_metadata: InputMetadata, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 739562730b..67d99d5124 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -26,9 +26,7 @@ get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) -from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, MergedColumnParallelLinear, @@ -43,8 +41,11 @@ ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -416,12 +417,8 @@ def __init__( v_head_dim=self.kv_lora_rank, ) - kv_b_proj = self.kv_b_proj - w_kc, w_vc = kv_b_proj.weight.unflatten( - 0, (-1, qk_nope_head_dim + v_head_dim) - ).split([qk_nope_head_dim, v_head_dim], dim=1) - self.w_kc = w_kc - self.w_vc = w_vc + self.w_kc = None + self.w_vc = None def forward( self, @@ -445,11 +442,12 @@ def forward( q_nope_out = q_input[..., : self.kv_lora_rank] torch.bmm(q_nope.transpose(0, 1), self.w_kc, out=q_nope_out.transpose(0, 1)) - k_input = self.kv_a_proj_with_mqa(hidden_states)[0].unsqueeze(1) - k_pe = k_input[..., self.kv_lora_rank :] - v_input = k_input[..., : self.kv_lora_rank] - v_input = self.kv_a_layernorm(v_input.contiguous()) + latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] + v_input = latent_cache[..., : self.kv_lora_rank] + v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1) + k_input = latent_cache.unsqueeze(1) k_input[..., : self.kv_lora_rank] = v_input + k_pe = k_input[..., self.kv_lora_rank :] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) q_input[..., self.kv_lora_rank :] = q_pe @@ -462,7 +460,7 @@ def forward( ) torch.bmm( attn_output.transpose(0, 1), - self.w_vc.transpose(1, 2).contiguous(), + self.w_vc, out=attn_bmm_output.transpose(0, 1), ) @@ -631,6 +629,7 @@ def __init__( config.vocab_size, config.hidden_size, quant_config=quant_config ) self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() def forward( self, @@ -639,9 +638,11 @@ def forward( input_metadata: InputMetadata, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ @@ -710,5 +711,15 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ) weight_loader(param, loaded_weight) + if global_server_args_dict["enable_mla"]: + for layer_id in range(self.config.num_hidden_layers): + self_attn = self.model.layers[layer_id].self_attn + w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten( + 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) + ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) + self_attn.w_kc = w_kc.contiguous() + self_attn.w_vc = w_vc.transpose(1, 2).contiguous() + del self_attn.kv_b_proj + EntryClass = DeepseekV2ForCausalLM diff --git a/python/sglang/srt/models/exaone.py b/python/sglang/srt/models/exaone.py new file mode 100644 index 0000000000..4dcafed7ce --- /dev/null +++ b/python/sglang/srt/models/exaone.py @@ -0,0 +1,399 @@ +""" +Copyright 2024 The LGcns AI Engineering Team +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# Adapted from llama2.py +"""Inference-only Exaone model compatible with HuggingFace weights.""" + +from typing import Any, Dict, Iterable, Optional, Tuple + +import torch +from torch import nn +from vllm.config import CacheConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler +from sglang.srt.model_executor.forward_batch_info import InputMetadata + + +class ExaoneGatedMLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.c_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.c_proj", + ) + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now." + ) + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.c_proj(x) + return x + + +class ExaoneAttention(nn.Module): + def __init__( + self, + config, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + layer_id: int = 0, + rope_theta: float = 500000, + rope_scaling: Optional[Dict[str, Any]] = None, + rope_is_neox_style: bool = True, + max_position_embeddings: int = 4096, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + self.head_dim = getattr( + config, "head_dim", self.hidden_size // self.total_num_heads + ) + self.rotary_dim = int( + self.head_dim * getattr(config, "partial_rotary_factor", 1) + ) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.out_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.rotary_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=rope_is_neox_style, + ) + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, input_metadata) + output, _ = self.out_proj(attn_output) + return output + + +class ExaoneDecoderLayer(nn.Module): + def __init__( + self, + config, + layer_id: int = 0, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 500000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None + ): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings + ) + rope_is_neox_style = getattr(config, "rope_is_neox_style", True) + max_position_embeddings = getattr(config, "max_position_embeddings", 4096) + self.self_attn = ExaoneAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + layer_id=layer_id, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + rope_is_neox_style=rope_is_neox_style, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.mlp = ExaoneGatedMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.activation_function, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + rms_norm_eps = config.layer_norm_epsilon + self.ln_1 = RMSNorm(config.hidden_size, eps=rms_norm_eps) + self.ln_2 = RMSNorm(config.hidden_size, eps=rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + else: + hidden_states, residual = self.ln_1(hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + input_metadata=input_metadata, + ) + + # Fully Connected + hidden_states, residual = self.ln_2(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class ExaoneModel(nn.Module): + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.wte = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.h = nn.ModuleList( + [ + ExaoneDecoderLayer( + config, i, quant_config=quant_config, prefix=f"model.h.{i}" + ) + for i in range(config.num_hidden_layers) + ] + ) + rms_norm_eps = config.layer_norm_epsilon + self.ln_f = RMSNorm(config.hidden_size, eps=rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_metadata: InputMetadata, + input_embeds: torch.Tensor = None, + ) -> torch.Tensor: + if input_embeds is None: + hidden_states = self.wte(input_ids) + else: + hidden_states = input_embeds + residual = None + for i in range(len(self.h)): + layer = self.h[i] + hidden_states, residual = layer( + positions, + hidden_states, + input_metadata, + residual, + ) + hidden_states, _ = self.ln_f(hidden_states, residual) + return hidden_states + + +class ExaoneForCausalLM(nn.Module): + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + efficient_weight_load=False, + ) -> None: + super().__init__() + self.config = config + self.quant_config = quant_config + self.transformer = ExaoneModel(config, quant_config=quant_config) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_metadata: InputMetadata, + input_embeds: torch.Tensor = None, + ) -> LogitsProcessorOutput: + hidden_states = self.transformer( + input_ids, positions, input_metadata, input_embeds + ) + logits_output = self.logits_processor( + input_ids, hidden_states, self.lm_head.weight, input_metadata + ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output + + def get_module_name(self, name): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id, num_shard) + ("qkv_proj", "q_proj", "q", 3), + ("qkv_proj", "k_proj", "k", 3), + ("qkv_proj", "v_proj", "v", 3), + ("gate_up_proj", "c_fc_0", 0, 2), + ("gate_up_proj", "c_fc_1", 1, 2), + ] + for param_name, weight_name, shard_id, num_shard in stacked_params_mapping: + if weight_name in name: + return ( + name.replace(weight_name, param_name)[: -len(".weight")], + num_shard, + ) + return name[: -len(".weight")], 1 + + def get_num_params(self): + params_dict = dict(self.named_parameters()) + return len(params_dict) + + def load_weights( + self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None + ): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "c_fc_0", 0), + ("gate_up_proj", "c_fc_1", 1), + ] + params_dict = dict(self.named_parameters()) + + def load_weights_per_param(name, loaded_weight): + if "rotary_emb.inv_freq" in name or "projector" in name: + return + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + return + if name.startswith("model.vision_tower") and name not in params_dict: + return + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + return + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + if name is None or loaded_weight is None: + for name, loaded_weight in weights: + name = name.replace("attn.attention", "self_attn") + load_weights_per_param(name, loaded_weight) + else: + name = name.replace("attn.attention", "self_attn") + load_weights_per_param(name, loaded_weight) + + +EntryClass = ExaoneForCausalLM diff --git a/python/sglang/srt/models/gemma.py b/python/sglang/srt/models/gemma.py index ce39731156..5a6e5df37f 100644 --- a/python/sglang/srt/models/gemma.py +++ b/python/sglang/srt/models/gemma.py @@ -23,8 +23,6 @@ from transformers import PretrainedConfig from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.activation import GeluAndMul -from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, @@ -35,8 +33,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from sglang.srt.layers.activation import GeluAndMul +from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -60,7 +61,7 @@ def __init__( bias=False, quant_config=quant_config, ) - self.act_fn = GeluAndMul() + self.act_fn = GeluAndMul("none") def forward(self, x): gate_up, _ = self.gate_up_proj(x) @@ -287,6 +288,7 @@ def __init__( self.quant_config = quant_config self.model = GemmaModel(config, quant_config=quant_config) self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() @torch.no_grad() def forward( @@ -297,9 +299,11 @@ def forward( input_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return (sample_output, logits_output) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index db87624d2d..77ebd8564c 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -22,12 +22,6 @@ from transformers import PretrainedConfig from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size - -# FIXME: temporary solution, remove after next vllm release -from vllm.model_executor.custom_op import CustomOp -from vllm.model_executor.layers.activation import GeluAndMul - -# from vllm.model_executor.layers.layernorm import GemmaRMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, @@ -39,55 +33,18 @@ from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from sglang.srt.layers.activation import GeluAndMul +from sglang.srt.layers.layernorm import GemmaRMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata -class GemmaRMSNorm(CustomOp): - """RMS normalization for Gemma. - - Two differences from the above RMSNorm: - 1. x * (1 + w) instead of x * w. - 2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w. - """ - - def __init__( - self, - hidden_size: int, - eps: float = 1e-6, - ) -> None: - super().__init__() - self.weight = nn.Parameter(torch.zeros(hidden_size)) - self.variance_epsilon = eps - - def forward_native( - self, - x: torch.Tensor, - residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """PyTorch-native implementation equivalent to forward().""" - orig_dtype = x.dtype - if residual is not None: - x = x + residual - residual = x - - x = x.float() - variance = x.pow(2).mean(dim=-1, keepdim=True) - x = x * torch.rsqrt(variance + self.variance_epsilon) - # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16) - # See https://github.com/huggingface/transformers/pull/29402 - x = x * (1.0 + self.weight.float()) - x = x.to(orig_dtype) - return x if residual is None else (x, residual) - - def forward_cuda( - self, - x: torch.Tensor, - residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - # from vLLM: TODO(woosuk): Implement an optimized kernel for GemmaRMSNorm. - return self.forward_native(x, residual) +# Aligned with HF's implementation, using sliding window inclusive with the last token +# SGLang assumes exclusive +def get_attention_sliding_window_size(config): + return config.sliding_window - 1 # FIXME: temporary solution, remove after next vllm release @@ -129,7 +86,7 @@ def __init__( "function. Please set `hidden_act` and `hidden_activation` to " "`gelu_pytorch_tanh`." ) - self.act_fn = GeluAndMul(approximate="tanh") + self.act_fn = GeluAndMul() def forward(self, x: torch.Tensor) -> torch.Tensor: gate_up, _ = self.gate_up_proj(x) @@ -200,17 +157,18 @@ def __init__( dtype=torch.get_default_dtype(), ) - # from vLLM: FIXME(woosuk): While Gemma 2 uses sliding window attention for every - # odd layer, vLLM currently ignores it and uses global attention for - # all layers. - use_sliding_window = layer_idx % 2 == 1 and config.sliding_window is not None - del use_sliding_window # Unused. + use_sliding_window = layer_idx % 2 == 0 and hasattr(config, "sliding_window") self.attn = RadixAttention( self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_idx, + sliding_window_size=( + get_attention_sliding_window_size(config) + if use_sliding_window + else None + ), logit_cap=self.config.attn_logit_softcapping, ) @@ -389,6 +347,7 @@ def __init__( self.quant_config = quant_config self.model = Gemma2Model(config, cache_config, quant_config) self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() @torch.no_grad() def forward( @@ -399,9 +358,14 @@ def forward( input_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output + + def get_attention_sliding_window_size(self): + return get_attention_sliding_window_size(self.config) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/gpt_bigcode.py b/python/sglang/srt/models/gpt_bigcode.py index 9a9e2aec3a..dc828f0142 100644 --- a/python/sglang/srt/models/gpt_bigcode.py +++ b/python/sglang/srt/models/gpt_bigcode.py @@ -23,7 +23,6 @@ from transformers import GPTBigCodeConfig from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import ( ColumnParallelLinear, QKVParallelLinear, @@ -33,8 +32,10 @@ from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from sglang.srt.layers.activation import get_act_fn from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -261,6 +262,7 @@ def __init__( if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() @torch.no_grad() def forward( @@ -270,9 +272,11 @@ def forward( input_metadata: InputMetadata, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, input_metadata) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters(remove_duplicate=False)) diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index 38297b7d6e..3c2a2c65ea 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -16,29 +16,24 @@ # Adapted from # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1 """Inference-only Grok1 model.""" +import warnings from typing import Iterable, List, Optional, Tuple -import numpy as np import torch import torch.nn.functional as F -import tqdm from torch import nn from transformers import PretrainedConfig -from vllm import _custom_ops as ops from vllm.config import CacheConfig from vllm.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce, ) -from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( QKVParallelLinear, ReplicatedLinear, RowParallelLinear, ) from vllm.model_executor.layers.quantization.base_config import QuantizationConfig -from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, @@ -46,140 +41,14 @@ ) from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.utils import set_weight_attrs -from vllm.utils import print_warning_once -from sglang.srt.layers.fused_moe import fused_moe +from sglang.srt.layers.fused_moe import FusedMoE +from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata -use_fused = True - - -class Grok1MLP(nn.Module): - def __init__( - self, - num_experts: int, - hidden_size: int, - intermediate_size: int, - quant_config: Optional[QuantizationConfig] = None, - ) -> None: - super().__init__() - self.num_experts = num_experts - self.ffn_dim = intermediate_size - self.hidden_dim = hidden_size - - self.w1 = ReplicatedLinear( - self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config - ) - self.w2 = ReplicatedLinear( - self.ffn_dim, self.hidden_dim, bias=False, quant_config=quant_config - ) - self.w3 = ReplicatedLinear( - self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config - ) - - self.act_fn = nn.GELU() - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - w1_out, _ = self.w1(hidden_states) - w1_out = self.act_fn(w1_out) - w3_out, _ = self.w3(hidden_states) - current_hidden_states = w1_out * w3_out - current_hidden_states, _ = self.w2(current_hidden_states) - return current_hidden_states - - -class Grok1MoEUnfused(nn.Module): - def __init__( - self, - config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, - ): - super().__init__() - self.config = config - self.rank = get_tensor_model_parallel_rank() - self.tp_size = get_tensor_model_parallel_world_size() - self.num_total_experts = config.num_local_experts - self.top_k = config.num_experts_per_tok - if self.tp_size > self.num_total_experts: - raise ValueError( - f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {self.num_total_experts}." - ) - # Split experts equally between ranks - self.expert_indicies = np.array_split( - range(self.num_total_experts), self.tp_size - )[self.rank].tolist() - if not self.expert_indicies: - raise ValueError(f"Rank {self.rank} has no experts assigned to it.") - - self.experts = nn.ModuleList( - [ - ( - Grok1MLP( - self.num_total_experts, - config.hidden_size, - config.intermediate_size, - quant_config=quant_config, - ) - if idx in self.expert_indicies - else None - ) - for idx in range(self.num_total_experts) - ] - ) - self.gate = ReplicatedLinear( - config.hidden_size, self.num_total_experts, bias=False, quant_config=None - ) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - router_logits, _ = self.gate(hidden_states) - router_logits = 30 * F.tanh(router_logits / 30) - - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk( - routing_weights, self.top_k, dim=-1 - ) - routing_weights = routing_weights.to(hidden_states.dtype) - hidden_dim = hidden_states.shape[1] - - final_hidden_states = torch.zeros( - (hidden_states.shape[0], hidden_dim), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - expert_mask = torch.nn.functional.one_hot( - selected_experts, num_classes=self.num_total_experts - ).permute(2, 1, 0) - - for expert_idx in self.expert_indicies: - expert_layer = self.experts[expert_idx] - idx, top_x = torch.where(expert_mask[expert_idx]) - - if top_x.shape[0] == 0: - continue - - # in torch it is faster to index using lists than torch tensors - top_x_list = top_x.tolist() - idx_list = idx.tolist() - - # Index the correct hidden states and compute the expert hidden state for - # the current expert. We need to make sure to multiply the output hidden - # states by `routing_weights` on the corresponding tokens (top-1 and top-2) - current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) - current_hidden_states = ( - expert_layer(current_state) - * routing_weights[top_x_list, idx_list, None] - ) - - # However `index_add_` only support torch tensors for indexing so we'll use - # the `top_x` tensor here. - final_hidden_states.index_add_(0, top_x, current_hidden_states) - - return tensor_model_parallel_all_reduce(final_hidden_states) - class Grok1MoE(nn.Module): """A tensor-parallel MoE implementation for Grok1 that shards each expert @@ -197,221 +66,42 @@ def __init__( hidden_size: int, intermediate_size: int, params_dtype: Optional[torch.dtype] = None, - tp_size: Optional[int] = None, quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, ): super().__init__() - self.tp_size = tp_size or get_tensor_model_parallel_world_size() - self.num_total_experts = num_experts - self.top_k = top_k self.hidden_size = hidden_size - self.intermediate_size = intermediate_size // self.tp_size - self.quant_config = quant_config - - # FIXME(pcmoritz): Make this more general to support different - # quantization schemes - self.use_fp8 = isinstance(quant_config, Fp8Config) - - if params_dtype is None: - params_dtype = torch.get_default_dtype() - self.params_dtype = params_dtype # Gate always runs at half / full precision for now. self.gate = ReplicatedLinear( - self.hidden_size, - self.num_total_experts, + hidden_size, + num_experts, bias=False, - params_dtype=self.params_dtype, + params_dtype=params_dtype, quant_config=None, ) - if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized: - params_dtype = torch.float8_e4m3fn - - self.w13_weight = nn.Parameter( - torch.empty( - self.num_total_experts, - 2 * self.intermediate_size, - self.hidden_size, - dtype=params_dtype, - ) - ) - self.w2_weight = nn.Parameter( - torch.empty( - self.num_total_experts, - self.hidden_size, - self.intermediate_size, - dtype=params_dtype, - ) - ) - - set_weight_attrs( - self.w13_weight, - { - "weight_loader": self.weight_loader, - }, - ) - set_weight_attrs( - self.w2_weight, - { - "weight_loader": self.weight_loader, - }, + self.experts = FusedMoE( + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=False, + quant_config=quant_config, + tp_size=tp_size, ) - # Used for fp8. - self.w13_scale = None - self.w2_scale = None - self.a13_scale = None - self.a2_scale = None - - if self.use_fp8: - # WEIGHT_SCALE (for fp8) - self.w13_scale = nn.Parameter( - torch.ones(self.num_total_experts, dtype=torch.float32), - requires_grad=False, - ) - self.w2_scale = nn.Parameter( - torch.ones(self.num_total_experts, dtype=torch.float32), - requires_grad=False, - ) - - # If loading fp8 checkpoint, pass the weight loaders. - # If loading an fp16 checkpoint, do not (we will quantize in - # process_weights_after_loading() - if quant_config.is_checkpoint_fp8_serialized: - set_weight_attrs( - self.w13_scale, - { - "weight_loader": self.weight_loader, - }, - ) - set_weight_attrs( - self.w2_scale, - { - "weight_loader": self.weight_loader, - }, - ) - - # ACT_SCALE (for fp8) - if quant_config.activation_scheme == "static": - if not quant_config.is_checkpoint_fp8_serialized: - raise ValueError( - "Found static activation scheme for checkpoint that " - "was not serialized fp8." - ) - self.a13_scale = nn.Parameter( - torch.zeros(self.num_total_experts, dtype=torch.float32), - requires_grad=False, - ) - self.a2_scale = nn.Parameter( - torch.zeros(self.num_total_experts, dtype=torch.float32), - requires_grad=False, - ) - - set_weight_attrs( - self.a13_scale, - { - "weight_loader": self.weight_loader, - }, - ) - set_weight_attrs( - self.a2_scale, - { - "weight_loader": self.weight_loader, - }, - ) - - def weight_loader( - self, - param: nn.Parameter, - loaded_weight: torch.Tensor, - weight_name: str, - expert_id: int, - pre_sharded: bool, - ): - param_data = param.data - shard_size = self.intermediate_size - if pre_sharded: - # The weight is already sharded. Readl the full shard - shard = slice(None) - else: - tp_rank = get_tensor_model_parallel_rank() - shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) - if weight_name.endswith("w1.weight"): - param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] - if weight_name.endswith("w3.weight"): - param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[ - shard, : - ] - if weight_name.endswith("w2.weight"): - param_data[expert_id, :, :] = loaded_weight[:, shard] - if "act_scale" in weight_name or "weight_scale" in weight_name: - param_data[expert_id] = loaded_weight - - def process_weights_after_loading(self): - # Fp8 is the only case where we need to process after loading. - if not self.use_fp8: - return - - # If checkpoint is fp16, quantize here. - if not self.quant_config.is_checkpoint_fp8_serialized: - w13_weight = torch.empty_like( - self.w13_weight.data, dtype=torch.float8_e4m3fn - ) - w2_weight = torch.empty_like(self.w2_weight.data, dtype=torch.float8_e4m3fn) - for expert in range(self.num_total_experts): - w13_weight[expert, :, :], self.w13_scale[expert] = ops.scaled_fp8_quant( - self.w13_weight.data[expert, :, :] - ) - w2_weight[expert, :, :], self.w2_scale[expert] = ops.scaled_fp8_quant( - self.w2_weight.data[expert, :, :] - ) - self.w13_weight = nn.Parameter(w13_weight, requires_grad=False) - self.w2_weight = nn.Parameter(w2_weight, requires_grad=False) - - # If checkpoint is fp8 + static, cleanup act_scales. - # Since state_dict has an act_scale per expert but our kernels - # are passed one act_scale shared across all experts. - elif self.quant_config.activation_scheme == "static": - if self.a13_scale is None or self.a2_scale is None: - raise ValueError( - "QuantConfig has static quantization, but found " - "activation scales are None." - ) - - if not all_close_1d(self.a13_scale) or not all_close_1d(self.a2_scale): - print_warning_once( - "Found act_scales that are not equal for fp8 MoE layer. " - "Using the maximum across experts for each layer. " - ) - - self.a13_scale = nn.Parameter(self.a13_scale.max(), requires_grad=False) - self.a2_scale = nn.Parameter(self.a2_scale.max(), requires_grad=False) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - num_tokens, hidden_size = hidden_states.shape + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape hidden_states = hidden_states.view(-1, self.hidden_size) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = fused_moe( - hidden_states, - self.w13_weight, - self.w2_weight, - router_logits, - self.top_k, - renormalize=False, - inplace=True, - use_fp8=self.use_fp8, - w1_scale=self.w13_scale, - w2_scale=self.w2_scale, - a1_scale=self.a13_scale, - a2_scale=self.a2_scale, - ) - - if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) - - return final_hidden_states.view(num_tokens, hidden_size) + router_logits = 30.0 * F.tanh(router_logits / 30.0) + final_hidden_states = self.experts(hidden_states, router_logits) + return final_hidden_states.view(orig_shape) class Grok1Attention(nn.Module): @@ -478,6 +168,7 @@ def __init__( layer_id=layer_id, logit_cap=logit_cap, ) + # TODO(lianmin): load logit cap from config def forward( self, @@ -502,7 +193,7 @@ def __init__( ) -> None: super().__init__() self.hidden_size = config.hidden_size - # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 10000) self.self_attn = Grok1Attention( hidden_size=self.hidden_size, @@ -513,18 +204,13 @@ def __init__( rope_theta=rope_theta, quant_config=quant_config, ) - if use_fused: - self.block_sparse_moe = Grok1MoE( - num_experts=config.num_local_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - quant_config=quant_config, - ) - else: - self.block_sparse_moe = Grok1MoEUnfused( - config=config, quant_config=quant_config - ) + self.block_sparse_moe = Grok1MoE( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + ) self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -536,6 +222,7 @@ def forward( hidden_states: torch.Tensor, input_metadata: InputMetadata, ) -> torch.Tensor: + # Self Attention hidden_states = ( self.post_attn_norm( self.self_attn( @@ -547,11 +234,11 @@ def forward( + hidden_states ) + # Fully Connected hidden_states = ( self.post_moe_norm(self.block_sparse_moe(self.pre_moe_norm(hidden_states))) + hidden_states ) - return hidden_states @@ -587,19 +274,18 @@ def forward( ) -> torch.Tensor: if input_embeds is None: hidden_states = self.embed_tokens(input_ids) + hidden_states.mul_(self.config.embedding_multiplier_scale) else: hidden_states = input_embeds - hidden_states.mul_(self.config.embedding_multiplier_scale) for i in range(len(self.layers)): hidden_states = self.layers[i](positions, hidden_states, input_metadata) - hidden_states = self.norm(hidden_states) hidden_states.mul_(self.config.output_multiplier_scale) return hidden_states -class Grok1ModelForCausalLM(nn.Module): +class Grok1ForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, @@ -612,11 +298,15 @@ def __init__( self.model = Grok1Model(config, quant_config=quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() # Monkey patch _prepare_weights to load pre-sharded weights setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights) - @torch.no_grad() + self.use_presharded_weights = True + + warnings.filterwarnings("ignore", category=FutureWarning) + def forward( self, input_ids: torch.Tensor, @@ -625,9 +315,11 @@ def forward( input_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ @@ -637,50 +329,17 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "v_proj", "v"), ] - if use_fused: - expert_params_mapping = ( - [ - # These are the weight scales for the experts - # (param_name, weight_name, expert_id) - ( - "w13_scale" if weight_name in ["w1", "w3"] else "w2_scale", - f"experts.{expert_id}.{weight_name}.weight_scale", - expert_id, - ) - for expert_id in range(self.config.num_local_experts) - for weight_name in ["w1", "w2", "w3"] - ] - + [ - # These are the weights for the experts - # (param_name, weight_name, expert_id) - ( - "w13_weight" if weight_name in ["w1", "w3"] else "w2_weight", - f"experts.{expert_id}.{weight_name}.weight", - expert_id, - ) - for expert_id in range(self.config.num_local_experts) - for weight_name in ["w1", "w2", "w3"] - ] - + [ - # These are the activation scales for the experts - # (param_name, weight_name, expert_id) - ( - "a13_scale" if weight_name in ["w1", "w3"] else "a2_scale", - f"experts.{expert_id}.{weight_name}.act_scale", - expert_id, - ) - for expert_id in range(self.config.num_local_experts) - for weight_name in ["w1", "w2", "w3"] - ] - ) - else: - expert_params_mapping = [] + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="w1", + ckpt_down_proj_name="w2", + ckpt_up_proj_name="w3", + num_experts=self.config.num_local_experts, + ) params_dict = dict(self.named_parameters()) - if get_tensor_model_parallel_rank() == 0: - weights = tqdm.tqdm(weights, total=int(len(params_dict) * 3.4)) for name, loaded_weight in weights: - # print(get_tensor_model_parallel_rank(), name) if "rotary_emb.inv_freq" in name: continue @@ -691,29 +350,43 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: - for param_name, weight_name, expert_id in expert_params_mapping: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue name = name.replace(weight_name, param_name) + + if self.use_presharded_weights: + extra_kwargs = { + "use_presharded_weights": self.use_presharded_weights + } + else: + extra_kwargs = {} + param = params_dict[name] weight_loader = param.weight_loader weight_loader( param, loaded_weight, weight_name, + shard_id=shard_id, expert_id=expert_id, - pre_sharded=get_tensor_model_parallel_world_size() > 1, + **extra_kwargs, ) break else: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if name is None: + continue + param = params_dict[name] weight_loader = getattr( param, "weight_loader", default_weight_loader @@ -721,11 +394,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight) -def all_close_1d(x: torch.Tensor) -> bool: - assert len(x.shape) == 1 - return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0])) - - old_prepare_weights = getattr(DefaultModelLoader, "_prepare_weights") @@ -751,4 +419,10 @@ def _prepare_presharded_weights( return hf_folder, hf_weights_files, use_safetensors -EntryClass = Grok1ModelForCausalLM +class Grok1ModelForCausalLM(Grok1ForCausalLM): + """An alias for backward-compatbility.""" + + pass + + +EntryClass = [Grok1ForCausalLM, Grok1ModelForCausalLM] diff --git a/python/sglang/srt/models/internlm2.py b/python/sglang/srt/models/internlm2.py index f2947e991b..c0e4d19e12 100644 --- a/python/sglang/srt/models/internlm2.py +++ b/python/sglang/srt/models/internlm2.py @@ -40,6 +40,7 @@ from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -262,6 +263,7 @@ def __init__( self.model = InternLM2Model(config, quant_config) self.output = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() @torch.no_grad() def forward( @@ -272,9 +274,11 @@ def forward( input_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.output.weight, input_metadata ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/llama2.py b/python/sglang/srt/models/llama2.py index 9de8d33c5c..22751d9b67 100644 --- a/python/sglang/srt/models/llama2.py +++ b/python/sglang/srt/models/llama2.py @@ -39,8 +39,9 @@ from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm -from sglang.srt.layers.logits_processor import LogitProcessorOutput, LogitsProcessor +from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -302,6 +303,7 @@ def __init__( self.model = LlamaModel(config, quant_config=quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() @torch.no_grad() def forward( @@ -310,11 +312,13 @@ def forward( positions: torch.Tensor, input_metadata: InputMetadata, input_embeds: torch.Tensor = None, - ) -> LogitProcessorOutput: + ) -> LogitsProcessorOutput: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def get_module_name(self, name): stacked_params_mapping = [ @@ -357,6 +361,9 @@ def load_weights_per_param(name, loaded_weight): # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. return + if name.startswith("model.vision_tower") and name not in params_dict: + return + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue @@ -364,8 +371,6 @@ def load_weights_per_param(name, loaded_weight): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - if name.startswith("model.vision_tower") and name not in params_dict: - continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -374,8 +379,6 @@ def load_weights_per_param(name, loaded_weight): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: return - if name.startswith("model.vision_tower") and name not in params_dict: - return param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) diff --git a/python/sglang/srt/models/llama_classification.py b/python/sglang/srt/models/llama_classification.py index 02224971d6..03ab5e802c 100644 --- a/python/sglang/srt/models/llama_classification.py +++ b/python/sglang/srt/models/llama_classification.py @@ -24,7 +24,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from sglang.srt.layers.logits_processor import LogitProcessorOutput +from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.models.llama2 import LlamaModel @@ -65,7 +65,7 @@ def forward( (input_metadata.batch_size, self.config.classification_out_size) ).to(input_ids.device) - return LogitProcessorOutput( + return LogitsProcessorOutput( next_token_logits=scores, next_token_logprobs=scores, normalized_prompt_logprobs=scores, @@ -103,8 +103,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - if name.startswith("model.vision_tower") and name not in params_dict: - continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -113,8 +111,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - if name.startswith("model.vision_tower") and name not in params_dict: - continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) diff --git a/python/sglang/srt/models/llama_embedding.py b/python/sglang/srt/models/llama_embedding.py index e8e6780472..e4e9174f14 100644 --- a/python/sglang/srt/models/llama_embedding.py +++ b/python/sglang/srt/models/llama_embedding.py @@ -29,7 +29,11 @@ def forward( positions: torch.Tensor, input_metadata: InputMetadata, input_embeds: torch.Tensor = None, + get_embedding: bool = True, ) -> EmbeddingPoolerOutput: + assert ( + get_embedding + ), "LlamaEmbeddingModel / MistralModel is only used for embedding" hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) return self.pooler(hidden_states, input_metadata) @@ -53,6 +57,9 @@ def load_weights_per_param(name, loaded_weight): # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. return + if name.startswith("model.vision_tower") and name not in params_dict: + return + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue @@ -60,8 +67,6 @@ def load_weights_per_param(name, loaded_weight): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - if name.startswith("model.vision_tower") and name not in params_dict: - continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -70,8 +75,6 @@ def load_weights_per_param(name, loaded_weight): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: return - if name.startswith("model.vision_tower") and name not in params_dict: - return param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index a885a6e595..7dcf5348b0 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -15,6 +15,8 @@ """Inference-only LLaVa model compatible with HuggingFace weights.""" +import math +import re from typing import Iterable, List, Optional, Tuple import numpy as np @@ -26,6 +28,7 @@ LlavaConfig, MistralConfig, Qwen2Config, + SiglipVisionModel, ) from transformers.models.llava.modeling_llava import LlavaMultiModalProjector from vllm.config import CacheConfig @@ -43,54 +46,68 @@ from sglang.srt.models.qwen2 import Qwen2ForCausalLM -class LlavaLlamaForCausalLM(nn.Module): - def __init__( +class LlavaBaseForCausalLM(nn.Module): + def pad_input_ids( self, - config: LlavaConfig, - quant_config: Optional[QuantizationConfig] = None, - cache_config: Optional[CacheConfig] = None, - ) -> None: - super().__init__() - self.config = config - self.vision_tower = None - self.config.vision_config.hidden_size = config.mm_hidden_size - self.config.text_config.hidden_size = config.hidden_size - self.multi_modal_projector = LlavaMultiModalProjector(config) - self.language_model = LlamaForCausalLM(config, quant_config=quant_config) - if "unpad" in getattr(config, "mm_patch_merge_type", ""): - self.language_model.model.image_newline = nn.Parameter( - torch.empty(config.text_config.hidden_size, dtype=torch.float16) - ) + input_ids: List[int], + pad_value: List[int], + pixel_values: List, + image_sizes: List[List[int]], + ): + # hardcode for spatial_unpad + anyres + image_aspect_ratio = "anyres" if len(image_sizes) == 1 else "pad" + offset_list = [] + for image_s in image_sizes: + if len(image_sizes) > 16: + # 2x2 pooling with stride 2 + new_image_feature_len = ( + math.ceil(self.image_size / self.patch_size / 2) ** 2 + ) + else: + new_image_feature_len = self.image_feature_len # multiimage - def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None): - new_image_feature_len = self.image_feature_len - # now only support spatial_unpad + anyres - if self.mm_patch_merge_type.startswith("spatial"): height = width = self.num_patches_per_side - if pt_shape[0] > 1: - if self.image_aspect_ratio == "anyres": - num_patch_width, num_patch_height = get_anyres_image_grid_shape( - image_size, - self.image_grid_pinpoints, - self.vision_tower.config.image_size, + if "anyres" in image_aspect_ratio: + num_patch_width, num_patch_height = get_anyres_image_grid_shape( + image_s, + self.image_grid_pinpoints, + self.vision_tower.config.image_size, + ) + h = num_patch_height * height + w = num_patch_width * width + new_h, new_w = unpad_image_shape(h, w, image_s) + + if "anyres_max" in self.config.image_aspect_ratio: + matched_anyres_max_num_patches = re.match( + r"anyres_max_(\d+)", self.config.image_aspect_ratio ) - if "unpad" in self.mm_patch_merge_type: - h = num_patch_height * height - w = num_patch_width * width - new_h, new_w = unpad_image_shape(h, w, image_size) - new_image_feature_len += new_h * (new_w + 1) - - pad_ids = pad_value * ( - (new_image_feature_len + len(pad_value)) // len(pad_value) - ) - offset = input_ids.index(self.config.image_token_index) - # old_len + pad_len - 1, because we need to remove image_token_id - new_input_ids = ( - input_ids[:offset] - + pad_ids[:new_image_feature_len] - + input_ids[offset + 1 :] - ) - return new_input_ids, offset + if matched_anyres_max_num_patches: + max_num_patches = int(matched_anyres_max_num_patches.group(1)) + # times = math.sqrt(h * w / (max_num_patches * unit**2)) + times = math.sqrt( + new_h * new_w / (max_num_patches * self.image_feature_len) + ) + if times > 1.1: + new_h = int(new_h // times) + new_w = int(new_w // times) + new_image_feature_len += new_h * (new_w + 1) + + pad_ids = pad_value * ( + (new_image_feature_len + len(pad_value)) // len(pad_value) + ) + # print("calculated new_image_feature_len: ", new_image_feature_len) + try: + offset = input_ids.index(self.config.image_token_index) + except ValueError: + offset = 0 + # old_len + pad_len - 1, because we need to remove image_token_id + input_ids = ( + input_ids[:offset] + + pad_ids[:new_image_feature_len] + + input_ids[offset + 1 :] + ) + offset_list.append(offset) + return input_ids, offset_list def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor: image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) @@ -122,18 +139,15 @@ def forward( if input_metadata.forward_mode == ForwardMode.EXTEND: bs = input_metadata.batch_size - # Embed text input + # Embed text inputs input_embeds = self.language_model.model.embed_tokens(input_ids) - # Embed vision input - need_vision = ( - (positions[input_metadata.extend_start_loc] < self.image_feature_len) - .cpu() - .numpy() + # Whether the requests need vision inputs + max_image_offset = np.array( + [max(image_offsets[i]) if image_offsets[i] else -1 for i in range(bs)] ) - # FIXME: We need to substract the length of the system prompt - has_pixel = np.array([pixel_values[i] is not None for i in range(bs)]) - need_vision = need_vision & has_pixel + start_positions = positions[input_metadata.extend_start_loc].cpu().numpy() + need_vision = start_positions <= max_image_offset if need_vision.any(): pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]] @@ -163,27 +177,73 @@ def forward( if self.mm_patch_merge_type.startswith("spatial"): new_image_features = [] + height = width = self.num_patches_per_side for image_idx, image_feature in enumerate(image_features): - if image_feature.shape[0] > 1: + if len(image_sizes[image_idx]) == 1: + image_aspect_ratio = ( + self.config.image_aspect_ratio + ) # single image + else: + image_aspect_ratio = "pad" # multi image + # image_aspect_ratio = ( + # "anyres" if len(image_sizes[image_idx]) == 1 else "pad" + # ) + if ( + image_feature.shape[0] > 1 + and "anyres" in image_aspect_ratio + ): base_image_feature = image_feature[0] image_feature = image_feature[1:] - height = width = self.num_patches_per_side assert height * width == base_image_feature.shape[0] - if self.image_aspect_ratio == "anyres": - ( - num_patch_width, - num_patch_height, - ) = get_anyres_image_grid_shape( - image_sizes[image_idx], - self.image_grid_pinpoints, - self.vision_tower.config.image_size, + + if "anyres_max" in image_aspect_ratio: + matched_anyres_max_num_patches = re.match( + r"anyres_max_(\d+)", image_aspect_ratio ) + if matched_anyres_max_num_patches: + max_num_patches = int( + matched_anyres_max_num_patches.group(1) + ) + + if ( + image_aspect_ratio == "anyres" + or "anyres_max" in image_aspect_ratio + ): + vision_tower_image_size = self.image_size + try: + num_patch_width, num_patch_height = ( + get_anyres_image_grid_shape( + image_sizes[image_idx][0], + self.config.image_grid_pinpoints, + vision_tower_image_size, + ) + ) + except Exception as e: + print(f"Error: {e}") + num_patch_width, num_patch_height = 2, 2 image_feature = image_feature.view( num_patch_height, num_patch_width, height, width, -1 ) else: - raise NotImplementedError() + image_feature = image_feature.view( + 2, 2, height, width, -1 + ) + + # ( + # num_patch_width, + # num_patch_height, + # ) = get_anyres_image_grid_shape( + # image_sizes[image_idx][0], + # self.image_grid_pinpoints, + # self.vision_tower.config.image_size, + # ) + + # image_feature = image_feature.view( + # num_patch_height, num_patch_width, height, width, -1 + # ) + if "unpad" in self.mm_patch_merge_type: + unit = image_feature.shape[2] image_feature = image_feature.permute( 4, 0, 2, 1, 3 ).contiguous() @@ -191,8 +251,23 @@ def forward( 2, 3 ) image_feature = unpad_image( - image_feature, image_sizes[image_idx] + image_feature, image_sizes[image_idx][0] ) + if ( + "anyres_max" in image_aspect_ratio + and matched_anyres_max_num_patches + ): + c, h, w = image_feature.shape + times = math.sqrt( + h * w / (max_num_patches * unit**2) + ) + if times > 1.1: + image_feature = image_feature[None] + image_feature = nn.functional.interpolate( + image_feature, + [int(h // times), int(w // times)], + mode="bilinear", + )[0] image_feature = torch.cat( ( image_feature, @@ -213,43 +288,63 @@ def forward( image_feature = torch.cat( (base_image_feature, image_feature), dim=0 ) + image_feature = image_feature.unsqueeze(0) else: - image_feature = image_feature[0] - if "unpad" in self.mm_patch_merge_type: - image_feature = torch.cat( - ( - image_feature, - self.language_model.model.image_newline[None], - ), - dim=0, + if image_feature.shape[0] > 16: # video + # 2x2 pooling + num_of_frames = image_feature.shape[0] + image_feature = image_feature.view( + num_of_frames, height, width, -1 + ) + image_feature = image_feature.permute( + 0, 3, 1, 2 + ).contiguous() # N, C, H, W + height, weight = image_feature.shape[2:] + scaled_shape = [ + math.ceil(height / 2), + math.ceil(weight / 2), + ] + image_feature = nn.functional.interpolate( + image_feature, size=scaled_shape, mode="bilinear" ) + image_feature = ( + image_feature.flatten(2) + .transpose(1, 2) + .contiguous() + ) # N, C, H*W + new_image_features.append(image_feature) image_features = new_image_features + # Fill in the placeholder for the image extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy() + prefix_lens_cpu = input_metadata.extend_prefix_lens.cpu().numpy() pt = 0 for i in range(bs): if not need_vision[i]: continue start_idx = extend_start_loc_cpu[i] - pad_len, pad_dim = image_features[pt].shape # 576, 4096 - dim = input_embeds.shape[1] - assert ( - pad_dim == dim - ), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim) - # Fill in the placeholder for the image - try: - input_embeds[ - start_idx - + image_offsets[i] : start_idx - + image_offsets[i] - + pad_len - ] = image_features[pt] - except RuntimeError as e: - print(f"RuntimeError in llava image encoding: {e}") - print(input_embeds.shape) - print(start_idx, image_offsets[i]) + prefix_len = prefix_lens_cpu[i] + + # Multiple images + for j, image_offset in enumerate(image_offsets[i]): + if image_offset < prefix_len: + continue + + tmp_image_feature = image_features[pt][j] + pad_len = tmp_image_feature.shape[0] + + left_idx = start_idx + (image_offset - prefix_len) + right_idx = start_idx + (image_offset - prefix_len) + pad_len + try: + input_embeds[left_idx:right_idx] = tmp_image_feature + except RuntimeError as e: + print(f"RuntimeError in image encoding: {e}") + print(f"{input_embeds.shape=}, {tmp_image_feature.shape=}") + print( + f"{start_idx=}, {image_offset=}, {prefix_len=}, {pad_len=}" + ) pt += 1 return self.language_model( @@ -259,12 +354,20 @@ def forward( return self.language_model(input_ids, positions, input_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - # load clip vision model by cfg['mm_vision_tower']: - # huggingface_name or path_of_clip_relative_to_llava_model_dir + # Load clip vision model by cfg['mm_vision_tower']: + # huggingface_name or path_of_clip_relative_to_llava_model_dir + # We put the initialization here instead of __init__ to allow it being reused by other subclasses. vision_path = self.config.mm_vision_tower - self.vision_tower = CLIPVisionModel.from_pretrained( - vision_path, torch_dtype=torch.float16 - ).cuda() + if "clip" in vision_path: + self.vision_tower = CLIPVisionModel.from_pretrained( + vision_path, torch_dtype=torch.float16 + ).cuda() + elif "siglip" in vision_path: + self.vision_tower = SiglipVisionModel.from_pretrained( + vision_path, torch_dtype=torch.float16 + ).cuda() + # Siglip needs all feature tokens + self.config.mm_vision_select_feature = "full" self.vision_tower.eval() self.vision_feature_layer = self.config.mm_vision_select_layer @@ -276,8 +379,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square") self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None) - self.image_feature_len = int((self.image_size / self.patch_size) ** 2) - if self.vision_feature_select_strategy == "patch": + self.image_feature_len = int((self.image_size // self.patch_size) ** 2) + if ( + self.vision_feature_select_strategy == "patch" + or self.vision_feature_select_strategy == "full" + ): pass elif self.vision_feature_select_strategy == "cls_patch": self.image_feature_len += 1 @@ -305,21 +411,41 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # load language model self.language_model.load_weights(weights) - monkey_path_clip_vision_embed_forward() - @property def num_patches_per_side(self): return self.image_size // self.patch_size -class LlavaQwenForCausalLM(LlavaLlamaForCausalLM): +class LlavaLlamaForCausalLM(LlavaBaseForCausalLM): + def __init__( + self, + config: LlavaConfig, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + ) -> None: + super().__init__() + + self.config = config + self.vision_tower = None + self.config.vision_config.hidden_size = config.mm_hidden_size + self.config.text_config.hidden_size = config.hidden_size + self.multi_modal_projector = LlavaMultiModalProjector(config) + self.language_model = LlamaForCausalLM(config, quant_config=quant_config) + if "unpad" in getattr(config, "mm_patch_merge_type", ""): + self.language_model.model.image_newline = nn.Parameter( + torch.empty(config.text_config.hidden_size, dtype=torch.float16) + ) + + +class LlavaQwenForCausalLM(LlavaBaseForCausalLM): def __init__( self, config: LlavaConfig, quant_config: Optional[QuantizationConfig] = None, cache_config: Optional[CacheConfig] = None, ) -> None: - super().__init__(config, quant_config=quant_config, cache_config=cache_config) + super().__init__() + self.config = config self.vision_tower = None if getattr(self.config, "vision_config", None) is None: @@ -345,14 +471,15 @@ def __init__( ) -class LlavaMistralForCausalLM(LlavaLlamaForCausalLM): +class LlavaMistralForCausalLM(LlavaBaseForCausalLM): def __init__( self, config: LlavaConfig, quant_config: Optional[QuantizationConfig] = None, cache_config: Optional[CacheConfig] = None, ) -> None: - super().__init__(config, quant_config=quant_config, cache_config=cache_config) + super().__init__() + self.config = config self.vision_tower = None if getattr(self.config, "vision_config", None) is None: @@ -378,36 +505,4 @@ def __init__( ) -first_call = True - - -def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: - batch_size = pixel_values.shape[0] - - # Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G. - global first_call - if first_call: - self.patch_embedding.cpu().float() - first_call = False - pixel_values = pixel_values.to(dtype=torch.float32, device="cpu") - patch_embeds = self.patch_embedding(pixel_values).cuda().half() - - patch_embeds = patch_embeds.flatten(2).transpose(1, 2) - - class_embeds = self.class_embedding.expand(batch_size, 1, -1) - embeddings = torch.cat([class_embeds, patch_embeds], dim=1) - embeddings = embeddings + self.position_embedding(self.position_ids) - return embeddings - - -def monkey_path_clip_vision_embed_forward(): - import transformers - - setattr( - transformers.models.clip.modeling_clip.CLIPVisionEmbeddings, - "forward", - clip_vision_embed_forward, - ) - - EntryClass = [LlavaLlamaForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM] diff --git a/python/sglang/srt/models/llavavid.py b/python/sglang/srt/models/llavavid.py index 8b81251d69..44e400ff6a 100644 --- a/python/sglang/srt/models/llavavid.py +++ b/python/sglang/srt/models/llavavid.py @@ -26,11 +26,6 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from sglang.srt.mm_utils import ( - get_anyres_image_grid_shape, - unpad_image, - unpad_image_shape, -) from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata from sglang.srt.models.llama2 import LlamaForCausalLM @@ -59,23 +54,14 @@ def __init__( torch.empty(config.text_config.hidden_size, dtype=torch.float16) ) - def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None): + def pad_input_ids( + self, + input_ids: List[int], + pad_value: List[int], + pixel_values: List, + image_sizes: List[List[int]], + ): new_image_feature_len = self.image_feature_len - # now only support spatial_unpad + anyres - # if self.mm_patch_merge_type.startswith("spatial"): - # height = width = self.num_patches_per_side - # if pt_shape[0] > 1: - # if self.image_aspect_ratio == "anyres": - # num_patch_width, num_patch_height = get_anyres_image_grid_shape( - # image_size, - # self.image_grid_pinpoints, - # self.vision_tower.config.image_size, - # ) - # if "unpad" in self.mm_patch_merge_type: - # h = num_patch_height * height - # w = num_patch_width * width - # new_h, new_w = unpad_image_shape(h, w, image_size) - # new_image_feature_len += new_h * (new_w + 1) pad_ids = pad_value * ( (new_image_feature_len + len(pad_value)) // len(pad_value) @@ -87,7 +73,7 @@ def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None): + pad_ids[:new_image_feature_len] + input_ids[offset + 1 :] ) - return new_input_ids, offset + return new_input_ids, [offset] def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor: image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) @@ -133,22 +119,18 @@ def forward( if input_metadata.forward_mode == ForwardMode.EXTEND: bs = input_metadata.batch_size - # Embed text input + # Embed text inputs input_embeds = self.language_model.model.embed_tokens(input_ids) - # Embed vision input - need_vision = ( - (positions[input_metadata.extend_start_loc] < self.image_feature_len) - .cpu() - .numpy() + # Whether the requests need vision inputs + max_image_offset = np.array( + [max(image_offsets[i]) if image_offsets[i] else -1 for i in range(bs)] ) - # FIXME: We need to substract the length of the system prompt - has_pixel = np.array([pixel_values[i] is not None for i in range(bs)]) - need_vision = need_vision & has_pixel + start_positions = positions[input_metadata.extend_start_loc].cpu().numpy() + need_vision = start_positions <= max_image_offset if need_vision.any(): pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]] - image_sizes = [image_sizes[i] for i in range(bs) if need_vision[i]] ########## Encode Image ######## @@ -183,31 +165,36 @@ def forward( new_image_features.append(image_feature.flatten(0, 1)) image_features = new_image_features + # Fill in the placeholder for the image extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy() + prefix_lens_cpu = input_metadata.extend_prefix_lens.cpu().numpy() pt = 0 for i in range(bs): if not need_vision[i]: continue start_idx = extend_start_loc_cpu[i] - pad_len, pad_dim = image_features[pt].shape # 576, 4096 - dim = input_embeds.shape[1] - assert ( - pad_dim == dim - ), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim) - # Fill in the placeholder for the image - try: - input_embeds[ - start_idx - + image_offsets[i] : start_idx - + image_offsets[i] - + pad_len - ] = image_features[pt] - except RuntimeError as e: - print(f"RuntimeError in llava image encoding: {e}") - print(input_embeds.shape) - print(start_idx, image_offsets[i]) - pt += 1 + prefix_len = prefix_lens_cpu[i] + + # Multiple images + for image_offset in image_offsets[i]: + if image_offset < prefix_len: + continue + + tmp_image_feature = image_features[pt] + pad_len = tmp_image_feature.shape[0] + + left_idx = start_idx + (image_offset - prefix_len) + right_idx = start_idx + (image_offset - prefix_len) + pad_len + try: + input_embeds[left_idx:right_idx] = tmp_image_feature + except RuntimeError as e: + print(f"RuntimeError in image encoding: {e}") + print(f"{input_embeds.shape=}, {tmp_image_feature.shape=}") + print( + f"{start_idx=}, {image_offset=}, {prefix_len=}, {pad_len=}" + ) + pt += 1 return self.language_model( input_ids, positions, input_metadata, input_embeds=input_embeds @@ -216,8 +203,9 @@ def forward( return self.language_model(input_ids, positions, input_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - # load clip vision model by cfg['mm_vision_tower']: - # huggingface_name or path_of_clip_relative_to_llava_model_dir + # Load clip vision model by cfg['mm_vision_tower']: + # huggingface_name or path_of_clip_relative_to_llava_model_dir + # We put the initialization here instead of __init__ to allow it being reused by other subclasses. vision_path = self.config.mm_vision_tower self.vision_tower = CLIPVisionModel.from_pretrained( vision_path, torch_dtype=torch.float16 @@ -271,43 +259,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # load language model self.language_model.load_weights(weights) - monkey_path_clip_vision_embed_forward() - @property def num_patches_per_side(self): return self.image_size // self.patch_size -first_call = True - - -def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: - batch_size = pixel_values.shape[0] - - # Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G. - global first_call - if first_call: - self.patch_embedding.cpu().float() - first_call = False - pixel_values = pixel_values.to(dtype=torch.float32, device="cpu") - patch_embeds = self.patch_embedding(pixel_values).cuda().half() - - patch_embeds = patch_embeds.flatten(2).transpose(1, 2) - - class_embeds = self.class_embedding.expand(batch_size, 1, -1) - embeddings = torch.cat([class_embeds, patch_embeds], dim=1) - embeddings = embeddings + self.position_embedding(self.position_ids) - return embeddings - - -def monkey_path_clip_vision_embed_forward(): - import transformers - - setattr( - transformers.models.clip.modeling_clip.CLIPVisionEmbeddings, - "forward", - clip_vision_embed_forward, - ) - - EntryClass = LlavaVidForCausalLM diff --git a/python/sglang/srt/models/minicpm.py b/python/sglang/srt/models/minicpm.py index bf572855e6..0028ae67a8 100644 --- a/python/sglang/srt/models/minicpm.py +++ b/python/sglang/srt/models/minicpm.py @@ -22,8 +22,6 @@ from torch import nn from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, @@ -37,8 +35,11 @@ ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -297,6 +298,7 @@ def __init__( self.scale_width = self.config.hidden_size / self.config.dim_model_base self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() @torch.no_grad() def forward( @@ -314,9 +316,11 @@ def forward( lm_head_weight = self.model.embed_tokens.weight else: lm_head_weight = self.lm_head.weight - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, lm_head_weight, input_metadata ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index 63053ac50b..ca38cb03ba 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -18,38 +18,30 @@ """Inference-only Mixtral model.""" from typing import Iterable, Optional, Tuple -import numpy as np import torch -import torch.nn.functional as F from torch import nn from transformers import MixtralConfig -from vllm import _custom_ops as ops from vllm.config import CacheConfig -from vllm.distributed import ( - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce, -) -from vllm.model_executor.layers.fused_moe import fused_moe -from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import ( QKVParallelLinear, ReplicatedLinear, RowParallelLinear, ) from vllm.model_executor.layers.quantization.base_config import QuantizationConfig -from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.utils import set_weight_attrs -from vllm.utils import print_warning_once +from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -69,216 +61,44 @@ def __init__( hidden_size: int, intermediate_size: int, params_dtype: Optional[torch.dtype] = None, - tp_size: Optional[int] = None, quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + prefix: str = "", ): super().__init__() - self.tp_size = tp_size or get_tensor_model_parallel_world_size() - self.num_total_experts = num_experts - self.top_k = top_k self.hidden_size = hidden_size - self.intermediate_size = intermediate_size // self.tp_size - self.quant_config = quant_config - - # FIXME(pcmoritz): Make this more general to support different - # quantization schemes - self.use_fp8 = isinstance(quant_config, Fp8Config) - - if params_dtype is None: - params_dtype = torch.get_default_dtype() - self.params_dtype = params_dtype # Gate always runs at half / full precision for now. self.gate = ReplicatedLinear( - self.hidden_size, - self.num_total_experts, + hidden_size, + num_experts, bias=False, - params_dtype=self.params_dtype, + params_dtype=params_dtype, quant_config=None, + prefix=f"{prefix}.gate", ) - if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized: - params_dtype = torch.float8_e4m3fn - - self.w13_weight = nn.Parameter( - torch.empty( - self.num_total_experts, - 2 * self.intermediate_size, - self.hidden_size, - dtype=params_dtype, - ) - ) - self.w2_weight = nn.Parameter( - torch.empty( - self.num_total_experts, - self.hidden_size, - self.intermediate_size, - dtype=params_dtype, - ) - ) - - set_weight_attrs( - self.w13_weight, - { - "weight_loader": self.weight_loader, - }, - ) - set_weight_attrs( - self.w2_weight, - { - "weight_loader": self.weight_loader, - }, + self.experts = FusedMoE( + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + tp_size=tp_size, + prefix=f"{prefix}.experts", ) - # Used for fp8. - self.w13_scale = None - self.w2_scale = None - self.a13_scale = None - self.a2_scale = None - - if self.use_fp8: - # WEIGHT_SCALE (for fp8) - self.w13_scale = nn.Parameter( - torch.ones(self.num_total_experts, dtype=torch.float32), - requires_grad=False, - ) - self.w2_scale = nn.Parameter( - torch.ones(self.num_total_experts, dtype=torch.float32), - requires_grad=False, - ) - - # If loading fp8 checkpoint, pass the weight loaders. - # If loading an fp16 checkpoint, do not (we will quantize in - # process_weights_after_loading() - if quant_config.is_checkpoint_fp8_serialized: - set_weight_attrs( - self.w13_scale, - { - "weight_loader": self.weight_loader, - }, - ) - set_weight_attrs( - self.w2_scale, - { - "weight_loader": self.weight_loader, - }, - ) - - # ACT_SCALE (for fp8) - if quant_config.activation_scheme == "static": - if not quant_config.is_checkpoint_fp8_serialized: - raise ValueError( - "Found static activation scheme for checkpoint that " - "was not serialized fp8." - ) - self.a13_scale = nn.Parameter( - torch.zeros(self.num_total_experts, dtype=torch.float32), - requires_grad=False, - ) - self.a2_scale = nn.Parameter( - torch.zeros(self.num_total_experts, dtype=torch.float32), - requires_grad=False, - ) - - set_weight_attrs( - self.a13_scale, - { - "weight_loader": self.weight_loader, - }, - ) - set_weight_attrs( - self.a2_scale, - { - "weight_loader": self.weight_loader, - }, - ) - - def weight_loader( - self, - param: nn.Parameter, - loaded_weight: torch.Tensor, - weight_name: str, - expert_id: int, - ): - tp_rank = get_tensor_model_parallel_rank() - param_data = param.data - shard_size = self.intermediate_size - shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) - if weight_name.endswith("w1.weight"): - param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] - if weight_name.endswith("w3.weight"): - param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[ - shard, : - ] - if weight_name.endswith("w2.weight"): - param_data[expert_id, :, :] = loaded_weight[:, shard] - if "act_scale" in weight_name or "weight_scale" in weight_name: - param_data[expert_id] = loaded_weight - - def process_weights_after_loading(self): - # Fp8 is the only case where we need to process after loading. - if not self.use_fp8: - return - - # If checkpoint is fp16, quantize here. - if not self.quant_config.is_checkpoint_fp8_serialized: - w13_weight = torch.empty_like( - self.w13_weight.data, dtype=torch.float8_e4m3fn - ) - w2_weight = torch.empty_like(self.w2_weight.data, dtype=torch.float8_e4m3fn) - for expert in range(self.num_total_experts): - w13_weight[expert, :, :], self.w13_scale[expert] = ops.scaled_fp8_quant( - self.w13_weight.data[expert, :, :] - ) - w2_weight[expert, :, :], self.w2_scale[expert] = ops.scaled_fp8_quant( - self.w2_weight.data[expert, :, :] - ) - self.w13_weight = nn.Parameter(w13_weight, requires_grad=False) - self.w2_weight = nn.Parameter(w2_weight, requires_grad=False) - - # If checkpoint is fp8 + static, cleanup act_scales. - # Since state_dict has an act_scale per expert but our kernels - # are passed one act_scale shared across all experts. - elif self.quant_config.activation_scheme == "static": - if self.a13_scale is None or self.a2_scale is None: - raise ValueError( - "QuantConfig has static quantization, but found " - "activation scales are None." - ) - - if not all_close_1d(self.a13_scale) or not all_close_1d(self.a2_scale): - print_warning_once( - "Found act_scales that are not equal for fp8 MoE layer. " - "Using the maximum across experts for each layer. " - ) - - self.a13_scale = nn.Parameter(self.a13_scale.max(), requires_grad=False) - self.a2_scale = nn.Parameter(self.a2_scale.max(), requires_grad=False) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - num_tokens, hidden_size = hidden_states.shape + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape hidden_states = hidden_states.view(-1, self.hidden_size) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = fused_moe( - hidden_states, - self.w13_weight, - self.w2_weight, - router_logits, - self.top_k, - renormalize=True, - inplace=True, - use_fp8=self.use_fp8, - w1_scale=self.w13_scale, - w2_scale=self.w2_scale, - a1_scale=self.a13_scale, - a2_scale=self.a2_scale, - ) - - if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) - - return final_hidden_states.view(num_tokens, hidden_size) + final_hidden_states = self.experts(hidden_states, router_logits) + return final_hidden_states.view(orig_shape) class MixtralAttention(nn.Module): @@ -291,7 +111,7 @@ def __init__( max_position: int = 4096 * 32, rope_theta: float = 10000, quant_config: Optional[QuantizationConfig] = None, - sliding_window: Optional[int] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = hidden_size @@ -314,7 +134,6 @@ def __init__( self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta - self.sliding_window = sliding_window self.qkv_proj = QKVParallelLinear( hidden_size, @@ -323,12 +142,14 @@ def __init__( self.total_num_kv_heads, bias=False, quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, quant_config=quant_config, + prefix=f"{prefix}.o_proj", ) self.rotary_emb = get_rope( self.head_dim, @@ -365,6 +186,7 @@ def __init__( config: MixtralConfig, layer_id: int = 0, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -377,8 +199,8 @@ def __init__( num_kv_heads=config.num_key_value_heads, layer_id=layer_id, rope_theta=rope_theta, - sliding_window=config.sliding_window, quant_config=quant_config, + prefix=f"{prefix}.self_attn", ) self.block_sparse_moe = MixtralMoE( num_experts=config.num_local_experts, @@ -386,6 +208,7 @@ def __init__( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, quant_config=quant_config, + prefix=f"{prefix}.block_sparse_moe", ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( @@ -422,6 +245,7 @@ def __init__( self, config: MixtralConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.padding_idx = config.pad_token_id @@ -431,10 +255,11 @@ def __init__( config.vocab_size, config.hidden_size, ) - # config.num_hidden_layers=16 self.layers = nn.ModuleList( [ - MixtralDecoderLayer(config, i, quant_config=quant_config) + MixtralDecoderLayer( + config, i, quant_config=quant_config, prefix=f"{prefix}.layers" + ) for i in range(config.num_hidden_layers) ] ) @@ -462,6 +287,7 @@ def forward( class MixtralForCausalLM(nn.Module): + def __init__( self, config: MixtralConfig, @@ -471,11 +297,11 @@ def __init__( super().__init__() self.config = config self.quant_config = quant_config - self.model = MixtralModel(config, quant_config=quant_config) + self.model = MixtralModel(config, quant_config=quant_config, prefix="model") self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() - @torch.no_grad() def forward( self, input_ids: torch.Tensor, @@ -484,9 +310,11 @@ def forward( input_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ @@ -496,40 +324,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "v_proj", "v"), ] - expert_params_mapping = ( - [ - # These are the weight scales for the experts - # (param_name, weight_name, expert_id) - ( - "w13_scale" if weight_name in ["w1", "w3"] else "w2_scale", - f"experts.{expert_id}.{weight_name}.weight_scale", - expert_id, - ) - for expert_id in range(self.config.num_local_experts) - for weight_name in ["w1", "w2", "w3"] - ] - + [ - # These are the weights for the experts - # (param_name, weight_name, expert_id) - ( - "w13_weight" if weight_name in ["w1", "w3"] else "w2_weight", - f"experts.{expert_id}.{weight_name}.weight", - expert_id, - ) - for expert_id in range(self.config.num_local_experts) - for weight_name in ["w1", "w2", "w3"] - ] - + [ - # These are the activation scales for the experts - # (param_name, weight_name, expert_id) - ( - "a13_scale" if weight_name in ["w1", "w3"] else "a2_scale", - f"experts.{expert_id}.{weight_name}.act_scale", - expert_id, - ) - for expert_id in range(self.config.num_local_experts) - for weight_name in ["w1", "w2", "w3"] - ] + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="w1", + ckpt_down_proj_name="w2", + ckpt_up_proj_name="w3", + num_experts=self.config.num_local_experts, ) params_dict = dict(self.named_parameters()) @@ -544,25 +345,35 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: - for param_name, weight_name, expert_id in expert_params_mapping: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue name = name.replace(weight_name, param_name) + param = params_dict[name] weight_loader = param.weight_loader weight_loader( - param, loaded_weight, weight_name, expert_id=expert_id + param, + loaded_weight, + weight_name, + shard_id=shard_id, + expert_id=expert_id, ) break else: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if name is None: + continue + param = params_dict[name] weight_loader = getattr( param, "weight_loader", default_weight_loader @@ -570,9 +381,4 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight) -def all_close_1d(x: torch.Tensor) -> bool: - assert len(x.shape) == 1 - return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0])) - - EntryClass = MixtralForCausalLM diff --git a/python/sglang/srt/models/mixtral_quant.py b/python/sglang/srt/models/mixtral_quant.py index 07caf38334..97ac09ee62 100644 --- a/python/sglang/srt/models/mixtral_quant.py +++ b/python/sglang/srt/models/mixtral_quant.py @@ -29,7 +29,6 @@ get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) -from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( QKVParallelLinear, ReplicatedLinear, @@ -43,8 +42,10 @@ ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -160,7 +161,6 @@ def __init__( max_position: int = 4096 * 32, rope_theta: float = 10000, quant_config: Optional[QuantizationConfig] = None, - sliding_window: Optional[int] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -183,7 +183,6 @@ def __init__( self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta - self.sliding_window = sliding_window self.qkv_proj = QKVParallelLinear( hidden_size, @@ -246,7 +245,6 @@ def __init__( num_kv_heads=config.num_key_value_heads, layer_id=layer_id, rope_theta=rope_theta, - sliding_window=config.sliding_window, quant_config=quant_config, ) self.block_sparse_moe = MixtralMoE(config=config, quant_config=quant_config) @@ -336,6 +334,7 @@ def __init__( self.model = MixtralModel(config, quant_config=quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() @torch.no_grad() def forward( @@ -346,9 +345,11 @@ def forward( input_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/qwen.py b/python/sglang/srt/models/qwen.py index ffc512b1ca..4958a81298 100644 --- a/python/sglang/srt/models/qwen.py +++ b/python/sglang/srt/models/qwen.py @@ -22,8 +22,6 @@ from transformers import PretrainedConfig from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, @@ -37,8 +35,11 @@ ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -251,6 +252,7 @@ def __init__( vocab_size = ((config.vocab_size + 63) // 64) * 64 self.lm_head = ParallelLMHead(vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() @torch.no_grad() def forward( @@ -260,10 +262,11 @@ def forward( input_metadata: InputMetadata, ): hidden_states = self.transformer(input_ids, positions, input_metadata) - next_tokens = self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - return next_tokens + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index dec962bf0a..6bb5c0b906 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -22,8 +22,6 @@ from torch import nn from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, @@ -37,8 +35,12 @@ ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata Qwen2Config = None @@ -275,6 +277,8 @@ def __init__( self.model = Qwen2Model(config, quant_config=quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() + self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) @torch.no_grad() def forward( @@ -283,11 +287,17 @@ def forward( positions: torch.Tensor, input_metadata: InputMetadata, input_embeds: torch.Tensor = None, + get_embedding: bool = False, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, input_metadata - ) + if not get_embedding: + logits_output = self.logits_processor( + input_ids, hidden_states, self.lm_head.weight, input_metadata + ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output + else: + return self.pooler(hidden_states, input_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ @@ -306,6 +316,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue + if name.startswith("model.vision_tower") and name not in params_dict: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue @@ -313,8 +326,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - if name.startswith("model.vision_tower") and name not in params_dict: - continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -323,8 +334,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - if name.startswith("model.vision_tower") and name not in params_dict: - continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index f96f7e0e48..67b5a6ce66 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -28,27 +28,26 @@ get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) -from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear, ) -from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -366,6 +365,7 @@ def __init__( config.vocab_size, config.hidden_size, quant_config=quant_config ) self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() @torch.no_grad() def forward( @@ -376,20 +376,11 @@ def forward( input_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - - def compute_logits( - self, - input_ids: torch.Tensor, - hidden_states: torch.Tensor, - input_metadata: InputMetadata, - ) -> torch.Tensor: - logits = self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, input_metadata - ) - return logits + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ @@ -401,24 +392,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("gate_up_proj", "up_proj", 1), ] - expert_params_mapping = [ - # These are the weights for the experts - # (param_name, weight_name, expert_id, shard_id) - ( - ( - "experts.w13_weight" - if weight_name in ["gate_proj", "up_proj"] - else "experts.w2_weight" - ), - f"experts.{expert_id}.{weight_name}.weight", - expert_id, - shard_id, - ) - for expert_id in range(self.config.num_experts) - for shard_id, weight_name in enumerate( - ["gate_proj", "down_proj", "up_proj"] - ) - ] + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts, + ) params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: @@ -458,7 +437,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader( param, loaded_weight, - weight_name, + name, shard_id=shard_id, expert_id=expert_id, ) diff --git a/python/sglang/srt/models/stablelm.py b/python/sglang/srt/models/stablelm.py index aeaa46ab12..a3102baabd 100644 --- a/python/sglang/srt/models/stablelm.py +++ b/python/sglang/srt/models/stablelm.py @@ -24,7 +24,6 @@ from transformers import PretrainedConfig from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, @@ -38,8 +37,10 @@ ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -249,6 +250,7 @@ def __init__( self.model = StableLMEpochModel(config, quant_config=quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() @torch.no_grad() def forward( @@ -259,9 +261,11 @@ def forward( input_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/yivl.py b/python/sglang/srt/models/yivl.py index 11d4cda1c0..0f86206d82 100644 --- a/python/sglang/srt/models/yivl.py +++ b/python/sglang/srt/models/yivl.py @@ -24,10 +24,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from sglang.srt.models.llava import ( - LlavaLlamaForCausalLM, - monkey_path_clip_vision_embed_forward, -) +from sglang.srt.models.llava import LlavaLlamaForCausalLM class YiVLForCausalLM(LlavaLlamaForCausalLM): @@ -50,7 +47,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): self.config._name_or_path, torch_dtype=torch.float16, subfolder=self.vision_tower_subfolder, - ).cuda() + ).to("cuda") self.vision_tower.eval() @@ -94,8 +91,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # load language model self.language_model.load_weights(weights) - monkey_path_clip_vision_embed_forward() - class YiVLMultiModalProjector(nn.Module): def __init__(self, config: LlavaConfig): diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 8998cf39de..4feb632b0b 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -17,6 +17,7 @@ import asyncio import json +import logging import os import time import uuid @@ -64,6 +65,8 @@ UsageInfo, ) +logger = logging.getLogger(__name__) + chat_template_name = None @@ -117,37 +120,48 @@ def create_streaming_error_response( return json_str -def load_chat_template_for_openai_api(chat_template_arg): +def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg): global chat_template_name - print(f"Use chat template: {chat_template_arg}") + logger.info(f"Use chat template: {chat_template_arg}") if not chat_template_exists(chat_template_arg): if not os.path.exists(chat_template_arg): raise RuntimeError( f"Chat template {chat_template_arg} is not a built-in template name " "or a valid chat template file path." ) - with open(chat_template_arg, "r") as filep: - template = json.load(filep) - try: - sep_style = SeparatorStyle[template["sep_style"]] - except KeyError: - raise ValueError( - f"Unknown separator style: {template['sep_style']}" - ) from None - register_conv_template( - Conversation( - name=template["name"], - system_template=template["system"] + "\n{system_message}", - system_message=template.get("system_message", ""), - roles=(template["user"], template["assistant"]), - sep_style=sep_style, - sep=template.get("sep", "\n"), - stop_str=template["stop_str"], - ), - override=True, + if chat_template_arg.endswith(".jinja"): + with open(chat_template_arg, "r") as f: + chat_template = "".join(f.readlines()).strip("\n") + tokenizer_manager.tokenizer.chat_template = chat_template.replace( + "\\n", "\n" ) - chat_template_name = template["name"] + chat_template_name = None + else: + assert chat_template_arg.endswith( + ".json" + ), "unrecognized format of chat template file" + with open(chat_template_arg, "r") as filep: + template = json.load(filep) + try: + sep_style = SeparatorStyle[template["sep_style"]] + except KeyError: + raise ValueError( + f"Unknown separator style: {template['sep_style']}" + ) from None + register_conv_template( + Conversation( + name=template["name"], + system_template=template["system"] + "\n{system_message}", + system_message=template.get("system_message", ""), + roles=(template["user"], template["assistant"]), + sep_style=sep_style, + sep=template.get("sep", "\n"), + stop_str=template["stop_str"], + ), + override=True, + ) + chat_template_name = template["name"] else: chat_template_name = chat_template_arg @@ -261,20 +275,32 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe end_point = batch_storage[batch_id].endpoint file_request_list = [] all_requests = [] + request_ids = [] for line in lines: request_data = json.loads(line) file_request_list.append(request_data) body = request_data["body"] + request_ids.append(request_data["custom_id"]) + + # Although streaming is supported for standalone completions, it is not supported in + # batch mode (multiple completions in single request). + if body.get("stream", False): + raise ValueError("Streaming requests are not supported in batch mode") + if end_point == "/v1/chat/completions": all_requests.append(ChatCompletionRequest(**body)) elif end_point == "/v1/completions": all_requests.append(CompletionRequest(**body)) + if end_point == "/v1/chat/completions": adapted_request, request = v1_chat_generate_request( - all_requests, tokenizer_manager + all_requests, tokenizer_manager, request_ids=request_ids ) elif end_point == "/v1/completions": - adapted_request, request = v1_generate_request(all_requests) + adapted_request, request = v1_generate_request( + all_requests, request_ids=request_ids + ) + try: ret = await tokenizer_manager.generate_request(adapted_request).__anext__() if not isinstance(ret, list): @@ -306,6 +332,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe } all_ret.append(response_json) completed_requests += 1 + # Write results to a new file output_file_id = f"backend_result_file-{uuid.uuid4()}" global storage_dir @@ -335,7 +362,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe } except Exception as e: - print("error in SGLang:", e) + logger.error("error in SGLang:", e) # Update batch status to "failed" retrieve_batch = batch_storage[batch_id] retrieve_batch.status = "failed" @@ -352,6 +379,72 @@ async def v1_retrieve_batch(batch_id: str): return batch_response +async def v1_cancel_batch(tokenizer_manager, batch_id: str): + # Retrieve the batch job from the in-memory storage + batch_response = batch_storage.get(batch_id) + if batch_response is None: + raise HTTPException(status_code=404, detail="Batch not found") + + # Only do cancal when status is "validating" or "in_progress" + if batch_response.status in ["validating", "in_progress"]: + # Start cancelling the batch asynchronously + asyncio.create_task( + cancel_batch( + tokenizer_manager=tokenizer_manager, + batch_id=batch_id, + input_file_id=batch_response.input_file_id, + ) + ) + + # Update batch status to "cancelling" + batch_response.status = "cancelling" + + return batch_response + else: + raise HTTPException( + status_code=500, + detail=f"Current status is {batch_response.status}, no need to cancel", + ) + + +async def cancel_batch(tokenizer_manager, batch_id: str, input_file_id: str): + try: + # Update the batch status to "cancelling" + batch_storage[batch_id].status = "cancelling" + + # Retrieve the input file content + input_file_request = file_id_request.get(input_file_id) + if not input_file_request: + raise ValueError("Input file not found") + + # Parse the JSONL file and process each request + input_file_path = file_id_storage.get(input_file_id) + with open(input_file_path, "r", encoding="utf-8") as f: + lines = f.readlines() + + file_request_list = [] + request_ids = [] + for line in lines: + request_data = json.loads(line) + file_request_list.append(request_data) + request_ids.append(request_data["custom_id"]) + + # Cancel requests by request_ids + for rid in request_ids: + tokenizer_manager.abort_request(rid=rid) + + retrieve_batch = batch_storage[batch_id] + retrieve_batch.status = "cancelled" + + except Exception as e: + logger.error("error in SGLang:", e) + # Update batch status to "failed" + retrieve_batch = batch_storage[batch_id] + retrieve_batch.status = "failed" + retrieve_batch.failed_at = int(time.time()) + retrieve_batch.errors = {"message": str(e)} + + async def v1_retrieve_file(file_id: str): # Retrieve the batch job from the in-memory storage file_response = file_id_response.get(file_id) @@ -372,20 +465,35 @@ def iter_file(): return StreamingResponse(iter_file(), media_type="application/octet-stream") -def v1_generate_request(all_requests): +def v1_generate_request( + all_requests: List[CompletionRequest], request_ids: List[str] = None +): prompts = [] sampling_params_list = [] return_logprobs = [] + logprob_start_lens = [] top_logprobs_nums = [] - first_prompt_type = type(all_requests[0].prompt) + # NOTE: with openai API, the prompt's logprobs are always not computed + first_prompt_type = type(all_requests[0].prompt) for request in all_requests: - prompt = request.prompt assert ( - type(prompt) == first_prompt_type + type(request.prompt) == first_prompt_type ), "All prompts must be of the same type in file input settings" - prompts.append(prompt) + if len(all_requests) > 1 and request.n > 1: + raise ValueError( + "Parallel sampling is not supported for completions from files" + ) + if request.echo and request.logprobs: + logger.warning( + "Echo is not compatible with logprobs. " + "To compute logprobs of input prompt, please use SGLang /request API." + ) + + for request in all_requests: + prompts.append(request.prompt) return_logprobs.append(request.logprobs is not None and request.logprobs > 0) + logprob_start_lens.append(-1) top_logprobs_nums.append( request.logprobs if request.logprobs is not None else 0 ) @@ -401,18 +509,16 @@ def v1_generate_request(all_requests): "frequency_penalty": request.frequency_penalty, "repetition_penalty": request.repetition_penalty, "regex": request.regex, + "json_schema": request.json_schema, "n": request.n, "ignore_eos": request.ignore_eos, } ) - if len(all_requests) > 1 and request.n > 1: - raise ValueError( - "Parallel sampling is not supported for completions from files" - ) if len(all_requests) == 1: prompt = prompts[0] sampling_params_list = sampling_params_list[0] + logprob_start_lens = logprob_start_lens[0] return_logprobs = return_logprobs[0] top_logprobs_nums = top_logprobs_nums[0] if isinstance(prompt, str) or isinstance(prompt[0], str): @@ -430,8 +536,10 @@ def v1_generate_request(all_requests): sampling_params=sampling_params_list, return_logprob=return_logprobs, top_logprobs_num=top_logprobs_nums, + logprob_start_len=logprob_start_lens, return_text_in_logprobs=True, stream=all_requests[0].stream, + rid=request_ids, ) if len(all_requests) == 1: @@ -569,27 +677,45 @@ async def v1_completions(tokenizer_manager, raw_request: Request): if adapted_request.stream: async def generate_stream_resp(): - stream_buffer = "" - n_prev_token = 0 + stream_buffers = {} + n_prev_tokens = {} + prompt_tokens = {} + completion_tokens = {} try: async for content in tokenizer_manager.generate_request( adapted_request, raw_request ): + index = content["index"] + + stream_buffer = stream_buffers.get(index, "") + n_prev_token = n_prev_tokens.get(index, 0) + text = content["text"] - prompt_tokens = content["meta_info"]["prompt_tokens"] - completion_tokens = content["meta_info"]["completion_tokens"] + prompt_tokens[index] = content["meta_info"]["prompt_tokens"] + completion_tokens[index] = content["meta_info"]["completion_tokens"] if not stream_buffer: # The first chunk if request.echo: if isinstance(request.prompt, str): # for the case of single str prompts prompts = request.prompt - elif isinstance(request.prompt, list) and isinstance( - request.prompt[0], int - ): - prompts = tokenizer_manager.tokenizer.decode( - request.prompt, skip_special_tokens=True - ) + elif isinstance(request.prompt, list): + if isinstance(request.prompt[0], str): + # for the case of multiple str prompts + prompts = request.prompt[index // request.n] + elif isinstance(request.prompt[0], int): + # for the case of single token ids prompt + prompts = tokenizer_manager.tokenizer.decode( + request.prompt, skip_special_tokens=True + ) + elif isinstance(request.prompt[0], list) and isinstance( + request.prompt[0][0], int + ): + # for the case of multiple token ids prompts + prompts = tokenizer_manager.tokenizer.decode( + request.prompt[index // request.n], + skip_special_tokens=True, + ) # Prepend prompt in response text. text = prompts + text @@ -626,7 +752,7 @@ async def generate_stream_resp(): delta = text[len(stream_buffer) :] stream_buffer = stream_buffer + delta choice_data = CompletionResponseStreamChoice( - index=0, + index=index, text=delta, logprobs=logprobs, finish_reason=format_finish_reason( @@ -639,12 +765,24 @@ async def generate_stream_resp(): choices=[choice_data], model=request.model, ) + + stream_buffers[index] = stream_buffer + n_prev_tokens[index] = n_prev_token + yield f"data: {chunk.model_dump_json()}\n\n" if request.stream_options and request.stream_options.include_usage: + total_prompt_tokens = sum( + tokens + for i, tokens in prompt_tokens.items() + if i % request.n == 0 + ) + total_completion_tokens = sum( + tokens for tokens in completion_tokens.values() + ) usage = UsageInfo( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, + prompt_tokens=total_prompt_tokens, + completion_tokens=total_completion_tokens, + total_tokens=total_prompt_tokens + total_completion_tokens, ) final_usage_chunk = CompletionStreamResponse( @@ -683,12 +821,20 @@ async def generate_stream_resp(): return response -def v1_chat_generate_request(all_requests, tokenizer_manager): +def v1_chat_generate_request( + all_requests: List[ChatCompletionRequest], + tokenizer_manager, + request_ids: List[str] = None, +): input_ids = [] sampling_params_list = [] image_data_list = [] return_logprobs = [] + logprob_start_lens = [] top_logprobs_nums = [] + + # NOTE: with openai API, the prompt's logprobs are always not computed + for request in all_requests: # Prep the data needed for the underlying GenerateReqInput: # - prompt: The full prompt string. @@ -721,6 +867,7 @@ def v1_chat_generate_request(all_requests, tokenizer_manager): image_data = None input_ids.append(prompt_ids) return_logprobs.append(request.logprobs) + logprob_start_lens.append(-1) top_logprobs_nums.append(request.top_logprobs) sampling_params_list.append( { @@ -734,6 +881,7 @@ def v1_chat_generate_request(all_requests, tokenizer_manager): "frequency_penalty": request.frequency_penalty, "repetition_penalty": request.repetition_penalty, "regex": request.regex, + "json_schema": request.json_schema, "n": request.n, } ) @@ -747,20 +895,24 @@ def v1_chat_generate_request(all_requests, tokenizer_manager): sampling_params_list = sampling_params_list[0] image_data = image_data_list[0] return_logprobs = return_logprobs[0] + logprob_start_lens = logprob_start_lens[0] top_logprobs_nums = top_logprobs_nums[0] else: if isinstance(input_ids[0], str): prompt_kwargs = {"text": input_ids} else: prompt_kwargs = {"input_ids": input_ids} + adapted_request = GenerateReqInput( **prompt_kwargs, image_data=image_data, sampling_params=sampling_params_list, return_logprob=return_logprobs, + logprob_start_len=logprob_start_lens, top_logprobs_num=top_logprobs_nums, stream=all_requests[0].stream, return_text_in_logprobs=True, + rid=request_ids, ) if len(all_requests) == 1: return adapted_request, all_requests[0] @@ -881,16 +1033,23 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): if adapted_request.stream: async def generate_stream_resp(): - is_first = True - - stream_buffer = "" - n_prev_token = 0 + is_firsts = {} + stream_buffers = {} + n_prev_tokens = {} + prompt_tokens = {} + completion_tokens = {} try: async for content in tokenizer_manager.generate_request( adapted_request, raw_request ): - prompt_tokens = content["meta_info"]["prompt_tokens"] - completion_tokens = content["meta_info"]["completion_tokens"] + index = content["index"] + + is_first = is_firsts.get(index, True) + stream_buffer = stream_buffers.get(index, "") + n_prev_token = n_prev_tokens.get(index, 0) + + prompt_tokens[index] = content["meta_info"]["prompt_tokens"] + completion_tokens[index] = content["meta_info"]["completion_tokens"] if request.logprobs: logprobs = to_openai_style_logprobs( output_token_logprobs=content["meta_info"][ @@ -940,7 +1099,7 @@ async def generate_stream_resp(): # First chunk with role is_first = False choice_data = ChatCompletionResponseStreamChoice( - index=0, + index=index, delta=DeltaMessage(role="assistant"), finish_reason=format_finish_reason( content["meta_info"]["finish_reason"] @@ -958,7 +1117,7 @@ async def generate_stream_resp(): delta = text[len(stream_buffer) :] stream_buffer = stream_buffer + delta choice_data = ChatCompletionResponseStreamChoice( - index=0, + index=index, delta=DeltaMessage(content=delta), finish_reason=format_finish_reason( content["meta_info"]["finish_reason"] @@ -970,12 +1129,25 @@ async def generate_stream_resp(): choices=[choice_data], model=request.model, ) + + is_firsts[index] = is_first + stream_buffers[index] = stream_buffer + n_prev_tokens[index] = n_prev_token + yield f"data: {chunk.model_dump_json()}\n\n" if request.stream_options and request.stream_options.include_usage: + total_prompt_tokens = sum( + tokens + for i, tokens in prompt_tokens.items() + if i % request.n == 0 + ) + total_completion_tokens = sum( + tokens for tokens in completion_tokens.values() + ) usage = UsageInfo( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, + prompt_tokens=total_prompt_tokens, + completion_tokens=total_completion_tokens, + total_tokens=total_prompt_tokens + total_completion_tokens, ) final_usage_chunk = ChatCompletionStreamResponse( diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index 758e48edef..ce51e1c029 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -161,6 +161,7 @@ class CompletionRequest(BaseModel): # Extra parameters for SRT backend only and will be ignored by OpenAI models. regex: Optional[str] = None + json_schema: Optional[str] = None ignore_eos: Optional[bool] = False min_tokens: Optional[int] = 0 repetition_penalty: Optional[float] = 1.0 @@ -262,6 +263,7 @@ class ChatCompletionRequest(BaseModel): # Extra parameters for SRT backend only and will be ignored by OpenAI models. regex: Optional[str] = None + json_schema: Optional[str] = None min_tokens: Optional[int] = 0 repetition_penalty: Optional[float] = 1.0 stop_token_ids: Optional[List[int]] = Field(default_factory=list) diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py new file mode 100644 index 0000000000..7843f4bd32 --- /dev/null +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -0,0 +1,209 @@ +from __future__ import annotations + +import dataclasses +from typing import TYPE_CHECKING, List + +import torch + +import sglang.srt.sampling.penaltylib as penaltylib + +if TYPE_CHECKING: + from sglang.srt.managers.schedule_batch import ScheduleBatch + + +@dataclasses.dataclass +class SamplingBatchInfo: + # Basic Info + vocab_size: int + + # Batched sampling params + temperatures: torch.Tensor = None + top_ps: torch.Tensor = None + top_ks: torch.Tensor = None + min_ps: torch.Tensor = None + + # Dispatch in CUDA graph + need_min_p_sampling: bool = False + + # Bias Tensors + logit_bias: torch.Tensor = None + vocab_mask: torch.Tensor = None + + # Penalizer + penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None + linear_penalties: torch.Tensor = None + scaling_penalties: torch.Tensor = None + + def has_bias(self): + return ( + self.logit_bias is not None + or self.vocab_mask is not None + or self.linear_penalties is not None + or self.scaling_penalties is not None + ) + + @classmethod + def dummy_one(cls, max_bs: int, vocab_size: int): + ret = cls(vocab_size=vocab_size) + ret.temperatures = torch.ones((max_bs, 1), dtype=torch.float, device="cuda") + ret.top_ps = torch.ones((max_bs,), dtype=torch.float, device="cuda") + ret.top_ks = torch.ones((max_bs,), dtype=torch.int, device="cuda") + ret.min_ps = torch.zeros((max_bs,), dtype=torch.float, device="cuda") + return ret + + def __getitem__(self, key): + if isinstance(key, slice): + # NOTE: We do not use cuda graph when there is bias tensors + assert not self.has_bias() + return SamplingBatchInfo( + vocab_size=self.vocab_size, + temperatures=self.temperatures[key], + top_ps=self.top_ps[key], + top_ks=self.top_ks[key], + min_ps=self.min_ps[key], + need_min_p_sampling=self.need_min_p_sampling, + ) + else: + raise NotImplementedError + + def inplace_assign(self, bs: int, other: SamplingBatchInfo): + # NOTE: We do not use cuda graph when there is bias tensors + assert not self.has_bias() + + self.vocab_size = other.vocab_size + self.need_min_p_sampling = other.need_min_p_sampling + + self.temperatures[:bs] = other.temperatures + self.top_ps[:bs] = other.top_ps + self.top_ks[:bs] = other.top_ks + self.min_ps[:bs] = other.min_ps + + @classmethod + def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): + device = "cuda" + reqs = batch.reqs + ret = cls(vocab_size=vocab_size) + + ret.temperatures = torch.tensor( + [r.sampling_params.temperature for r in reqs], + dtype=torch.float, + device=device, + ).view(-1, 1) + ret.top_ps = torch.tensor( + [r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device + ) + ret.top_ks = torch.tensor( + [r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device + ) + ret.min_ps = torch.tensor( + [r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device + ) + ret.need_min_p_sampling = any(r.sampling_params.min_p > 0 for r in reqs) + + # Each penalizers will do nothing if they evaluate themselves as not required by looking at + # the sampling_params of the requests (See {_is_required()} of each penalizers). So this + # should not add hefty computation overhead other than simple checks. + # + # While we choose not to even create the class instances if they are not required, this + # could add additional complexity to the {ScheduleBatch} class, especially we need to + # handle {filter_batch()} and {merge()} cases as well. + ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator( + vocab_size=vocab_size, + batch=batch, + device=device, + Penalizers={ + penaltylib.BatchedFrequencyPenalizer, + penaltylib.BatchedMinNewTokensPenalizer, + penaltylib.BatchedPresencePenalizer, + penaltylib.BatchedRepetitionPenalizer, + }, + ) + + # Handle logit bias but only allocate when needed + ret.logit_bias = None + + ret.update_regex_vocab_mask(batch) + + return ret + + def prepare_penalties(self): + self.scaling_penalties = None + self.linear_penalties = None + + for penalizer in self.penalizer_orchestrator.penalizers.values(): + if isinstance(penalizer, penaltylib.BatchedRepetitionPenalizer): + if penalizer.is_prepared(): + self.scaling_penalties = penalizer.cumulated_repetition_penalties + else: + if penalizer.is_prepared(): + if self.linear_penalties is None: + bs = self.penalizer_orchestrator.batch.batch_size() + self.linear_penalties = torch.zeros( + (bs, self.vocab_size), + dtype=torch.float32, + device="cuda", + ) + self.linear_penalties = penalizer.apply(self.linear_penalties) + + def update_regex_vocab_mask(self, batch: ScheduleBatch): + bs, reqs = batch.batch_size(), batch.reqs + device = "cuda" + has_regex = any(req.regex_fsm is not None for req in reqs) + + # Reset the vocab mask + self.vocab_mask = None + + if has_regex: + for i, req in enumerate(reqs): + if req.regex_fsm is not None: + if self.vocab_mask is None: + self.vocab_mask = torch.zeros( + bs, self.vocab_size, dtype=torch.bool, device=device + ) + self.vocab_mask[i][ + req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens + ] = 1 + + def filter(self, unfinished_indices: List[int], new_indices: torch.Tensor): + self.penalizer_orchestrator.filter(unfinished_indices, new_indices) + + for item in [ + "temperatures", + "top_ps", + "top_ks", + "min_ps", + "logit_bias", + ]: + self_val = getattr(self, item, None) + if self_val is not None: # logit_bias can be None + setattr(self, item, self_val[new_indices]) + + def merge(self, other: "SamplingBatchInfo"): + self.penalizer_orchestrator.merge(other.penalizer_orchestrator) + + for item in [ + "temperatures", + "top_ps", + "top_ks", + "min_ps", + ]: + self_val = getattr(self, item, None) + other_val = getattr(other, item, None) + setattr(self, item, torch.concat([self_val, other_val])) + + # logit_bias can be None + if self.logit_bias is not None or other.logit_bias is not None: + vocab_size = ( + self.logit_bias.shape[1] + if self.logit_bias is not None + else other.logit_bias.shape[1] + ) + if self.logit_bias is None: + self.logit_bias = torch.zeros( + (len(self.reqs), vocab_size), dtype=torch.float32, device="cuda" + ) + if other.logit_bias is None: + other.logit_bias = torch.zeros( + (len(other.reqs), vocab_size), dtype=torch.float32, device="cuda" + ) + self.logit_bias = torch.concat([self.logit_bias, other.logit_bias]) diff --git a/python/sglang/srt/sampling_params.py b/python/sglang/srt/sampling/sampling_params.py similarity index 84% rename from python/sglang/srt/sampling_params.py rename to python/sglang/srt/sampling/sampling_params.py index 29067dc851..8111757d85 100644 --- a/python/sglang/srt/sampling_params.py +++ b/python/sglang/srt/sampling/sampling_params.py @@ -30,19 +30,21 @@ def __init__( temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1, + min_p: float = 0.0, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, repetition_penalty: float = 1.0, ignore_eos: bool = False, skip_special_tokens: bool = True, spaces_between_special_tokens: bool = True, - dtype: Optional[str] = None, regex: Optional[str] = None, n: int = 1, + json_schema: Optional[str] = None, ) -> None: self.temperature = temperature self.top_p = top_p self.top_k = top_k + self.min_p = min_p self.frequency_penalty = frequency_penalty self.presence_penalty = presence_penalty self.repetition_penalty = repetition_penalty @@ -53,9 +55,9 @@ def __init__( self.ignore_eos = ignore_eos self.skip_special_tokens = skip_special_tokens self.spaces_between_special_tokens = spaces_between_special_tokens - self.dtype = dtype self.regex = regex self.n = n + self.json_schema = json_schema # Process some special cases if self.temperature < _SAMPLING_EPS: @@ -63,8 +65,6 @@ def __init__( self.top_k = 1 if self.top_k == -1: self.top_k = 1 << 30 # whole vocabulary - if self.dtype == "int": - self.stop_strs = [" ", "\n"] def verify(self): if self.temperature < 0.0: @@ -73,6 +73,8 @@ def verify(self): ) if not 0.0 < self.top_p <= 1.0: raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.") + if not 0.0 <= self.min_p <= 1.0: + raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.") if self.top_k < -1 or self.top_k == 0: raise ValueError( f"top_k must be -1 (disable), or at least 1, " f"got {self.top_k}." @@ -106,6 +108,8 @@ def verify(self): f"min_new_tokens must be in (0, max_new_tokens({self.max_new_tokens})], got " f"{self.min_new_tokens}." ) + if self.regex is not None and self.json_schema is not None: + raise ValueError("regex and json_schema cannot be both set.") def normalize(self, tokenizer): # Process stop strings @@ -127,3 +131,17 @@ def normalize(self, tokenizer): else: stop_str_max_len = max(stop_str_max_len, len(stop_str)) self.stop_str_max_len = stop_str_max_len + + def to_srt_kwargs(self): + return { + "max_new_tokens": self.max_new_tokens, + "stop": self.stop_strs, + "stop_token_ids": list(self.stop_token_ids), + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + "frequency_penalty": self.frequency_penalty, + "presence_penalty": self.presence_penalty, + "ignore_eos": self.ignore_eos, + "regex": self.regex, + } diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 8b67663357..5ba2a45e70 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -24,7 +24,6 @@ import logging import multiprocessing as mp import os -import sys import threading import time from http import HTTPStatus @@ -34,7 +33,6 @@ setattr(threading, "_register_atexit", lambda *args, **kwargs: None) import aiohttp -import psutil import requests import uvicorn import uvloop @@ -52,11 +50,16 @@ start_controller_process as start_controller_process_single, ) from sglang.srt.managers.detokenizer_manager import start_detokenizer_process -from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput +from sglang.srt.managers.io_struct import ( + EmbeddingReqInput, + GenerateReqInput, + UpdateWeightReqInput, +) from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.openai_api.adapter import ( load_chat_template_for_openai_api, v1_batches, + v1_cancel_batch, v1_chat_completions, v1_completions, v1_delete_file, @@ -72,6 +75,7 @@ add_api_key_middleware, allocate_init_ports, assert_pkg_version, + configure_logger, enable_show_time_cost, kill_child_process, maybe_set_triton_cache_manager, @@ -92,10 +96,25 @@ @app.get("/health") async def health() -> Response: - """Health check.""" + """Check the health of the http server.""" return Response(status_code=200) +@app.get("/health_generate") +async def health_generate(request: Request) -> Response: + """Check the health of the inference server by generating one token.""" + gri = GenerateReqInput( + text="s", sampling_params={"max_new_tokens": 1, "temperature": 0.7} + ) + try: + async for _ in tokenizer_manager.generate_request(gri, request): + break + return Response(status_code=200) + except Exception as e: + logger.exception(e) + return Response(status_code=503) + + @app.get("/get_model_info") async def get_model_info(): result = { @@ -120,6 +139,23 @@ async def flush_cache(): ) +@app.post("/update_weights") +async def update_weights(obj: UpdateWeightReqInput, request: Request): + + success, message = await tokenizer_manager.update_weights(obj, request) + content = {"message": message, "success": str(success)} + if success: + return JSONResponse( + content, + status_code=HTTPStatus.OK, + ) + else: + return JSONResponse( + content, + status_code=HTTPStatus.BAD_REQUEST, + ) + + async def generate_request(obj: GenerateReqInput, request: Request): """Handle a generate request.""" if obj.stream: @@ -211,6 +247,12 @@ async def openai_v1_batches(raw_request: Request): return await v1_batches(tokenizer_manager, raw_request) +@app.post("/v1/batches/{batch_id}/cancel") +async def cancel_batches(batch_id: str): + # https://platform.openai.com/docs/api-reference/batch/cancel + return await v1_cancel_batch(tokenizer_manager, batch_id) + + @app.get("/v1/batches/{batch_id}") async def retrieve_batch(batch_id: str): return await v1_retrieve_batch(batch_id) @@ -236,15 +278,12 @@ def launch_server( """Launch an HTTP server.""" global tokenizer_manager - logging.basicConfig( - level=getattr(logging, server_args.log_level.upper()), - format="%(message)s", - ) + configure_logger(server_args) server_args.check_server_args() _set_envs_and_config(server_args) - # Allocate ports + # Allocate ports for inter-process communications server_args.port, server_args.additional_ports = allocate_init_ports( server_args.port, server_args.additional_ports, @@ -264,42 +303,48 @@ def launch_server( server_args.tokenizer_path = prepare_tokenizer(server_args.tokenizer_path) # Launch processes for multi-node tensor parallelism - if server_args.nnodes > 1: - if server_args.node_rank != 0: - tp_size_local = server_args.tp_size // server_args.nnodes - gpu_ids = [ - i for _ in range(server_args.nnodes) for i in range(tp_size_local) - ] - tp_rank_range = list( - range( - server_args.node_rank * tp_size_local, - (server_args.node_rank + 1) * tp_size_local, - ) - ) - procs = launch_tp_servers( - gpu_ids, - tp_rank_range, - server_args, - ports[3], - model_overide_args, + if server_args.nnodes > 1 and server_args.node_rank != 0: + tp_size_local = server_args.tp_size // server_args.nnodes + gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)] + tp_rank_range = list( + range( + server_args.node_rank * tp_size_local, + (server_args.node_rank + 1) * tp_size_local, ) - while True: - pass + ) + procs = launch_tp_servers( + gpu_ids, + tp_rank_range, + server_args, + ports[3], + model_overide_args, + ) + + try: + for p in procs: + p.join() + finally: + kill_child_process(os.getpid(), including_parent=False) + return # Launch processes tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args) + if server_args.chat_template: + load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template) pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False) pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False) if server_args.dp_size == 1: - start_process = start_controller_process_single + start_controller_process = start_controller_process_single else: - start_process = start_controller_process_multi + start_controller_process = start_controller_process_multi + proc_controller = mp.Process( - target=start_process, + target=start_controller_process, args=(server_args, port_args, pipe_controller_writer, model_overide_args), ) proc_controller.start() + proc_detoken = mp.Process( target=start_detokenizer_process, args=( @@ -317,15 +362,11 @@ def launch_server( if controller_init_state != "init ok" or detoken_init_state != "init ok": proc_controller.kill() proc_detoken.kill() - print( - f"Initialization failed. controller_init_state: {controller_init_state}", - flush=True, + raise RuntimeError( + "Initialization failed. " + f"controller_init_state: {controller_init_state}, " + f"detoken_init_state: {detoken_init_state}" ) - print( - f"Initialization failed. detoken_init_state: {detoken_init_state}", - flush=True, - ) - sys.exit(1) assert proc_controller.is_alive() and proc_detoken.is_alive() # Add api key authorization @@ -334,12 +375,12 @@ def launch_server( # Send a warmup request t = threading.Thread( - target=_wait_and_warmup, args=(server_args, pipe_finish_writer) + target=_wait_and_warmup, args=(server_args, pipe_finish_writer, os.getpid()) ) t.start() - # Listen for requests try: + # Listen for requests uvicorn.run( app, host=server_args.host, @@ -358,6 +399,7 @@ def _set_envs_and_config(server_args: ServerArgs): os.environ["NCCL_CUMEM_ENABLE"] = "0" os.environ["NCCL_NVLS_ENABLE"] = "0" os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" # Set ulimit set_ulimit() @@ -375,23 +417,18 @@ def _set_envs_and_config(server_args: ServerArgs): # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency. maybe_set_triton_cache_manager() - # Set global chat template - if server_args.chat_template: - # TODO: replace this with huggingface transformers template - load_chat_template_for_openai_api(server_args.chat_template) - # Check flashinfer version if not server_args.disable_flashinfer: assert_pkg_version( "flashinfer", - "0.1.4", + "0.1.6", "Please uninstall the old version and " "reinstall the latest version by following the instructions " "at https://docs.flashinfer.ai/installation.html.", ) -def _wait_and_warmup(server_args, pipe_finish_writer): +def _wait_and_warmup(server_args, pipe_finish_writer, pid): headers = {} url = server_args.url() if server_args.api_key: @@ -414,8 +451,9 @@ def _wait_and_warmup(server_args, pipe_finish_writer): if not success: if pipe_finish_writer is not None: pipe_finish_writer.send(last_traceback) - print(f"Initialization failed. warmup error: {last_traceback}", flush=True) - sys.exit(1) + logger.error(f"Initialization failed. warmup error: {last_traceback}") + kill_child_process(pid, including_parent=False) + return # Send a warmup request request_name = "/generate" if model_info["is_generation"] else "/encode" @@ -440,12 +478,13 @@ def _wait_and_warmup(server_args, pipe_finish_writer): timeout=600, ) assert res.status_code == 200, f"{res}" - except Exception as e: + except Exception: last_traceback = get_exception_traceback() if pipe_finish_writer is not None: pipe_finish_writer.send(last_traceback) - print(f"Initialization failed. warmup error: {last_traceback}", flush=True) - sys.exit(1) + logger.error(f"Initialization failed. warmup error: {last_traceback}") + kill_child_process(pid, including_parent=False) + return logger.info("The server is fired up and ready to roll!") if pipe_finish_writer is not None: @@ -483,6 +522,7 @@ def __init__( self.pid = None pipe_reader, pipe_writer = mp.Pipe(duplex=False) + proc = mp.Process( target=launch_server, args=(self.server_args, model_overide_args, pipe_writer), @@ -524,11 +564,18 @@ async def async_generate( prompt: str, sampling_params: Optional[Dict] = None, ): - json_data = { - "text": prompt, - "sampling_params": sampling_params, - "stream": True, - } + if self.server_args.skip_tokenizer_init: + json_data = { + "input_ids": prompt, + "sampling_params": sampling_params, + "stream": True, + } + else: + json_data = { + "text": prompt, + "sampling_params": sampling_params, + "stream": True, + } pos = 0 timeout = aiohttp.ClientTimeout(total=3 * 3600) @@ -540,24 +587,29 @@ async def async_generate( if chunk == "data: [DONE]\n\n": break data = json.loads(chunk[5:].strip("\n")) - cur = data["text"][pos:] - if cur: - yield cur - pos += len(cur) + if hasattr(data, "text"): + cur = data["text"][pos:] + if cur: + yield cur + pos += len(cur) + else: + yield data add_request = async_generate def generate( self, - prompt: str, + prompt: Union[str, List[str]], sampling_params: Optional[Dict] = None, return_logprob: Optional[Union[List[bool], bool]] = False, + logprob_start_len: Optional[Union[List[int], int]] = None, top_logprobs_num: Optional[Union[List[int], int]] = None, ): json_data = { "text": prompt, "sampling_params": sampling_params, "return_logprob": return_logprob, + "logprob_start_len": logprob_start_len, "top_logprobs_num": top_logprobs_num, } response = requests.post( @@ -568,7 +620,7 @@ def generate( def encode( self, - prompt: str, + prompt: Union[str, List[str]], ): json_data = { "text": prompt, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 6fd11d1345..70c7204d2f 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -17,9 +17,12 @@ import argparse import dataclasses +import logging import random from typing import List, Optional, Union +logger = logging.getLogger(__name__) + @dataclasses.dataclass class ServerArgs: @@ -30,11 +33,13 @@ class ServerArgs: skip_tokenizer_init: bool = False load_format: str = "auto" dtype: str = "auto" + kv_cache_dtype: str = "auto" trust_remote_code: bool = True context_length: Optional[int] = None quantization: Optional[str] = None served_model_name: Optional[str] = None chat_template: Optional[str] = None + is_embedding: bool = False # Port host: str = "127.0.0.1" @@ -46,7 +51,7 @@ class ServerArgs: max_running_requests: Optional[int] = None max_num_reqs: Optional[int] = None max_total_tokens: Optional[int] = None - chunked_prefill_size: int = -1 + chunked_prefill_size: int = 8192 max_prefill_tokens: int = 16384 schedule_policy: str = "lpm" schedule_conservativeness: float = 1.0 @@ -76,12 +81,14 @@ class ServerArgs: disable_radix_cache: bool = False disable_regex_jump_forward: bool = False disable_cuda_graph: bool = False + disable_cuda_graph_padding: bool = False disable_disk_cache: bool = False + disable_custom_all_reduce: bool = False + enable_mixed_chunk: bool = False enable_torch_compile: bool = False enable_p2p_check: bool = False enable_mla: bool = False - attention_reduce_in_fp32: bool = False - efficient_weight_load: bool = False + triton_attention_reduce_in_fp32: bool = False # Distributed args nccl_init_addr: Optional[str] = None @@ -190,11 +197,23 @@ def add_cli_args(parser: argparse.ArgumentParser): '* "float" is shorthand for FP32 precision.\n' '* "float32" for FP32 precision.', ) + parser.add_argument( + "--kv-cache-dtype", + type=str, + default=ServerArgs.kv_cache_dtype, + choices=["auto", "fp8_e5m2"], + help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.', + ) parser.add_argument( "--trust-remote-code", action="store_true", help="Whether or not to allow for custom models defined on the Hub in their own modeling files.", ) + parser.add_argument( + "--is-embedding", + action="store_true", + help="Whether to use a CausalLM as an embedding model.", + ) parser.add_argument( "--context-length", type=int, @@ -389,11 +408,27 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Disable cuda graph.", ) + parser.add_argument( + "--disable-cuda-graph-padding", + action="store_true", + help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.", + ) parser.add_argument( "--disable-disk-cache", action="store_true", help="Disable disk cache to avoid possible crashes related to file system or high concurrency.", ) + parser.add_argument( + "--disable-custom-all-reduce", + action="store_true", + default=False, + help="Disable the custom all-reduce kernel and fall back to NCCL.", + ) + parser.add_argument( + "--enable-mixed-chunk", + action="store_true", + help="Enabling mixing prefill and decode in a batch when using chunked prefill.", + ) parser.add_argument( "--enable-torch-compile", action="store_true", @@ -407,13 +442,13 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--enable-mla", action="store_true", - help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2", + help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2.", ) parser.add_argument( - "--attention-reduce-in-fp32", + "--triton-attention-reduce-in-fp32", action="store_true", help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16." - "This only affects Triton attention kernels", + "This only affects Triton attention kernels.", ) parser.add_argument( "--efficient-weight-load", @@ -431,15 +466,6 @@ def from_cli_args(cls, args: argparse.Namespace): def url(self): return f"http://{self.host}:{self.port}" - def print_mode_args(self): - return ( - f"disable_flashinfer={self.disable_flashinfer}, " - f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}, " - f"disable_radix_cache={self.disable_radix_cache}, " - f"disable_regex_jump_forward={self.disable_regex_jump_forward}, " - f"disable_disk_cache={self.disable_disk_cache}, " - ) - def check_server_args(self): assert ( self.tp_size % self.nnodes == 0 @@ -447,6 +473,14 @@ def check_server_args(self): assert not ( self.dp_size > 1 and self.node_rank is not None ), "multi-node data parallel is not supported" + if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path: + logger.info( + "Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True" + ) + self.trust_remote_code = False + if "gemma-2" in self.model_path.lower(): + logger.info("When using sliding window in gemma-2, turn on flashinfer.") + self.disable_flashinfer = False @dataclasses.dataclass diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 2d20881c8f..66a5679d75 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -26,7 +26,7 @@ import time from importlib.metadata import PackageNotFoundError, version from io import BytesIO -from typing import List, Optional +from typing import List, Optional, Union import numpy as np import psutil @@ -35,7 +35,6 @@ import torch.distributed as dist from fastapi.responses import JSONResponse from packaging import version as pkg_version -from starlette.middleware.base import BaseHTTPMiddleware from torch.nn.parameter import Parameter from triton.runtime.cache import ( FileCacheManager, @@ -194,44 +193,30 @@ def allocate_init_ports( return ret_ports[0], ret_ports[1:num_ports_needed] -def get_int_token_logit_bias(tokenizer, vocab_size): - """Get the logit bias for integer-only tokens.""" - # a bug when model's vocab size > tokenizer.vocab_size - if tokenizer == None: - return [-1e5] * vocab_size - vocab_size = tokenizer.vocab_size - logit_bias = np.zeros(vocab_size, dtype=np.float32) - for t_id in range(vocab_size): - ss = tokenizer.decode([t_id]).strip() - if not (ss.isdigit() or len(ss) == 0 or t_id == tokenizer.eos_token_id): - logit_bias[t_id] = -1e5 - - return logit_bias - - -def is_multimodal_model(model): - from sglang.srt.model_config import ModelConfig - - if isinstance(model, str): - model = model.lower() - return "llava" in model or "yi-vl" in model or "llava-next" in model - - if isinstance(model, ModelConfig): - model_path = model.path.lower() - return ( - "llava" in model_path or "yi-vl" in model_path or "llava-next" in model_path - ) +def is_multimodal_model(model_architectures): + if ( + "LlavaLlamaForCausalLM" in model_architectures + or "LlavaQwenForCausalLM" in model_architectures + or "LlavaMistralForCausalLM" in model_architectures + or "LlavaVidForCausalLM" in model_architectures + ): + return True + else: + return False - raise ValueError("unrecognized type") +def is_generation_model(model_architectures, is_embedding: bool = False): + # We have two ways to determine whether a model is a generative model. + # 1. Check the model architectue + # 2. check the `is_embedding` server args -def is_generation_model(model_architectures): if ( "LlamaEmbeddingModel" in model_architectures or "MistralModel" in model_architectures ): return False - return True + else: + return not is_embedding def decode_video_base64(video_base64): @@ -313,12 +298,14 @@ def decode_video_base64(video_base64): ) # Return an empty array and size tuple if no frames were found -def load_image(image_file): +def load_image(image_file: Union[str, bytes]): from PIL import Image image = image_size = None - if image_file.startswith("http://") or image_file.startswith("https://"): + if isinstance(image_file, bytes): + image = Image.open(BytesIO(image_file)) + elif image_file.startswith("http://") or image_file.startswith("https://"): timeout = int(os.getenv("REQUEST_TIMEOUT", "3")) response = requests.get(image_file, timeout=timeout) image = Image.open(BytesIO(response.content)) @@ -330,8 +317,10 @@ def load_image(image_file): elif image_file.startswith("video:"): image_file = image_file.replace("video:", "") image, image_size = decode_video_base64(image_file) - else: + elif isinstance(image_file, str): image = Image.open(BytesIO(base64.b64decode(image_file))) + else: + raise ValueError(f"Invalid image: {image}") return image, image_size @@ -348,7 +337,7 @@ def suppress_other_loggers(): logging.WARN ) logging.getLogger("vllm.selector").setLevel(logging.WARN) - logging.getLogger("vllm.utils").setLevel(logging.WARN) + logging.getLogger("vllm.utils").setLevel(logging.ERROR) def assert_pkg_version(pkg: str, min_version: str, message: str): @@ -370,14 +359,11 @@ def kill_parent_process(): """Kill the parent process and all children of the parent process.""" current_process = psutil.Process() parent_process = current_process.parent() - children = parent_process.children(recursive=True) - for child in children: - if child.pid != current_process.pid: - os.kill(child.pid, 9) - os.kill(parent_process.pid, 9) + kill_child_process(parent_process.pid, skip_pid=current_process.pid) -def kill_child_process(pid, including_parent=True): +def kill_child_process(pid, including_parent=True, skip_pid=None): + """Kill the process and all its children process.""" try: parent = psutil.Process(pid) except psutil.NoSuchProcess: @@ -385,6 +371,8 @@ def kill_child_process(pid, including_parent=True): children = parent.children(recursive=True) for child in children: + if child.pid == skip_pid: + continue try: child.kill() except psutil.NoSuchProcess: @@ -419,7 +407,6 @@ def monkey_patch_vllm_dummy_weight_loader(): DummyModelLoader, LoRAConfig, ModelConfig, - MultiModalConfig, ParallelConfig, SchedulerConfig, _initialize_model, @@ -434,7 +421,6 @@ def load_model( model_config: ModelConfig, device_config: DeviceConfig, lora_config: Optional[LoRAConfig], - multimodal_config: Optional[MultiModalConfig], parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, cache_config: CacheConfig, @@ -445,7 +431,6 @@ def load_model( model_config, self.load_config, lora_config, - multimodal_config, cache_config, ) @@ -453,10 +438,6 @@ def load_model( quant_method = getattr(module, "quant_method", None) if quant_method is not None: quant_method.process_weights_after_loading(module) - # FIXME: Remove this after Mixtral is updated - # to use quant_method. - if hasattr(module, "process_weights_after_loading"): - module.process_weights_after_loading() # NOTE(woosuk): For accurate performance evaluation, we assign # random values to the weights. @@ -644,7 +625,7 @@ def set_ulimit(target_soft_limit=65535): logger.warn(f"Fail to set RLIMIT_NOFILE: {e}") -def is_llama3_405b_fp8(model_config): +def is_llama3_405b_fp8_head_16(model_config): """Return whether the model is meta-llama/Meta-Llama-3.1-405B-FP8 with 16 kv heads.""" if ( model_config.hf_config.architectures[0] == "LlamaForCausalLM" @@ -693,7 +674,7 @@ def weight_loader_srt( setattr(QKVParallelLinear, "weight_loader", weight_loader_srt) -def add_api_key_middleware(app, api_key): +def add_api_key_middleware(app, api_key: str): @app.middleware("http") async def authentication(request, call_next): if request.method == "OPTIONS": @@ -705,7 +686,7 @@ async def authentication(request, call_next): return await call_next(request) -def prepare_model(model_path): +def prepare_model(model_path: str): if "SGLANG_USE_MODELSCOPE" in os.environ: if not os.path.exists(model_path): from modelscope import snapshot_download @@ -714,7 +695,7 @@ def prepare_model(model_path): return model_path -def prepare_tokenizer(tokenizer_path): +def prepare_tokenizer(tokenizer_path: str): if "SGLANG_USE_MODELSCOPE" in os.environ: if not os.path.exists(tokenizer_path): from modelscope import snapshot_download @@ -723,3 +704,13 @@ def prepare_tokenizer(tokenizer_path): tokenizer_path, ignore_patterns=["*.bin", "*.safetensors"] ) return tokenizer_path + + +def configure_logger(server_args, prefix: str = ""): + format = f"[%(asctime)s{prefix}] %(message)s" + logging.basicConfig( + level=getattr(logging, server_args.log_level.upper()), + format=format, + datefmt="%H:%M:%S", + force=True, + ) diff --git a/python/sglang/test/long_prompt.txt b/python/sglang/test/long_prompt.txt new file mode 100644 index 0000000000..301d7e107d --- /dev/null +++ b/python/sglang/test/long_prompt.txt @@ -0,0 +1 @@ +You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\nIntroduction\n\nThroughout U.S. history, Congress has created advisory commissions to assist in the development of public policy. Among other contexts, commissions have been used following crisis situations, including the September 11, 2001, terrorist attacks and the 2008 financial crisis. In such situations, advisory commissions may potentially provide Congress with a high-visibility forum to assemble expertise that might not exist within the legislative environment; allow for the in-depth examination of complex, cross-cutting policy issues; and lend bipartisan credibility to a set of findings and recommendations.\nAs Congress considers its range of responses to the coronavirus pandemic, the creation of one or more congressional advisory commissions is an option that could provide a platform for evaluating various pandemic-related policy issues over time. Past congressional advisory commissions have retrospectively evaluated policy responses, brought together diverse groups of experts, and supplemented existing congressional oversight mechanisms. Policymakers may determine that creating an advisory commission is unnecessary and instead prefer to utilize existing congressional oversight structures, such as standing or select committees, or already established oversight entities.\nThis report provides a comparative analysis of five proposed congressional advisory commissions that would investigate various aspects of the COVID-19 pandemic. The five proposed commissions are found in H.R. 6429 (the National Commission on COVID-19 Act, sponsored by Representative Stephanie Murphy), H.R. 6431 (the Made in America Emergency Preparedness Act, sponsored by Representative Brian Fitzpatrick), H.R. 6440 (the Pandemic Rapid Response Act, sponsored by Representative Rodney Davis), H.R. 6455 (the COVID-19 Commission Act, sponsored by Representative Bennie Thompson), and H.R. 6548 (the National Commission on the COVID-19 Pandemic in the United States Act, sponsored by Representative Adam Schiff). The overall structures of each of the proposed commissions are similar in many respects, both to each other and to previous independent advisory entities established by Congress. Specifically, the proposed commissions would (1) exist temporarily; (2) serve in an advisory capacity; and (3) report a work product detailing the commission\'s findings, conclusions, and recommendations. That said, each particular proposed commission has distinctive elements, particularly concerning its membership structure, appointment structure, and time line for reporting its work product to Congress.\nThis report compares the (1) membership structure, (2) appointment structure, (3) rules of procedure and operation, (4) duties and reporting requirements, (5) powers of the commission, (6) staffing issues, and (7) funding for each of the proposed COVID-19 commissions. Table 1 (at the end of this report) provides a side-by-side comparison of major provisions of the five proposals.\n\n Membership Structure\n\nSeveral matters related to a commission\'s membership structure might be considered. They include the size of a commission, member qualifications, compensation of commission members, and requirements for partisan balance. \n\n Size of Commission\n\nIn general, there is significant variation in the size of congressional advisory commissions. Among 155 identified congressional commissions created between the 101 st Congress and the 115 th Congress, the median size was 12 members, with the smallest commission having 5 members and the largest 33 members.\nThe membership structure of each of the five proposed commissions is similar to previous independent advisory entities created by Congress. H.R. 6429 , H.R. 6431 , H.R. 6440 , and H.R. 6548 would each create a 10-member entity. H.R. 6455 would create a 25-member entity.\n\n Qualifications\n\nPast legislation creating congressional commissions has often required or suggested that commission members possess certain substantive qualifications. Such provisions arguably make it more likely that the commission is populated with genuine experts in the policy area, which may improve the commission\'s final work product.\nH.R. 6455 would provide that commissioners \"shall be a United States person with significant expertise\" in a variety of fields related to public health and public administration. H.R. 6440 , H.R. 6429 , H.R. 6431 , and H.R. 6548 would provide \"the sense of Congress\" that commission members should be \"prominent U.S. citizens\" who are nationally recognized experts in a variety of fields relevant to the pandemic and response efforts. In addition, H.R. 6429 , H.R. 6431 , H.R. 6440 , and H.R. 6548 all prohibit the appointment of federal, state, and local government employees and officers. H.R. 6455 would prohibit federal employees from being commission members.\n\n Compensation of Commission Members\n\nSome congressional commissions have compensated their members. For example, the National Commission on Terrorist Attacks Upon the United States (9/11 Commission) and the Financial Crisis Inquiry Commission provided that commission members could be compensated at a daily rate of basic pay. Nearly all have reimbursed members for travel expenses. Those that have provided for commissioner compensation most frequently provided compensation at the daily equivalent of level IV of the Executive Schedule.\nEach of the five proposals would provide that commission members be compensated at a rate \"not to exceed the daily equivalent of the annual rate of basic pay\" for level IV of the Executive Schedule, \"for each day during which that member is engaged in the actual performance of duties of the Commission.\" Members of three proposed commissions would receive travel expenses, including a per diem.\n\n Partisan Limitations\n\nEach proposal provides a limit on the number of members appointed from the same political party. H.R. 6455 would provide that not more than 13 of its 25 members may be from the same party. H.R. 6429 , H.R. 6431 , H.R. 6440 , and H.R. 6548 would provide that not more than 5 (of 10) members are from the same party. Most previous advisory entities created by Congress do not impose formal partisan restrictions on the membership structure. It may also be difficult to assess the political affiliation of potential members, who may have no formal affiliation (voter registration, for example) with a political party. Instead, most past advisory commissions usually achieve partisan balance through the appointment structure; for instance, by providing equal (or near-equal) numbers of appointments to congressional leaders of each party.\n\n Appointment Structure\n\nPast congressional commissions have used a wide variety of appointment structures. Considerations regarding appointment structures include partisan balance, filling vacancies, and the time line for making commission appointments.\nThe statutory scheme may directly designate members of the commission, such as a specific cabinet official or a congressional leader. In other cases, selected congressional leaders, often with balance between the parties, appoint commission members. A third common statutory scheme is to have selected leaders, such as committee chairs and ranking members, recommend candidates for appointment to a commission. These selected leaders may act either in parallel or jointly, and the recommendation may be made either to other congressional leaders, such as the Speaker of the House and President pro tempore of the Senate, or to the President.\nEach of the five commission proposals would delegate most or all appointment authority to congressional leaders (including chamber, party, and committee leaders; see Table 1 for details). Additionally, H.R. 6429 , H.R. 6431 , H.R. 6440 , and H.R. 6548 provide for one appointment to be made by the President. H.R. 6429 , H.R. 6431 , and H.R. 6548 would have the President appoint the commission\'s chair. H.R. 6455 has its membership appointed by the chairs and ranking members of designated House and Senate committees, and the Joint Economic Committee. H.R. 6455 does not provide any executive branch appointments.\nAttention to the proper balance between the number of members appointed by congressional leaders and by other individuals (such as the President), or to the number of Members of Congress required to be among the appointees, or to the qualifications of appointees, can be significant factors in enabling a commission to fulfill its congressional mandate.\nIn general, a commission\'s appointment scheme can impact both the commission\'s ability to fulfill its statutory duties and its final work product. For instance, if the scheme provides only for the appointment of Members of Congress to the commission, it arguably might not have the technical expertise or diversity of knowledge to complete its duties within the time given by statute. Similarly, if the appointment scheme includes qualifying provisos so specific that only a small set of private citizens could serve on the panel, the commission\'s final work product may arguably only represent a narrow range of viewpoints. None of the proposed COVID-19 commissions specify whether Members of Congress may serve on the commission.\n\n Partisan Balance in Appointment Authority\n\nMost previous congressional advisory commissions have been structured to be bipartisan, with an even (or near-even) split of appointments between leaders of the two major parties. By achieving a nonpartisan or bipartisan character, congressional commissions may make their findings and recommendations more politically acceptable to diverse viewpoints. The bipartisan or nonpartisan arrangement can give recommendations strong credibility, both in Congress and among the public, even when dealing with divisive public policy issues. Similarly, commission recommendations that are perceived as partisan may have difficulty gaining support in Congress.\nIn some cases, however, bipartisanship also can arguably impede a commission\'s ability to complete its mandate. In situations where a commission is tasked with studying divisive or partisan issues, the appointment of an equal number of majority and minority commissioners may serve to promote partisanship within the commission rather than suppress it, raising the possibility of deadlock where neither side can muster a majority to act.\nEach of the five proposals employs a structure where leaders in both the majority and minority parties in Congress would make appointments. H.R. 6429 , H.R. 6431 , and H.R. 6548 would provide for five majority and five minority appointments, including one for the President. H.R. 6440 would include two each by the Senate majority leader, the Senate minority leader, and the Speaker of the House, with one appointment by the House minority leader and one by the President, and the chair appointed by the Speaker and vice chair appointed by the Senate majority leader. H.R. 6455 would have 12 majority and 12 minority appointments made by the 12 committee chairs and ranking members and one member jointly appointed by the chair and vice chair of the Joint Economic Committee.\n\n Vacancies\n\nAll five proposals provide that vacancies on the commission will not affect its powers and would be filled in the same manner as the original appointment.\n\n Deadline for Appointments\n\nThree of the bills propose specific deadlines for the appointment of commissioners. H.R. 6429 and H.R. 6548 provide that appointments are made between specific dates in January or February 2021. Further, H.R. 6429 provides that commission members could be appointed in September 2020, if there is no longer a COVID-19 public health emergency in effect—as determined by the Secretary of Health and Human Services—as of August 31, 2020. H.R. 6440 would require all appointments be made by December 15, 2020. H.R. 6455 would require appointments to be made within 45 days after enactment. H.R. 6429 , H.R. 6440 , and H.R. 6548 would start the commission\'s work in early 2021, as the commission cannot operate without the appointment of members. H.R. 6429 , however would provide that the proposed commission\'s work would begin no later than October 31, 2020, if members are appointed in September 2020. H.R. 6431 does not specify a deadline for the appointment of members.\nTypically, deadlines for appointment can range from several weeks to several months. For example, the deadline for appointments to the Antitrust Modernization Commission was 60 days after the enactment of its establishing act. The deadline for appointment to the Commission on Wartime Contracting in Iraq and Afghanistan was 120 days from the date of enactment. The deadline for appointment to the 9/11 Commission was December 15, 2002, 18 days after enactment of the act.\n\n Rules of Procedure and Operations\n\nWhile most statutes that authorize congressional advisory commissions do not provide detailed procedures for how the commission should conduct its business, the statutory language may provide a general structure, including a mechanism for selecting a chair and procedures for creating rules. None of the five COVID-19 commission proposals contain language that directs the process for potentially adopting rules of procedure. For a comparison of each proposed commission\'s specified rules of procedures and operations, see Table 1 .\n\n Chair Selection\n\nEach bill provides for the selection of a chair and/or vice chair of the commission. H.R. 6429 , H.R. 6431 , and H.R. 6548 would have the chair appointed by the President and the vice chair appointed by congressional leaders of the political party opposite the President. H.R. 6440 would have the chair appointed by the Speaker of the House (in consultation with the Senate majority leader and the House minority leader) and the vice chair appointed by the Senate majority leader (in consultation with the Speaker of the House and the Senate minority leader). H.R. 6455 would have the chair and vice chair chosen from among commission members by a majority vote of the commission, and would require the chair and vice chair to have \"significant experience\" in areas to be studied by the commission.\n\n Initial Meeting Deadline\n\nAs with the timing of commission appointments, some authorizing statutes are prescriptive in when the commission\'s first meeting should take place. Three of the bills analyzed here provide specific time lines for the commission\'s first meeting. H.R. 6429 would require the first meeting to be no later than March 15, 2021, unless members are appointed in September 2020 (if no public health emergency exists). H.R. 6455 would require the first meeting within 45 days after the appointment of all commission members, which is—given the 45-day deadline for appointment—effectively a maximum of 90 days after enactment. H.R. 6548 would direct the commission to hold its initial meeting \"as soon as practicable,\" but not later than March 5, 2021. H.R. 6431 and H.R. 6440 do not provide for an initial meeting deadline. Instead, they direct the commission to meet \"as soon as practicable.\" \n\n Quorum\n\nMost commission statutes provide that a quorum will consist of a particular number of commissioners, usually a majority, but occasionally a supermajority. All five bills would provide for a quorum requirement. H.R. 6429 , H.R. 6431 , H.R. 6440 , and H.R. 6548 would define a quorum as 6 (of 10) members. H.R. 6455 would provide that a quorum is 18 of 25 members (72%).\n\n Public Access\n\nAll five commission bills would require commission meetings to be open to the public. Each bill would also require that reports be made publicly available.\n\n Formulating Other Rules of Procedure and Operations\n\nAbsent statutory guidance (eithe r in general statutes or in individual statutes authorizing commissions), advisory entities vary widely in how they adopt their rules of procedure. In general, three models exist: formal written rules, informal rules, and the reliance on norms. Any individual advisory entity might make use of all three of these models for different types of decisionmaking. \nThe choice to adopt written rules or rely on informal norms to guide commission procedure may be based on a variety of factors, such as the entity\'s size, the frequency of meetings, member preferences regarding formality, the level of collegiality among members, and the amount of procedural guidance provided by the entity\'s authorizing statute. Regardless of how procedural issues are handled, protocol for decisionmaking regarding the following operational issues may be important for the commission to consider at the outset of its existence: eligibility to vote and proxy rules; staff hiring, compensation, and work assignments; hearings, meetings, and field visits; nonstaff expenditures and contracting; reports to Congress; budgeting; and procedures for future modification of rules. None of the five COVID-19 commission proposals specify that the proposed commission must adopt written rules.\n\n FACA Applicability\n\nThe Federal Advisory Committee Act (FACA) mandates certain structural and operational requirements, including formal reporting and oversight procedures, for certain federal advisory bodies that advise the executive branch. Three proposals ( H.R. 6429 , H.R. 6431 , and H.R. 6548 ) specifically exempt the proposed commission from FACA. Of the remaining two, FACA would also likely not apply to the commission proposed in H.R. 6455 because it would be appointed entirely by Members of Congress, although it only specifies that its final report is public, not whether it is specifically sent to Congress and/or the President. It is not clear that FACA would apply to the commission proposed in H.R. 6440 . Although it includes a presidential appointment and its report would be sent to both Congress and the President, its establishment clause specifies that the commission \"is established in the legislative branch,\" and a super-majority of its members would be appointed by Congress.\n\n Duties and Reporting Requirements\n\nMost congressional commissions are generally considered policy commissions—temporary bodies that study particular policy problems and report their findings to Congress or review a specific event. \n\n General Duties\n\nAll five of the proposed commissions would be tasked with duties that are analogous to those of past policy commissions. While the specific mandates differ somewhat, all proposed commissions are tasked with investigating aspects of the COVID-19 pandemic and submitting one or more reports that include the commission\'s findings, conclusions, and recommendations for legislative action. H.R. 6440 would specifically require the commission to avoid unnecessary duplication of work being conducted by the Government Accountability Office (GAO), congressional committees, and executive branch agency and independent commission investigations.\n\n Reports\n\nEach proposed commission would be tasked with issuing a final report detailing its findings, conclusions, and recommendations. H.R. 6429 , H.R. 6431 , H.R. 6440 , and H.R. 6548 would provide that the commission \"may submit\" interim reports to Congress and the President, but do not provide time lines on when those reports might be submitted. In each case, the interim report would need to be agreed to by a majority of commission members. H.R. 6431 would also require the commission to submit a report on actions taken by the states and a report on essential products, materials, ingredients, and equipment required to fight pandemics.\nH.R. 6429 , H.R. 6431 , H.R. 6440 , and H.R. 6548 also specify that final reports shall be agreed to by a majority of commission members. H.R. 6455 does not specify a vote threshold for approval of its report.\nNone of the bills make specific provisions for the inclusion of minority viewpoints. Presumably this would leave each commission with discretion on whether to include or exclude minority viewpoints. Past advisory entities have been proposed or established with a variety of statutory reporting conditions, including the specification of majority or super-majority rules for report adoption and provisions requiring the inclusion of minority viewpoints. In practice, advisory bodies that are not given statutory direction on these matters have tended to work under simple-majority rules for report adoption.\n\n Report Deadlines\n\nH.R. 6429 would require a final report one year after the commission\'s initial meeting. H.R. 6431 and H.R. 6440 would require a final report not later than 18 months after enactment. H.R. 6455 would require a final report to be published not later than 18 months after the commission\'s first meeting. \nH.R. 6548 would require a final report by October 15, 2021. This deadline could be extended by 90 days upon a vote of no fewer than 8 (out of 10) commission members. The commission could vote to extend its final report deadline up to three times, and would be required to notify Congress, the President, and the public of any such extension.\nWhile such a deadline would potentially give the commission a defined period of time to complete its work, setting a particular date for report completion could potentially create unintended time constraints. Any delay in the passage of the legislation or in the appointment process would reduce the amount of time the commission has to complete its work, even with the opportunity for the commission to extend its own deadline up to three times.\nThe length of time a congressional commission has to complete its work is arguably one of the most consequential decisions when designing an advisory entity. If the entity has a short window of time, the quality of its work product may suffer or it may not be able to fulfill its statutory mandate on time.\nOn the other hand, if the commission is given a long period of time to complete its work, it may undermine one of a commission\'s primary legislative advantages, the timely production of expert advice on a current matter. A short deadline may also affect the process of standing up a new commission. The selection of commissioners, recruitment of staff, arrangement of office space, and other logistical matters may require expedited action if short deadlines need to be met.\n\n Report Submission\n\nOf the five proposed commissions, four ( H.R. 6429 , H.R. 6431 , H.R. 6440 , and H.R. 6548 ) are directed to submit their reports to both Congress and the President. H.R. 6455 requires that the report is made public.\nMost congressional advisory commissions are required to submit their reports to Congress, and sometimes to the President or an executive department or agency head. For example, the National Commission on Severely Distressed Public Housing\'s final report was submitted to both Congress and the Secretary of Housing and Urban Development.\n\n Commission Termination\n\nCongressional commissions are usually statutorily mandated to terminate. Termination dates for most commissions are linked to either a fixed period of time after the establishment of the commission, the selection of members, or the date of submission of the commission\'s final report. Alternatively, some commissions are given fixed calendar termination dates.\nAll five commission proposals would provide for the commission to terminate within a certain period of time following submission of its final report. H.R. 6429 , H.R. 6431 , H.R. 6440 , and H.R. 6455 would each direct the commission to terminate 60 days after the submission; H.R. 6548 specifies a time line of 90 days after submission.\n\n Commission Powers\n\nEach of the five proposals would provide the proposed commission with certain powers to carry out its mission (see Table 1 for specifics). One general issue for commissions is who is authorized to execute such powers. In some cases, the commission itself executes its powers, with the commission deciding whether to devise rules and procedures for the general use of such power. In other cases, the legislation specifically authorizes the commission to give discretionary power to subcommittees or individual commission members. Finally, the legislation itself might grant certain powers to individual members of the commission, such as the chair.\n\n Hearings and Evidence\n\nAll five bills would provide the proposed commission with the power to hold hearings, take testimony, and receive evidence. All five commissions would also be provided the power to administer oaths to witnesses.\n\n Subpoenas\n\nFour of the bills would provide the commission with subpoena power. H.R. 6440 would not provide subpoena power to the commission. H.R. 6429 , H.R. 6431 , and H.R. 6548 would provide that subpoenas could only be issued by either (1) agreement of the chair and vice chair, or (2) the affirmative vote of 6 (of 10) commission members. H.R. 6455 would require that a subpoena could only be issued by either agreement of the chair and vice chair or an affirmative vote of 18 (of 25) commission members. All four bills that would provide subpoena power contain substantially similar judicial methods of subpoena enforcement.\n\n Administrative Support\n\nAll five of the bills would provide that the commission receive administrative support from the General Services Administration (GSA). The GSA provides administrative support to dozens of federal entities, including congressional advisory commissions. Each of the five bills would provide that GSA be reimbursed for its services by the commission. Each bill also provides that other departments or agencies may provide funds, facilities, staff, and other services to the commission.\n\n Other Powers\n\nWithout explicit language authorizing certain activities, commissions often cannot gather information, enter into contracts, use the U.S. mail like an executive branch entity, or accept donations or gifts. \nAll five bills direct that federal agencies provide information to the commission upon request. H.R. 6429 , H.R. 6431 , and H.R. 6548 would also provide that the commission could use the U.S. mails in the same manner as any department or agency, enter into contracts, and accept gifts or donations of services or property.\n\n Staffing\n\nThe proposed COVID-19 commissions contain staffing provisions commonly found in congressional advisory commission legislation. Congressional advisory commissions are usually authorized to hire staff. Most statutes specify that the commission may hire a lead staffer, often referred to as a \"staff director,\" \"executive director,\" or another similar title, in addition to additional staff as needed. Rather than mandate a specific staff size, many commissions are instead authorized to appoint a staff director and other personnel as necessary, subject to the limitations of available funds.\nMost congressional commissions are also authorized to hire consultants, procure intermittent services, and request that federal agencies detail personnel to aid the work of the commission.\n\n Director and Commission Staff\n\nFour of the bills provide that the commission may hire staff without regard to certain laws regarding the competitive service; H.R. 6440 does not specifically exempt the commission from such laws. Four bills ( H.R. 6429 , H.R. 6431 , H.R. 6455 , and H.R. 6548 ) would authorize, but not require, the commission to hire a staff director and additional staff, as appropriate. Four proposals would limit staff salaries to level V of the executive schedule. Three of the bills would specifically designate staff as federal employees for the purposes of certain laws, such as workman\'s compensation, retirement, and other benefits.\n\n Detailees\n\nWhen authorized, some commissions can have federal agency staff detailed to the commission. All five bills would provide that federal employees could be detailed to the commission. Four bills would provide that the detailee would be without reimbursement to his or her home agency. H.R. 6440 would allow detailees on a reimbursable basis. \n\n Experts and Consultants\n\nAll five bills would provide the commission with the authority to hire experts and consultants. Four of the bills limit the rate of pay for consultants to level IV of the Executive Schedule. H.R. 6440 does not specify a specific limit.\n\n Security Clearances\n\nFour bills would provide that federal agencies and departments shall cooperate with the commission to provide members and staff appropriate security clearances. H.R. 6440 does not contain a security clearance provision.\n\n Funding and Costs\n\nCommissions generally require funding to help meet their statutory goals. When designing a commission, therefore, policymakers may consider both how the commission will be funded, and how much funding the commission will be authorized to receive. Four of the five proposals specify a funding mechanism for the commission.\nHow commissions are funded and the amounts that they receive vary considerably. Several factors can contribute to overall commission costs. These factors might include the cost of hiring staff, contracting with outside consultants, and engaging administrative support, among others. Additionally, most commissions reimburse the travel expenditures of commissioners and staff, and some compensate their members. The duration of a commission can also significantly affect its cost; past congressional commissions have been designed to last anywhere from several months to several years.\n\n Costs\n\nIt is difficult to estimate or predict the potential overall cost of any commission. Annual budgets for congressional advisory entities range from several hundred thousand dollars to millions of dollars annually. Overall expenses for any individual advisory entity depend on a variety of factors, the most important of which are the number of paid staff and the commission\'s duration and scope. Some commissions have few full-time staff; others employ large numbers, such as the National Commission on Terrorist Attacks Upon the United States, which had a full-time paid staff of nearly 80. Secondary factors that can affect commission costs include the number of commissioners, how often the commission meets or holds hearings, whether or not the commission travels or holds field hearings, and the publications the commission produces.\n\n Authorized Funding\n\nThree of the bills ( H.R. 6429 , H.R. 6440 , and H.R. 6548 ) would authorize the appropriation of \"such sums as may be necessary\" for the commission, to be derived in equal amounts from the contingent fund of the Senate and the applicable accounts of the House of Representatives. H.R. 6429 and H.R. 6548 would provide that funds are available until the commission terminates. H.R. 6455 would authorize the appropriation of $4 million for the commission, to remain available until the commission terminates. H.R. 6431 does not include an authorization of appropriations.\n\n Comparison of Proposals to Create a COVID-19 Commission\n\n Table 1 provides a side-by-side comparison of major provisions of the five proposals. For each bill, the membership structure, appointment structure, rules of procedure and operation, duties and reporting requirements, proposed commission powers, staffing provisions, and funding are compared.\n\nSummary:\n \ No newline at end of file diff --git a/python/sglang/test/run_eval.py b/python/sglang/test/run_eval.py index 6c1f284b16..51b32ca01b 100644 --- a/python/sglang/test/run_eval.py +++ b/python/sglang/test/run_eval.py @@ -16,6 +16,8 @@ def run_eval(args): + set_ulimit() + if "OPENAI_API_KEY" not in os.environ: os.environ["OPENAI_API_KEY"] = "EMPTY" @@ -39,6 +41,14 @@ def run_eval(args): eval_obj = MathEval( filename, equality_checker, args.num_examples, args.num_threads ) + elif args.eval_name == "mgsm": + from sglang.test.simple_eval_mgsm import MGSMEval + + eval_obj = MGSMEval(args.num_examples, args.num_threads) + elif args.eval_name == "mgsm_en": + from sglang.test.simple_eval_mgsm import MGSMEval + + eval_obj = MGSMEval(args.num_examples, args.num_threads, languages=["en"]) elif args.eval_name == "gpqa": from sglang.test.simple_eval_gpqa import GPQAEval @@ -109,7 +119,6 @@ def run_eval(args): parser.add_argument("--eval-name", type=str, default="mmlu") parser.add_argument("--num-examples", type=int) parser.add_argument("--num-threads", type=int, default=512) - set_ulimit() args = parser.parse_args() run_eval(args) diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index e5ad3ea9d3..ac69ab875b 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -14,7 +14,8 @@ """ import json -import multiprocessing +import multiprocessing as mp +import os from dataclasses import dataclass from typing import List, Union @@ -23,16 +24,22 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from sglang.srt.server import Runtime -from sglang.srt.utils import is_generation_model +from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER DEFAULT_PROMPTS = [ # the output of gemma-2-2b from SRT is unstable on the commented prompt # "The capital of France is", - "The capital of the United Kindom is", + "Apple is red. Banana is Yellow. " * 800 + "Apple is", + "The capital of the United Kingdom is", "Today is a sunny day and I like", "AI is a field of computer science focused on", ] +dirpath = os.path.dirname(__file__) +with open(os.path.join(dirpath, "long_prompt.txt"), "r") as f: + long_prompt = f.read() +DEFAULT_PROMPTS.append(long_prompt) + NUM_TOP_LOGPROBS = 5 @@ -56,44 +63,37 @@ class HFRunner: def __init__( self, model_path, - torch_dtype=torch.float16, - is_generation_model=None, + torch_dtype, + is_generation, ): - self.in_queue = multiprocessing.Queue() - self.out_queue = multiprocessing.Queue() + self.is_generation = is_generation + + self.in_queue = mp.Queue() + self.out_queue = mp.Queue() - self.model_proc = multiprocessing.Process( + self.model_proc = mp.Process( target=self.start_model_process, args=( self.in_queue, self.out_queue, model_path, torch_dtype, - is_generation_model, ), ) self.model_proc.start() - def start_model_process( - self, in_queue, out_queue, model_path, torch_dtype, is_generation_model - ): + def start_model_process(self, in_queue, out_queue, model_path, torch_dtype): self.tokenizer = AutoTokenizer.from_pretrained( model_path, torch_dtype=torch_dtype, - trust_remote_code=True, ) - self.is_generation_model = ( - is_generation_model(model_path) - if is_generation_model is None - else is_generation_model - ) - if self.is_generation_model: + if self.is_generation: self.model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch_dtype, + trust_remote_code=False, low_cpu_mem_usage=True, - trust_remote_code=True, ).cuda() else: from sentence_transformers import SentenceTransformer @@ -106,7 +106,7 @@ def start_model_process( while True: prompts, max_new_tokens = in_queue.get() if prompts is not None: - if self.is_generation_model: + if self.is_generation: output_strs = [] prefill_logprobs = [] for p in prompts: @@ -125,16 +125,14 @@ def start_model_process( ) logits = self.model.forward(input_ids).logits[0] - logprobs = F.log_softmax( - logits, dim=-1, dtype=torch.float32 - ).tolist() - # index_of_max = (lambda nums: nums.index(max(nums)))(logprobs[-1]) - # print("index", index_of_max) - logprobs = [ - sorted(token_logprobs, reverse=True)[:NUM_TOP_LOGPROBS] - for token_logprobs in logprobs - ] - prefill_logprobs.append(logprobs) + logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) + logprobs, top_indices = torch.topk( + logprobs, k=NUM_TOP_LOGPROBS, dim=-1 + ) + # print("index", top_indices) + prefill_logprobs.append(logprobs.tolist()) + del logits + del logprobs out_queue.put( ModelOutput( @@ -171,19 +169,20 @@ class SRTRunner: def __init__( self, model_path, + torch_dtype, + is_generation, tp_size=1, - torch_dtype=torch.float16, - is_generation_model=None, + port=DEFAULT_PORT_FOR_SRT_TEST_RUNNER, ): - self.is_generation_model = ( - is_generation_model(model_path) - if is_generation_model is None - else is_generation_model - ) + self.is_generation = is_generation self.runtime = Runtime( model_path=model_path, tp_size=tp_size, dtype=get_dtype_str(torch_dtype), + port=port, + mem_fraction_static=0.69, + trust_remote_code=False, + is_embedding=not self.is_generation, ) def forward( @@ -191,7 +190,7 @@ def forward( prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS, max_new_tokens=8, ): - if self.is_generation_model: + if self.is_generation: # the return value contains logprobs from prefill output_strs = [] top_input_logprobs = [] @@ -201,6 +200,7 @@ def forward( prompt, sampling_params=sampling_params, return_logprob=True, + logprob_start_len=0, top_logprobs_num=NUM_TOP_LOGPROBS, ) response = json.loads(response) diff --git a/python/sglang/test/simple_eval_common.py b/python/sglang/test/simple_eval_common.py index 4cfd3515fe..d97d84de93 100644 --- a/python/sglang/test/simple_eval_common.py +++ b/python/sglang/test/simple_eval_common.py @@ -1,13 +1,12 @@ # Adapted from https://github.com/openai/simple-evals/ -import base64 import os import resource import time from collections import defaultdict from dataclasses import dataclass, field from multiprocessing.pool import ThreadPool -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple import httpx import jinja2 @@ -44,8 +43,8 @@ class EvalResult: Result of running an evaluation (usually consisting of many samples) """ - score: float | None # top-line metric - metrics: Dict[str, float] | None # other metrics + score: Optional[float] # top-line metric + metrics: Optional[Dict[str, float]] # other metrics htmls: List[str] # strings of valid HTML convos: List[MessageList] # sampled conversations @@ -56,10 +55,10 @@ class SingleEvalResult: Result of evaluating a single sample """ - score: float | None + score: Optional[float] metrics: Dict[str, float] = field(default_factory=dict) - html: str | None = None - convo: MessageList | None = None # sampled conversation + html: Optional[str] = None + convo: Optional[MessageList] = None # sampled conversation class Eval: @@ -89,8 +88,8 @@ class ChatCompletionSampler(SamplerBase): def __init__( self, base_url: str = None, - model: str | None = None, - system_message: str | None = None, + model: Optional[str] = None, + system_message: Optional[str] = None, temperature: float = 0.0, max_tokens: int = 2048, ): @@ -272,7 +271,7 @@ def _compute_stat(values: list, stat: str): def aggregate_results( single_eval_results: List[SingleEvalResult], default_stats: Tuple[str] = ("mean", "std"), - name2stats: Dict[str, Tuple[str]] | None = None, + name2stats: Optional[Dict[str, Tuple[str]]] = None, ) -> EvalResult: """ Aggregate results from multiple evaluations into a single EvalResult. diff --git a/python/sglang/test/simple_eval_gpqa.py b/python/sglang/test/simple_eval_gpqa.py index 46055caa5f..ec2abb4adc 100644 --- a/python/sglang/test/simple_eval_gpqa.py +++ b/python/sglang/test/simple_eval_gpqa.py @@ -8,6 +8,7 @@ import random import re +from typing import Optional import pandas @@ -28,7 +29,7 @@ class GPQAEval(Eval): def __init__( self, filename: str, - num_examples: int | None, + num_examples: Optional[int], num_threads: int, n_repeats: int = 1, ): diff --git a/python/sglang/test/simple_eval_humaneval.py b/python/sglang/test/simple_eval_humaneval.py index 7a0f90c467..b0ad79d413 100644 --- a/python/sglang/test/simple_eval_humaneval.py +++ b/python/sglang/test/simple_eval_humaneval.py @@ -6,21 +6,15 @@ https://arxiv.org/abs/2107.03374 https://github.com/openai/human-eval/ """ -import json -import logging -import multiprocessing import random import re -from collections import Counter, defaultdict from concurrent.futures import ThreadPoolExecutor, as_completed -from io import BytesIO -from typing import Any, Dict, List, Tuple +from typing import Dict, List, Optional -import blobfile as bf import tqdm try: - from human_eval.data import HUMAN_EVAL, read_problems + from human_eval.data import read_problems from human_eval.evaluation import estimate_pass_at_k from human_eval.execution import check_correctness # , unsafe_execute except (ImportError, ModuleNotFoundError): @@ -67,7 +61,7 @@ def evaluate_functional_correctness( class HumanEval(Eval): def __init__( self, - num_examples: int | None, + num_examples: Optional[int], num_threads: int, num_samples_per_task: int = 5, ks_passes: List[int] = [1, 2, 5], diff --git a/python/sglang/test/simple_eval_math.py b/python/sglang/test/simple_eval_math.py index 4ddb650d96..74c49abe51 100644 --- a/python/sglang/test/simple_eval_math.py +++ b/python/sglang/test/simple_eval_math.py @@ -8,6 +8,7 @@ import random import re +from typing import Optional import pandas @@ -36,7 +37,7 @@ def __init__( self, filename: str, equality_checker: SamplerBase, - num_examples: int | None, + num_examples: Optional[int], num_threads: int, ): df = pandas.read_csv(filename) diff --git a/python/sglang/test/simple_eval_mgsm.py b/python/sglang/test/simple_eval_mgsm.py new file mode 100644 index 0000000000..ce00a1ac76 --- /dev/null +++ b/python/sglang/test/simple_eval_mgsm.py @@ -0,0 +1,203 @@ +# Adapted from https://github.com/openai/simple-evals/ + +""" +MGSM: Multilingual Grade School Math Benchmark (MGSM) is a benchmark of grade-school math problems. +Language Models are Multilingual Chain-of-Thought Reasoners +Freda Shi, Mirac Suzgun, Markus Freitag, Xuezhi Wang, Suraj Srivats, Soroush Vosoughi, Hyung Won Chung, Yi Tay, Sebastian Ruder, Denny Zhou, Dipanjan Das, Jason Wei +https://arxiv.org/abs/2210.03057 reference: https://github.com/google-research/url-nlp +""" + +import re +import urllib +from typing import Optional + +from sglang.test import simple_eval_common as common +from sglang.test.simple_eval_common import ( + HTML_JINJA, + Eval, + EvalResult, + SamplerBase, + SingleEvalResult, +) + +ALL_LANGUAGES = ["bn", "de", "en", "es", "fr", "ja", "ru", "sw", "te", "th", "zh"] +LATIN_LANGUAGES = ["de", "en", "es", "fr", "sw"] +NON_LATIN_LANGUAGES = ["bn", "ja", "ru", "te", "th", "zh"] + +LANG_TO_FPATH = { + "bn": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_bn.tsv", + "de": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_de.tsv", + "en": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_en.tsv", + "es": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_es.tsv", + "fr": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_fr.tsv", + "ja": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_ja.tsv", + "ru": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_ru.tsv", + "sw": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_sw.tsv", + "te": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_te.tsv", + "th": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_th.tsv", + "zh": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_zh.tsv", +} +LANG_TO_INSTRUCTIONS = { + "en": """Solve this math problem. Give the reasoning steps before giving the final answer on the last line by itself in the format of "Answer:". Do not add anything other than the integer answer after "Answer:". + +{input}""", + "bn": """এই গণিতের সমস্যাটি সমাধান করুন। চূড়ান্ত উত্তর দেওয়ার আগে যুক্তিসম্পন্ন পদক্ষেপ প্রদান করুন। চূড়ান্ত উত্তরটি একক সংখ্যা হিসাবে "উত্তর:" এর পরে শেষ লাইনে দিন। "উত্তর:" এর পরে অন্য কিছু যুক্ত করবেন না।. + +{input}""", + "de": """Löse dieses Mathematikproblem. Gib die Schritte zur Begründung an, bevor du die endgültige Antwort in der letzten Zeile alleine im Format "Antwort:" gibst. Füge nichts anderes als die ganzzahlige Antwort nach "Antwort:" hinzu. + +{input}""", + "es": """Resuelve este problema matemático. Proporciona los pasos de razonamiento antes de dar la respuesta final en la última línea por sí misma en el formato de "Respuesta:". No añadas nada más que la respuesta entera después de "Respuesta:". + +{input}""", + "fr": """Résolvez ce problème de mathématiques. Donnez les étapes de raisonnement avant de fournir la réponse finale sur la dernière ligne elle-même dans le format de "Réponse:". N'ajoutez rien d'autre que la réponse entière après "Réponse:". + +{input}""", + "ja": """の数学の問題を解いてください。最終的な答えを出す前に、解答の推論過程を記述してください。そして最後の行には "答え:" の形式で答えを記述し、その後には整数の答え以外何も追加しないでください。 + +{input}""", + "ru": """Решите эту математическую задачу. Объясните шаги рассуждения перед тем, как дать окончательный ответ в последней строке сам по себе в формате "Ответ:". Не добавляйте ничего, кроме целочисленного ответа после "Ответ:". + +{input}""", + "sw": """Suluhisha tatizo hili la hesabu. Toa hatua za mantiki kabla ya kutoa jibu la mwisho kwenye mstari wa mwisho peke yake katika muundo wa "Jibu:". Usiongeze chochote kingine isipokuwa jibu la integer baada ya "Jibu:". + +{input}""", + "te": """ఈ గణిత సమస్యను పరిష్కరించండి. చివరి సమాధానాన్ని ఇవ్వదానికి ముందు తర్కాత్మక అదుగులను ఇవ్వండి. చివరి పంక్తిలో మాత్రమే 'సమాధానం:' అనే ఆకారంలో చివరి సమాధానాద్ని ఇవ్వండి సమాధానం: తర్వాత పూర్ణాంక సమాధానానికి తప్పించి ఎదేనా చేర్చవద్దు. + +{input}""", + "th": """แก้ปัญหาคณิตศาสตร์นี้ ให้ให้ขั้นตอนการใช้เหตุผลก่อนที่จะให้คำตอบสุดท้ายในบรรทัดสุดท้ายโดยอยู่ในรูปแบบ "คำตอบ:" ไม่ควรเพิ่มอะไรนอกจากคำตอบที่เป็นจำนวนเต็มหลังจาก "คำตอบ:" + +{input}""", + "zh": """解决这个数学问题。在最后一行给出答案前,请提供推理步骤。最后一行应该以 "答案: " 的形式独立给出答案。在 "答案:" 后不要添加除整数答案之外的任何内容。 + +{input}""", +} + +LANG_TO_ANSWER_PREFIX = { + "en": "Answer", + "bn": "উত্তর", + "de": "Antwort", + "es": "Respuesta", + "fr": "Réponse", + "ja": "答え", + "ru": "Ответ", + "sw": "Jibu", + "te": "సమాధానం", + "th": "คำตอบ", + "zh": "答案", +} + + +def parse_answer(answer: str, answer_prefix: str) -> str: + if answer_prefix not in answer: + return "" + + answer_text = answer.split(answer_prefix)[-1].strip() + + # find all the numbers (including decimals) in the string + numbers = re.findall(r"\d+\.?\d*", answer_text.replace(",", "")) + + # return the first number (removing trailing decimal point if present), + # or an empty string if there were no numbers + return numbers[-1].rstrip(".") if numbers else "" + + +def score_mgsm(target: str, prediction: str) -> bool: + if "." in prediction: + prediction = prediction.rstrip("0").rstrip(".") + + target = target.replace(",", "") + prediction = prediction.replace(",", "") + + return target == prediction + + +def get_lang_examples(lang: str) -> list[dict[str, str]]: + fpath = LANG_TO_FPATH[lang] + examples = [] + with urllib.request.urlopen(fpath) as f: + for line in f.read().decode("utf-8").splitlines(): + inputs, targets = line.strip().split("\t") + if "." in targets: + raise ValueError(f"targets {targets} contains a decimal point.") + # targets = int(targets.replace(",", "")) + examples.append({"inputs": inputs, "targets": targets, "lang": lang}) + return examples + + +def get_all_examples() -> list[dict[str, str]]: + examples = [] + for lang in ALL_LANGUAGES: + if lang != "en": + continue + examples += get_lang_examples(lang) + return examples + + +class MGSMEval(Eval): + def __init__( + self, + num_examples_per_lang: int = 250, # restrict to a subset of the data for debugging + num_threads: int = 64, + languages: Optional[list[str]] = ALL_LANGUAGES, + ): + if languages is None: + languages = ALL_LANGUAGES + else: + for language in languages: + if language not in ALL_LANGUAGES: + raise ValueError( + f"language {language} is not a valid language. " + f"It should be one in {ALL_LANGUAGES}" + ) + self._languages = languages + self._num_examples_per_lang = num_examples_per_lang + self._num_threads = num_threads + + examples = [] + for lang in self._languages: + lang_examples = get_lang_examples(lang) + examples.extend(lang_examples[: self._num_examples_per_lang]) + self.examples = examples + + def __call__(self, sampler: SamplerBase) -> EvalResult: + def fn(example: dict[str, str]): + language = example["lang"] + latin_language = ( + "group_latin" if language in LATIN_LANGUAGES else "group_non_latin" + ) + correct_answer = example["targets"] + instructoin = LANG_TO_INSTRUCTIONS[language] + prompt_messages = [ + sampler._pack_message( + content=instructoin.format(input=example["inputs"]), role="user" + ) + ] + try: + response_text = sampler(prompt_messages) + except Exception as e: + response_text = "" + + answer_prefix = LANG_TO_ANSWER_PREFIX[language] + extracted_answer = parse_answer(response_text, answer_prefix) + + score = score_mgsm(correct_answer, extracted_answer) + html = common.jinja_env.from_string(HTML_JINJA).render( + prompt_messages=prompt_messages, + next_message=dict(content=response_text, role="assistant"), + score=score, + correct_answer=correct_answer, + extracted_answer=extracted_answer, + ) + convo = prompt_messages + [dict(content=response_text, role="assistant")] + return SingleEvalResult( + html=html, + score=score, + convo=convo, + metrics={language: score, latin_language: score}, + ) + + results = common.map_with_progress( + fn, self.examples, num_threads=self._num_threads + ) + return common.aggregate_results(results, default_stats=("mean", "std")) diff --git a/python/sglang/test/simple_eval_mmlu.py b/python/sglang/test/simple_eval_mmlu.py index 3c0287510c..36a5c7fe35 100644 --- a/python/sglang/test/simple_eval_mmlu.py +++ b/python/sglang/test/simple_eval_mmlu.py @@ -8,6 +8,7 @@ import random import re +from typing import Optional import pandas @@ -84,7 +85,7 @@ class MMLUEval(Eval): - def __init__(self, filename: str, num_examples: int | None, num_threads: int): + def __init__(self, filename: str, num_examples: Optional[int], num_threads: int): df = pandas.read_csv(filename) examples = [row.to_dict() for _, row in df.iterrows()] if num_examples: diff --git a/python/sglang/test/test_activation.py b/python/sglang/test/test_activation.py new file mode 100644 index 0000000000..357a23319b --- /dev/null +++ b/python/sglang/test/test_activation.py @@ -0,0 +1,55 @@ +import itertools +import unittest + +import torch + +from sglang.srt.layers.activation import GeluAndMul + + +class TestGeluAndMul(unittest.TestCase): + DTYPES = [torch.half, torch.bfloat16] + NUM_TOKENS = [7, 83, 2048] + D = [512, 4096, 5120, 13824] + SEEDS = [0] + + @classmethod + def setUpClass(cls): + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA is not available") + torch.set_default_device("cuda") + + def _run_gelu_and_mul_test(self, num_tokens, d, dtype, seed): + torch.manual_seed(seed) + + layer = GeluAndMul().to(dtype=dtype) + x = torch.randn(num_tokens, 2 * d, dtype=dtype) + + with torch.inference_mode(): + ref_out = layer.forward_native(x) + out = layer.forward_cuda(x) + + if dtype == torch.bfloat16: + atol = rtol = 1e-2 + else: + atol = rtol = 1e-3 + + self.assertTrue(torch.allclose(out, ref_out, atol=atol, rtol=rtol)) + + def test_gelu_and_mul(self): + for params in itertools.product( + self.NUM_TOKENS, + self.D, + self.DTYPES, + self.SEEDS, + ): + with self.subTest( + num_tokens=params[0], + d=params[1], + dtype=params[2], + seed=params[3], + ): + self._run_gelu_and_mul_test(*params) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/python/sglang/test/test_layernorm.py b/python/sglang/test/test_layernorm.py index ab61aa8040..770e69733d 100644 --- a/python/sglang/test/test_layernorm.py +++ b/python/sglang/test/test_layernorm.py @@ -3,7 +3,7 @@ import torch -from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.layernorm import GemmaRMSNorm, RMSNorm class TestRMSNorm(unittest.TestCase): @@ -56,5 +56,57 @@ def test_rms_norm(self): self._run_rms_norm_test(*params) +class TestGemmaRMSNorm(unittest.TestCase): + DTYPES = [torch.half, torch.bfloat16] + NUM_TOKENS = [7, 83, 4096] + HIDDEN_SIZES = [768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192, 8199] + ADD_RESIDUAL = [False, True] + SEEDS = [0] + + @classmethod + def setUpClass(cls): + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA is not available") + torch.set_default_device("cuda") + + def _run_gemma_rms_norm_test( + self, num_tokens, hidden_size, add_residual, dtype, seed + ): + torch.manual_seed(seed) + + layer = GemmaRMSNorm(hidden_size).to(dtype=dtype) + layer.weight.data.normal_(mean=1.0, std=0.1) + scale = 1 / (2 * hidden_size) + x = torch.randn(num_tokens, hidden_size, dtype=dtype) * scale + residual = torch.randn_like(x) * scale if add_residual else None + + with torch.inference_mode(): + ref_out = layer.forward_native(x, residual) + out = layer(x, residual) + + if add_residual: + self.assertTrue(torch.allclose(out[0], ref_out[0], atol=1e-3, rtol=1e-3)) + self.assertTrue(torch.allclose(out[1], ref_out[1], atol=1e-3, rtol=1e-3)) + else: + self.assertTrue(torch.allclose(out, ref_out, atol=1e-3, rtol=1e-3)) + + def test_gemma_rms_norm(self): + for params in itertools.product( + self.NUM_TOKENS, + self.HIDDEN_SIZES, + self.ADD_RESIDUAL, + self.DTYPES, + self.SEEDS, + ): + with self.subTest( + num_tokens=params[0], + hidden_size=params[1], + add_residual=params[2], + dtype=params[3], + seed=params[4], + ): + self._run_gemma_rms_norm_test(*params) + + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/python/sglang/test/test_programs.py b/python/sglang/test/test_programs.py index 7c7c9bdcb1..ce40255855 100644 --- a/python/sglang/test/test_programs.py +++ b/python/sglang/test/test_programs.py @@ -103,16 +103,19 @@ def decode_int(s): def test_decode_json_regex(): @sgl.function def decode_json(s): - from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STRING + from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STR s += "Generate a JSON object to describe the basic city information of Paris.\n" + s += "Here are the JSON object:\n" + + # NOTE: we recommend using dtype gen or whole regex string to control the output with s.var_scope("json_output"): s += "{\n" - s += ' "name": ' + sgl.gen(regex=REGEX_STRING + ",") + "\n" - s += ' "population": ' + sgl.gen(regex=REGEX_INT + ",") + "\n" - s += ' "area": ' + sgl.gen(regex=REGEX_INT + ",") + "\n" - s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT) + "\n" + s += ' "name": ' + sgl.gen(regex=REGEX_STR) + ",\n" + s += ' "population": ' + sgl.gen(regex=REGEX_INT, stop=[" ", "\n"]) + ",\n" + s += ' "area": ' + sgl.gen(regex=REGEX_INT, stop=[" ", "\n"]) + ",\n" + s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT, stop=[" ", "\n"]) + "\n" s += "}" ret = decode_json.run(temperature=0.0) @@ -359,6 +362,30 @@ def regex_gen(s): assert re.match(regex, answer) +def test_dtype_gen(): + @sgl.function + def dtype_gen(s): + s += "Q: What is the full name of DNS?\n" + s += "A: The full nams is " + sgl.gen("str_res", dtype=str, stop="\n") + "\n" + s += "Q: Which year was DNS invented?\n" + s += "A: " + sgl.gen("int_res", dtype=int) + "\n" + s += "Q: What is the value of pi?\n" + s += "A: " + sgl.gen("float_res", dtype=float) + "\n" + s += "Q: Is the sky blue?\n" + s += "A: " + sgl.gen("bool_res", dtype=bool) + "\n" + + state = dtype_gen.run() + + try: + state["int_res"] = int(state["int_res"]) + state["float_res"] = float(state["float_res"]) + state["bool_res"] = bool(state["bool_res"]) + # assert state["str_res"].startswith('"') and state["str_res"].endswith('"') + except ValueError: + print(state) + raise + + def test_completion_speculative(): @sgl.function(num_api_spec_tokens=64) def gen_character_spec(s): diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 7243ff2ecd..d6a1792b85 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -2,11 +2,10 @@ import argparse import asyncio -import multiprocessing +import os import subprocess import threading import time -import unittest from functools import partial from typing import Callable, List, Optional @@ -18,10 +17,19 @@ from sglang.global_config import global_config from sglang.lang.backend.openai import OpenAI from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint +from sglang.srt.utils import kill_child_process from sglang.utils import get_exception_traceback DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Meta-Llama-3.1-8B-Instruct" -DEFAULT_URL_FOR_TEST = "http://127.0.0.1:8157" +DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1" +DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 600 + +if os.getenv("SGLANG_IS_IN_CI", "false") == "true": + DEFAULT_PORT_FOR_SRT_TEST_RUNNER = 5157 + DEFAULT_URL_FOR_TEST = "http://127.0.0.1:6157" +else: + DEFAULT_PORT_FOR_SRT_TEST_RUNNER = 1157 + DEFAULT_URL_FOR_TEST = "http://127.0.0.1:2157" def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None): @@ -100,31 +108,8 @@ def call_generate_srt_raw(prompt, temperature, max_tokens, stop=None, url=None): return pred -def call_generate_ginfer(prompt, temperature, max_tokens, stop=None, url=None): - import grpc - from ginfer import sampler_pb2, sampler_pb2_grpc - - sampler_channel = grpc.insecure_channel(url.replace("http://", "")) - sampler = sampler_pb2_grpc.SamplerStub(sampler_channel) - - if stop is None: - stop_strings = None - else: - stop_strings = [stop] - - sample_request = sampler_pb2.SampleTextRequest( - prompt=prompt, - settings=sampler_pb2.SampleSettings( - max_len=max_tokens, - rng_seed=0, - temperature=max(temperature, 1e-7), - nucleus_p=1, - stop_strings=stop_strings, - ), - ) - stream = sampler.SampleText(sample_request) - response = "".join([x.text for x in stream]) - return response +def call_generate_gserver(prompt, temperature, max_tokens, stop=None, url=None): + raise NotImplementedError() def call_generate_guidance( @@ -267,7 +252,7 @@ def add_common_other_args_and_parse(parser: argparse.ArgumentParser): "vllm", "outlines", "lightllm", - "ginfer", + "gserver", "guidance", "lmql", "srt-raw", @@ -288,7 +273,7 @@ def add_common_other_args_and_parse(parser: argparse.ArgumentParser): "lightllm": 22000, "lmql": 23000, "srt-raw": 30000, - "ginfer": 9988, + "gserver": 9988, } args.port = default_port.get(args.backend, None) return args @@ -324,8 +309,8 @@ def _get_call_generate(args: argparse.Namespace): return partial(call_generate_vllm, url=f"{args.host}:{args.port}/generate") elif args.backend == "srt-raw": return partial(call_generate_srt_raw, url=f"{args.host}:{args.port}/generate") - elif args.backend == "ginfer": - return partial(call_generate_ginfer, url=f"{args.host}:{args.port}") + elif args.backend == "gserver": + return partial(call_generate_gserver, url=f"{args.host}:{args.port}") elif args.backend == "outlines": return partial(call_generate_outlines, url=f"{args.host}:{args.port}/generate") elif args.backend == "guidance": @@ -476,34 +461,36 @@ def run_unittest_files(files: List[str], timeout_per_file: float): success = True for filename in files: + global process - def func(): - print(f"\n\nRun {filename}\n\n") - ret = unittest.main(module=None, argv=["", "-vb"] + [filename]) - - p = multiprocessing.Process(target=func) - - def run_one_file(): - p.start() - p.join() + def run_one_file(filename): + filename = os.path.join(os.getcwd(), filename) + print(f"\n\nRun:\npython3 {filename}\n\n", flush=True) + process = subprocess.Popen( + ["python3", filename], stdout=None, stderr=None, env=os.environ + ) + process.wait() + return process.returncode try: - run_with_timeout(run_one_file, timeout=timeout_per_file) - if p.exitcode != 0: - success = False - break + ret_code = run_with_timeout( + run_one_file, args=(filename,), timeout=timeout_per_file + ) + assert ret_code == 0 except TimeoutError: - p.terminate() + kill_child_process(process.pid) time.sleep(5) print( - f"\nTimeout after {timeout_per_file} seconds when running {filename}\n" + f"\nTimeout after {timeout_per_file} seconds when running {filename}\n", + flush=True, ) - return False + success = False + break if success: - print(f"Success. Time elapsed: {time.time() - tic:.2f}s") + print(f"Success. Time elapsed: {time.time() - tic:.2f}s", flush=True) else: - print(f"Fail. Time elapsed: {time.time() - tic:.2f}s") + print(f"Fail. Time elapsed: {time.time() - tic:.2f}s", flush=True) return 0 if success else -1 diff --git a/python/sglang/version.py b/python/sglang/version.py index 5635676f6b..ad954de503 100644 --- a/python/sglang/version.py +++ b/python/sglang/version.py @@ -1 +1 @@ -__version__ = "0.2.11" +__version__ = "0.2.14.post2" diff --git a/scripts/convert_yi_vl.py b/scripts/deprecated/convert_yi_vl.py similarity index 100% rename from scripts/convert_yi_vl.py rename to scripts/deprecated/convert_yi_vl.py diff --git a/scripts/convert_yi_vl.sh b/scripts/deprecated/convert_yi_vl.sh similarity index 100% rename from scripts/convert_yi_vl.sh rename to scripts/deprecated/convert_yi_vl.sh diff --git a/scripts/playground/reference_hf.py b/scripts/playground/reference_hf.py index ac91b3bed4..d2d3116101 100644 --- a/scripts/playground/reference_hf.py +++ b/scripts/playground/reference_hf.py @@ -35,18 +35,17 @@ def normal_text(args): args.model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True, + device_map="auto", trust_remote_code=True, ) m.cuda() - print(m) - prompts = [ "The capital of France is", "The capital of the United Kindom is", "Today is a sunny day and I like", ] - max_new_tokens = 32 + max_new_tokens = 16 for p in prompts: if isinstance(p, str): @@ -58,10 +57,11 @@ def normal_text(args): input_ids, do_sample=False, max_new_tokens=max_new_tokens ) output_str = t.decode(output_ids[0]) - print(output_str) prefill_logits = m.forward(input_ids).logits[0][-1] + print("prefill logits", prefill_logits) + print(output_str) @torch.inference_mode() diff --git a/test/lang/test_srt_backend.py b/test/lang/test_srt_backend.py index b2a07ae36c..fcd86ae3d3 100644 --- a/test/lang/test_srt_backend.py +++ b/test/lang/test_srt_backend.py @@ -1,10 +1,10 @@ -import json import unittest import sglang as sgl from sglang.test.test_programs import ( test_decode_int, test_decode_json_regex, + test_dtype_gen, test_expert_answer, test_few_shot_qa, test_mt_bench, @@ -59,6 +59,9 @@ def test_stream(self): def test_regex(self): test_regex() + def test_dtype_gen(self): + test_dtype_gen() + if __name__ == "__main__": unittest.main() diff --git a/test/srt/models/test_embedding_models.py b/test/srt/models/test_embedding_models.py index 520e811a80..a5a73bf319 100644 --- a/test/srt/models/test_embedding_models.py +++ b/test/srt/models/test_embedding_models.py @@ -13,6 +13,7 @@ limitations under the License. """ +import multiprocessing as mp import unittest import torch @@ -20,7 +21,10 @@ from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner from sglang.test.test_utils import get_similarities -MODELS = [("intfloat/e5-mistral-7b-instruct", 1)] +MODELS = [ + ("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, 1e-5), + ("intfloat/e5-mistral-7b-instruct", 1, 1e-5), +] TORCH_DTYPES = [torch.float16] @@ -32,9 +36,10 @@ def assert_close_prefill_logits( model_path, tp_size, torch_dtype, + prefill_tolerance, ) -> None: with HFRunner( - model_path, torch_dtype=torch_dtype, is_generation_model=False + model_path, torch_dtype=torch_dtype, is_generation=False ) as hf_runner: hf_outputs = hf_runner.forward(prompts) @@ -42,30 +47,34 @@ def assert_close_prefill_logits( model_path, tp_size=tp_size, torch_dtype=torch_dtype, - is_generation_model=False, + is_generation=False, ) as srt_runner: - srt_outputs = srt_runner.forward( - prompts, - ) + srt_outputs = srt_runner.forward(prompts) for i in range(len(prompts)): hf_logits = torch.Tensor(hf_outputs.embed_logits[i]) srt_logits = torch.Tensor(srt_outputs.embed_logits[i]) - similarities = torch.tensor(get_similarities(hf_logits, srt_logits)) + similarity = torch.tensor(get_similarities(hf_logits, srt_logits)) + print("similarity diff", abs(similarity - 1)) - tolerance = 1e-2 - assert torch.all( - abs(similarities - 1) < tolerance - ), f"embeddings not all close" + if len(prompts[i]) <= 1000: + assert torch.all( + abs(similarity - 1) < prefill_tolerance + ), "embeddings are not all close" def test_prefill_logits(self): - for model, tp_size in MODELS: + for model, tp_size, prefill_tolerance in MODELS: for torch_dtype in TORCH_DTYPES: self.assert_close_prefill_logits( - DEFAULT_PROMPTS, model, tp_size, torch_dtype + DEFAULT_PROMPTS, model, tp_size, torch_dtype, prefill_tolerance ) if __name__ == "__main__": - unittest.main(warnings="ignore") + try: + mp.set_start_method("spawn") + except RuntimeError: + pass + + unittest.main() diff --git a/test/srt/models/test_generation_models.py b/test/srt/models/test_generation_models.py index ca4f096e30..08288c510c 100644 --- a/test/srt/models/test_generation_models.py +++ b/test/srt/models/test_generation_models.py @@ -13,6 +13,7 @@ limitations under the License. """ +import multiprocessing as mp import unittest import torch @@ -20,14 +21,47 @@ from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner MODELS = [ - ("meta-llama/Meta-Llama-3.1-8B-Instruct", 1), - ("google/gemma-2-2b", 1), + ("meta-llama/Meta-Llama-3.1-8B-Instruct", 1, 1.1, 3e-2, 1), + ("google/gemma-2-2b", 1, 3, 3e-2, 1), + ("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, None, 6e-2, 1), ] TORCH_DTYPES = [torch.float16] -class TestGenerationModels(unittest.TestCase): +def lcs(X, Y): + m = len(X) + n = len(Y) + L = [[0] * (n + 1) for _ in range(m + 1)] + + for i in range(m + 1): + for j in range(n + 1): + if i == 0 or j == 0: + L[i][j] = 0 + elif X[i - 1] == Y[j - 1]: + L[i][j] = L[i - 1][j - 1] + 1 + else: + L[i][j] = max(L[i - 1][j], L[i][j - 1]) + + return L[m][n] + + +def calculate_rouge_l(output_strs_list1, output_strs_list2): + rouge_l_scores = [] + + for s1, s2 in zip(output_strs_list1, output_strs_list2): + lcs_len = lcs(s1, s2) + precision = lcs_len / len(s1) if len(s1) > 0 else 0 + recall = lcs_len / len(s2) if len(s2) > 0 else 0 + if precision + recall > 0: + fmeasure = (2 * precision * recall) / (precision + recall) + else: + fmeasure = 0.0 + rouge_l_scores.append(fmeasure) + return rouge_l_scores + + +class TestGenerationModels(unittest.TestCase): def assert_close_prefill_logits_and_output_strs( self, prompts, @@ -35,9 +69,14 @@ def assert_close_prefill_logits_and_output_strs( tp_size, torch_dtype, max_new_tokens, + prefill_tolerance, + rouge_threshold, + long_context_tolerance, ) -> None: + if model_path == "Alibaba-NLP/gte-Qwen2-1.5B-instruct": + prompts = prompts[:-1] with HFRunner( - model_path, torch_dtype=torch_dtype, is_generation_model=True + model_path, torch_dtype=torch_dtype, is_generation=True ) as hf_runner: hf_outputs = hf_runner.forward(prompts, max_new_tokens=max_new_tokens) @@ -45,7 +84,7 @@ def assert_close_prefill_logits_and_output_strs( model_path, tp_size=tp_size, torch_dtype=torch_dtype, - is_generation_model=True, + is_generation=True, ) as srt_runner: srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens) @@ -53,25 +92,48 @@ def assert_close_prefill_logits_and_output_strs( hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i]) srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i]) - tolerance = 3e-2 - assert torch.all( - abs(hf_logprobs - srt_logprobs) < tolerance - ), f"prefill logprobs not all close" - - assert hf_outputs.output_strs == srt_outputs.output_strs - - def test_prefill_logits(self): - for model, tp_size in MODELS: + print("max_diff", torch.max(abs(hf_logprobs - srt_logprobs))) + if hf_logprobs.shape[0] <= 100: + assert torch.all( + abs(hf_logprobs - srt_logprobs) < prefill_tolerance + ), f"prefill logprobs are not all close with model_path={model_path} prompts={prompts} prefill_tolerance={prefill_tolerance}" + + print(f"hf_outputs.output_strs={hf_outputs.output_strs}") + print(f"srt_outputs.output_strs={srt_outputs.output_strs}") + rouge_l_scores = calculate_rouge_l( + hf_outputs.output_strs, srt_outputs.output_strs + ) + print(f"rouge_l_scores={rouge_l_scores}") + assert all( + score >= rouge_threshold for score in rouge_l_scores + ), f"Not all ROUGE-L scores are greater than rouge_threshold={rouge_threshold}" + + def test_prefill_logits_and_output_strs(self): + for ( + model, + tp_size, + long_context_tolerance, + prefill_tolerance, + rouge_threshold, + ) in MODELS: for torch_dtype in TORCH_DTYPES: - max_new_tokens = 8 + max_new_tokens = 32 self.assert_close_prefill_logits_and_output_strs( DEFAULT_PROMPTS, model, tp_size, torch_dtype, max_new_tokens, + prefill_tolerance=prefill_tolerance, + rouge_threshold=rouge_threshold, + long_context_tolerance=long_context_tolerance, ) if __name__ == "__main__": - unittest.main(warnings="ignore") + try: + mp.set_start_method("spawn") + except RuntimeError: + pass + + unittest.main() diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 08122389f9..cafcf3f2d5 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -5,18 +5,20 @@ suites = { "minimal": [ + "models/test_embedding_models.py", + "models/test_generation_models.py", + "sampling/penaltylib", "test_chunked_prefill.py", "test_embedding_openai_server.py", - "test_eval_accuracy.py", + "test_eval_accuracy_mini.py", "test_large_max_new_tokens.py", "test_openai_server.py", + "test_json_constrained.py", "test_skip_tokenizer_init.py", "test_torch_compile.py", + "test_triton_attn_backend.py", + "test_update_weights.py", "test_vision_openai_server.py", - "test_large_max_new_tokens.py", - "models/test_generation_models.py", - "models/test_embedding_models.py", - "sampling/penaltylib", ], "sampling/penaltylib": glob.glob( "sampling/penaltylib/**/test_*.py", recursive=True @@ -31,6 +33,7 @@ tests.remove(target_suite_name) tests.extend(target_tests) + if __name__ == "__main__": arg_parser = argparse.ArgumentParser() arg_parser.add_argument( @@ -46,6 +49,18 @@ choices=list(suites.keys()) + ["all"], help="The suite to run", ) + arg_parser.add_argument( + "--range-begin", + type=int, + default=0, + help="The begin index of the range of the files to run.", + ) + arg_parser.add_argument( + "--range-end", + type=int, + default=None, + help="The end index of the range of the files to run.", + ) args = arg_parser.parse_args() if args.suite == "all": @@ -53,5 +68,7 @@ else: files = suites[args.suite] + files = files[args.range_begin : args.range_end] + exit_code = run_unittest_files(files, args.timeout_per_file) exit(exit_code) diff --git a/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py b/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py index e72dc30f95..e3496102cb 100644 --- a/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py +++ b/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py @@ -5,7 +5,12 @@ import requests from sglang.srt.utils import kill_child_process -from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST, popen_launch_server +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) class TestBatchPenalizerE2E(unittest.TestCase): @@ -13,11 +18,11 @@ class TestBatchPenalizerE2E(unittest.TestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MODEL_NAME_FOR_TEST - cls.base_url = f"http://127.0.0.1:{8157}" + cls.base_url = DEFAULT_URL_FOR_TEST cls.process = popen_launch_server( cls.model, cls.base_url, - timeout=300, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=( "--random-seed", "0", @@ -107,4 +112,4 @@ def test_repetition_penalty(self): if __name__ == "__main__": - unittest.main(warnings="ignore") + unittest.main() diff --git a/test/srt/test_chunked_prefill.py b/test/srt/test_chunked_prefill.py index 3a9423bc5b..2eb704dc91 100644 --- a/test/srt/test_chunked_prefill.py +++ b/test/srt/test_chunked_prefill.py @@ -5,39 +5,55 @@ from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, popen_launch_server, ) -class TestAccuracy(unittest.TestCase): +class TestChunkedPrefill(unittest.TestCase): + def run_mmlu(self, disable_radix_cache, enable_mixed_chunk): + other_args = ["--chunked-prefill-size", "32"] + if disable_radix_cache: + other_args += ["--disable-radix-cache"] - @classmethod - def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=300, - other_args=["--chunked-prefill-size", "32"], - ) + if enable_mixed_chunk: + other_args += ["--enable-mixed-chunk"] - @classmethod - def tearDownClass(cls): - kill_child_process(cls.process.pid) + model = DEFAULT_MODEL_NAME_FOR_TEST + base_url = DEFAULT_URL_FOR_TEST + process = popen_launch_server( + model, + base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) - def test_mmlu(self): args = SimpleNamespace( - base_url=self.base_url, - model=self.model, + base_url=base_url, + model=model, eval_name="mmlu", - num_examples=20, - num_threads=20, + num_examples=32, + num_threads=32, ) - metrics = run_eval(args) - assert metrics["score"] >= 0.5 + try: + metrics = run_eval(args) + assert metrics["score"] >= 0.6 + finally: + kill_child_process(process.pid) + + def test_chunked_prefill(self): + self.run_mmlu(disable_radix_cache=False, enable_mixed_chunk=False) + + def test_mixed_chunked_prefill(self): + self.run_mmlu(disable_radix_cache=False, enable_mixed_chunk=True) + + def test_chunked_prefill_without_radix_cache(self): + self.run_mmlu(disable_radix_cache=True, enable_mixed_chunk=False) + + def test_mixed_chunked_prefill_without_radix_cache(self): + self.run_mmlu(disable_radix_cache=True, enable_mixed_chunk=True) if __name__ == "__main__": diff --git a/test/srt/test_embedding_openai_server.py b/test/srt/test_embedding_openai_server.py index 45580feda0..45f7850da9 100644 --- a/test/srt/test_embedding_openai_server.py +++ b/test/srt/test_embedding_openai_server.py @@ -4,18 +4,24 @@ from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.utils import kill_child_process -from sglang.test.test_utils import DEFAULT_URL_FOR_TEST, popen_launch_server +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) class TestOpenAIServer(unittest.TestCase): - @classmethod def setUpClass(cls): cls.model = "intfloat/e5-mistral-7b-instruct" cls.base_url = DEFAULT_URL_FOR_TEST cls.api_key = "sk-123456" cls.process = popen_launch_server( - cls.model, cls.base_url, timeout=300, api_key=cls.api_key + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + api_key=cls.api_key, ) cls.base_url += "/v1" cls.tokenizer = get_tokenizer(cls.model) diff --git a/test/srt/test_eval_accuracy_large.py b/test/srt/test_eval_accuracy_large.py new file mode 100644 index 0000000000..3729ad26b6 --- /dev/null +++ b/test/srt/test_eval_accuracy_large.py @@ -0,0 +1,68 @@ +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_child_process +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestEvalAccuracyLarge(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--log-level-http", "warning"], + ) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=3000, + num_threads=1024, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.705, f"{metrics}" + + def test_human_eval(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="humaneval", + num_examples=None, + num_threads=1024, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.64, f"{metrics}" + + def test_mgsm_en(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mgsm_en", + num_examples=None, + num_threads=1024, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.84, f"{metrics}" + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_eval_accuracy_large_chunked_prefill.py b/test/srt/test_eval_accuracy_large_chunked_prefill.py new file mode 100644 index 0000000000..02df2a7f56 --- /dev/null +++ b/test/srt/test_eval_accuracy_large_chunked_prefill.py @@ -0,0 +1,68 @@ +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_child_process +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestEvalAccuracyLargeChunkedPrefill(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--log-level-http", "warning", "--chunked-prefill-size", "256"], + ) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=3000, + num_threads=1024, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.705, f"{metrics}" + + def test_human_eval(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="humaneval", + num_examples=None, + num_threads=1024, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.64, f"{metrics}" + + def test_mgsm_en(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mgsm_en", + num_examples=None, + num_threads=1024, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.84, f"{metrics}" + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_eval_accuracy_large_mixed_chunked_prefill.py b/test/srt/test_eval_accuracy_large_mixed_chunked_prefill.py new file mode 100644 index 0000000000..8ba71e5c83 --- /dev/null +++ b/test/srt/test_eval_accuracy_large_mixed_chunked_prefill.py @@ -0,0 +1,74 @@ +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_child_process +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestEvalAccuracyLargeChunkedPrefill(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--log-level-http", + "warning", + "--chunked-prefill-size", + "256", + "--enable-mixed-chunk", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=3000, + num_threads=1024, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.705, f"{metrics}" + + def test_human_eval(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="humaneval", + num_examples=None, + num_threads=1024, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.64, f"{metrics}" + + def test_mgsm_en(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mgsm_en", + num_examples=None, + num_threads=1024, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.84, f"{metrics}" + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_eval_accuracy.py b/test/srt/test_eval_accuracy_mini.py similarity index 70% rename from test/srt/test_eval_accuracy.py rename to test/srt/test_eval_accuracy_mini.py index a3f16f857e..25aa0ca116 100644 --- a/test/srt/test_eval_accuracy.py +++ b/test/srt/test_eval_accuracy_mini.py @@ -5,18 +5,20 @@ from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, popen_launch_server, ) -class TestAccuracy(unittest.TestCase): - +class TestEvalAccuracyMini(unittest.TestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MODEL_NAME_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300) + cls.process = popen_launch_server( + cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH + ) @classmethod def tearDownClass(cls): @@ -27,12 +29,12 @@ def test_mmlu(self): base_url=self.base_url, model=self.model, eval_name="mmlu", - num_examples=20, - num_threads=20, + num_examples=32, + num_threads=32, ) metrics = run_eval(args) - assert metrics["score"] >= 0.5 + assert metrics["score"] >= 0.6 if __name__ == "__main__": diff --git a/test/srt/test_json_constrained.py b/test/srt/test_json_constrained.py new file mode 100644 index 0000000000..5393ecc33c --- /dev/null +++ b/test/srt/test_json_constrained.py @@ -0,0 +1,96 @@ +import json +import unittest + +import openai +import requests + +from sglang.srt.utils import kill_child_process +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestJSONConstrained(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + cls.json_schema = json.dumps( + { + "type": "object", + "properties": { + "name": {"type": "string", "pattern": "^[\\w]+$"}, + "population": {"type": "integer"}, + }, + "required": ["name", "population"], + } + ) + cls.process = popen_launch_server( + cls.model, cls.base_url, timeout=300, api_key=cls.api_key + ) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid) + + def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1): + headers = {"Authorization": f"Bearer {self.api_key}"} + response = requests.post( + self.base_url + "/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0 if n == 1 else 0.5, + "max_new_tokens": 128, + "n": n, + "stop_token_ids": [119690], + "json_schema": self.json_schema, + }, + "stream": False, + "return_logprob": return_logprob, + "top_logprobs_num": top_logprobs_num, + "logprob_start_len": 0, + }, + headers=headers, + ) + print(json.dumps(response.json())) + print("=" * 100) + try: + js_obj = json.loads(response.json()["text"]) + except (TypeError, json.decoder.JSONDecodeError): + raise + assert isinstance(js_obj["name"], str) + assert isinstance(js_obj["population"], int) + + def test_json_generate(self): + self.run_decode() + + def test_json_openai(self): + client = openai.Client(api_key=self.api_key, base_url=f"{self.base_url}/v1") + + response = client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": "You are a helpful AI assistant"}, + {"role": "user", "content": "Introduce the capital of France."}, + ], + temperature=0, + max_tokens=128, + extra_body={"json_schema": self.json_schema}, + ) + text = response.choices[0].message.content + + try: + js_obj = json.loads(text) + except (TypeError, json.decoder.JSONDecodeError): + print("JSONDecodeError", text) + raise + assert isinstance(js_obj["name"], str) + assert isinstance(js_obj["population"], int) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_large_max_new_tokens.py b/test/srt/test_large_max_new_tokens.py index 58f82b3516..10b82706a6 100644 --- a/test/srt/test_large_max_new_tokens.py +++ b/test/srt/test_large_max_new_tokens.py @@ -10,13 +10,13 @@ from sglang.srt.utils import kill_child_process from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, popen_launch_server, ) class TestOpenAIServer(unittest.TestCase): - @classmethod def setUpClass(cls): cls.model = DEFAULT_MODEL_NAME_FOR_TEST @@ -25,7 +25,7 @@ def setUpClass(cls): cls.process = popen_launch_server( cls.model, cls.base_url, - timeout=300, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, api_key=cls.api_key, other_args=("--max-total-token", "1024"), env={"SGLANG_CLIP_MAX_NEW_TOKENS": "256", **os.environ}, diff --git a/test/srt/test_moe_serving_throughput.py b/test/srt/test_moe_serving_throughput.py new file mode 100644 index 0000000000..4f6e8db82c --- /dev/null +++ b/test/srt/test_moe_serving_throughput.py @@ -0,0 +1,105 @@ +import os +import unittest +from types import SimpleNamespace + +from sglang.bench_serving import run_benchmark +from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import kill_child_process +from sglang.test.test_utils import ( + DEFAULT_MOE_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestServingThroughput(unittest.TestCase): + def run_test(self, disable_radix_cache, disable_flashinfer, chunked_prefill_size): + # Launch the server + other_args = [] + if disable_radix_cache: + other_args.append("--disable-radix-cache") + if disable_flashinfer: + other_args.append("--disable-flashinfer") + other_args.extend(["--chunked-prefill-size", str(chunked_prefill_size)]) + other_args.extend(["--tensor-parallel-size", "2"]) + other_args.append("--enable-p2p-check") + + model = DEFAULT_MOE_MODEL_NAME_FOR_TEST + base_url = DEFAULT_URL_FOR_TEST + process = popen_launch_server( + model, + base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + # Run benchmark + num_prompts = 200 + args = SimpleNamespace( + backend="sglang", + base_url=base_url, + host=None, + port=None, + dataset_name="random", + dataset_path="", + model=None, + tokenizer=None, + num_prompts=num_prompts, + sharegpt_output_len=None, + random_input_len=4096, + random_output_len=2048, + random_range_ratio=0.0, + request_rate=float("inf"), + multi=None, + seed=0, + output_file=None, + disable_tqdm=False, + disable_stream=False, + disable_ignore_eos=False, + extra_request_body=None, + ) + + try: + res = run_benchmark(args) + finally: + kill_child_process(process.pid) + + assert res["completed"] == num_prompts + return res + + def test_default(self): + res = self.run_test( + disable_radix_cache=ServerArgs.disable_radix_cache, + disable_flashinfer=ServerArgs.disable_flashinfer, + chunked_prefill_size=ServerArgs.chunked_prefill_size, + ) + + if os.getenv("SGLANG_IS_IN_CI", "false") == "true": + # A100 (PCIE): 950, H100 (SMX): 1800 + assert res["output_throughput"] > 1750 + + def test_default_without_radix_cache(self): + res = self.run_test( + disable_radix_cache=True, + disable_flashinfer=ServerArgs.disable_flashinfer, + chunked_prefill_size=ServerArgs.chunked_prefill_size, + ) + + if os.getenv("SGLANG_IS_IN_CI", "false") == "true": + # A100 (PCIE): 950, H100 (SMX): 1900 + assert res["output_throughput"] > 1850 + + def test_all_cases(self): + for disable_radix_cache in [False, True]: + for disable_flashinfer in [False, True]: + for chunked_prefill_size in [-1, 2048]: + self.run_test( + disable_radix_cache=False, + disable_flashinfer=False, + chunked_prefill_size=-1, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index b66c35f01d..3fc5785517 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -8,20 +8,23 @@ from sglang.srt.utils import kill_child_process from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, popen_launch_server, ) class TestOpenAIServer(unittest.TestCase): - @classmethod def setUpClass(cls): cls.model = DEFAULT_MODEL_NAME_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST cls.api_key = "sk-123456" cls.process = popen_launch_server( - cls.model, cls.base_url, timeout=300, api_key=cls.api_key + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + api_key=cls.api_key, ) cls.base_url += "/v1" cls.tokenizer = get_tokenizer(DEFAULT_MODEL_NAME_FOR_TEST) @@ -71,13 +74,12 @@ def run_completion( assert isinstance(response.choices[0].logprobs.tokens[0], str) assert isinstance(response.choices[0].logprobs.top_logprobs[1], dict) ret_num_top_logprobs = len(response.choices[0].logprobs.top_logprobs[1]) + # FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some out_put id maps to the same output token and duplicate in the map # assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}" + assert ret_num_top_logprobs > 0 - if echo: - assert response.choices[0].logprobs.token_logprobs[0] == None - else: - assert response.choices[0].logprobs.token_logprobs[0] != None + assert response.choices[0].logprobs.token_logprobs[0] != None assert response.id assert response.created @@ -87,13 +89,26 @@ def run_completion( assert response.usage.completion_tokens > 0 assert response.usage.total_tokens > 0 - def run_completion_stream(self, echo, logprobs, token_input): + def run_completion_stream( + self, echo, logprobs, use_list_input, parallel_sample_num, token_input + ): client = openai.Client(api_key=self.api_key, base_url=self.base_url) prompt = "The capital of France is" if token_input: - prompt_arg = self.tokenizer.encode(prompt) + prompt_input = self.tokenizer.encode(prompt) + num_prompt_tokens = len(prompt_input) + else: + prompt_input = prompt + num_prompt_tokens = len(self.tokenizer.encode(prompt)) + + if use_list_input: + prompt_arg = [prompt_input, prompt_input] + num_choices = len(prompt_arg) + num_prompt_tokens *= 2 else: - prompt_arg = prompt + prompt_arg = prompt_input + num_choices = 1 + generator = client.completions.create( model=self.model, prompt=prompt_arg, @@ -103,9 +118,10 @@ def run_completion_stream(self, echo, logprobs, token_input): logprobs=logprobs, stream=True, stream_options={"include_usage": True}, + n=parallel_sample_num, ) - first = True + is_firsts = {} for response in generator: usage = response.usage if usage is not None: @@ -113,10 +129,14 @@ def run_completion_stream(self, echo, logprobs, token_input): assert usage.completion_tokens > 0 assert usage.total_tokens > 0 continue + + index = response.choices[0].index + is_first = is_firsts.get(index, True) + if logprobs: assert response.choices[0].logprobs assert isinstance(response.choices[0].logprobs.tokens[0], str) - if not (first and echo): + if not (is_first and echo): assert isinstance( response.choices[0].logprobs.top_logprobs[0], dict ) @@ -127,15 +147,20 @@ def run_completion_stream(self, echo, logprobs, token_input): # assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}" assert ret_num_top_logprobs > 0 - if first: + if is_first: if echo: assert response.choices[0].text.startswith( prompt - ), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {first}" - first = False + ), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {is_first}" + is_firsts[index] = False assert response.id assert response.created + for index in [i for i in range(parallel_sample_num * num_choices)]: + assert not is_firsts.get( + index, True + ), f"index {index} is not found in the response" + def run_chat_completion(self, logprobs, parallel_sample_num): client = openai.Client(api_key=self.api_key, base_url=self.base_url) response = client.chat.completions.create( @@ -174,7 +199,7 @@ def run_chat_completion(self, logprobs, parallel_sample_num): assert response.usage.completion_tokens > 0 assert response.usage.total_tokens > 0 - def run_chat_completion_stream(self, logprobs): + def run_chat_completion_stream(self, logprobs, parallel_sample_num=1): client = openai.Client(api_key=self.api_key, base_url=self.base_url) generator = client.chat.completions.create( model=self.model, @@ -187,9 +212,10 @@ def run_chat_completion_stream(self, logprobs): top_logprobs=logprobs, stream=True, stream_options={"include_usage": True}, + n=parallel_sample_num, ) - is_first = True + is_firsts = {} for response in generator: usage = response.usage if usage is not None: @@ -198,11 +224,12 @@ def run_chat_completion_stream(self, logprobs): assert usage.total_tokens > 0 continue + index = response.choices[0].index data = response.choices[0].delta - if is_first: - data.role == "assistant" - is_first = False + if is_firsts.get(index, True): + assert data.role == "assistant" + is_firsts[index] = False continue if logprobs: @@ -224,8 +251,12 @@ def run_chat_completion_stream(self, logprobs): assert response.id assert response.created - def run_batch(self, mode): - client = openai.Client(api_key=self.api_key, base_url=self.base_url) + for index in [i for i in range(parallel_sample_num)]: + assert not is_firsts.get( + index, True + ), f"index {index} is not found in the response" + + def _create_batch(self, mode, client): if mode == "completion": input_file_path = "complete_input.jsonl" # write content to input file @@ -301,9 +332,11 @@ def run_batch(self, mode): }, }, ] + with open(input_file_path, "w") as file: for line in content: file.write(json.dumps(line) + "\n") + with open(input_file_path, "rb") as file: uploaded_file = client.files.create(file=file, purpose="batch") if mode == "completion": @@ -316,13 +349,22 @@ def run_batch(self, mode): endpoint=endpoint, completion_window=completion_window, ) + + return batch_job, content, uploaded_file + + def run_batch(self, mode): + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + batch_job, content, uploaded_file = self._create_batch(mode=mode, client=client) + while batch_job.status not in ["completed", "failed", "cancelled"]: time.sleep(3) print( f"Batch job status: {batch_job.status}...trying again in 3 seconds..." ) batch_job = client.batches.retrieve(batch_job.id) - assert batch_job.status == "completed" + assert ( + batch_job.status == "completed" + ), f"Batch job status is not completed: {batch_job.status}" assert batch_job.request_counts.completed == len(content) assert batch_job.request_counts.failed == 0 assert batch_job.request_counts.total == len(content) @@ -336,6 +378,29 @@ def run_batch(self, mode): if line.strip() != "" ] assert len(results) == len(content) + for delete_fid in [uploaded_file.id, result_file_id]: + del_pesponse = client.files.delete(delete_fid) + assert del_pesponse.deleted + + def run_cancel_batch(self, mode): + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + batch_job, _, uploaded_file = self._create_batch(mode=mode, client=client) + + assert batch_job.status not in ["cancelling", "cancelled"] + + batch_job = client.batches.cancel(batch_id=batch_job.id) + assert batch_job.status == "cancelling" + + while batch_job.status not in ["failed", "cancelled"]: + batch_job = client.batches.retrieve(batch_job.id) + print( + f"Batch job status: {batch_job.status}...trying again in 3 seconds..." + ) + time.sleep(3) + + assert batch_job.status == "cancelled" + del_response = client.files.delete(uploaded_file.id) + assert del_response.deleted def test_completion(self): for echo in [False, True]: @@ -355,8 +420,16 @@ def test_completion_stream(self): # parallel sampling adn list input are not supported in streaming mode for echo in [False, True]: for logprobs in [None, 5]: - for token_input in [False, True]: - self.run_completion_stream(echo, logprobs, token_input) + for use_list_input in [True, False]: + for parallel_sample_num in [1, 2]: + for token_input in [False, True]: + self.run_completion_stream( + echo, + logprobs, + use_list_input, + parallel_sample_num, + token_input, + ) def test_chat_completion(self): for logprobs in [None, 5]: @@ -365,12 +438,17 @@ def test_chat_completion(self): def test_chat_completion_stream(self): for logprobs in [None, 5]: - self.run_chat_completion_stream(logprobs) + for parallel_sample_num in [1, 2]: + self.run_chat_completion_stream(logprobs, parallel_sample_num) def test_batch(self): for mode in ["completion", "chat"]: self.run_batch(mode) + def test_calcel_batch(self): + for mode in ["completion", "chat"]: + self.run_cancel_batch(mode) + def test_regex(self): client = openai.Client(api_key=self.api_key, base_url=self.base_url) diff --git a/test/srt/test_serving_throughput.py b/test/srt/test_serving_throughput.py index 808bc833ea..f1089a6a7b 100644 --- a/test/srt/test_serving_throughput.py +++ b/test/srt/test_serving_throughput.py @@ -1,13 +1,19 @@ +import os import unittest from types import SimpleNamespace from sglang.bench_serving import run_benchmark +from sglang.srt.server_args import ServerArgs from sglang.srt.utils import kill_child_process -from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST, popen_launch_server +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) class TestServingThroughput(unittest.TestCase): - def run_test(self, disable_radix_cache, disable_flashinfer, chunked_prefill_size): # Launch the server other_args = [] @@ -18,9 +24,12 @@ def run_test(self, disable_radix_cache, disable_flashinfer, chunked_prefill_size other_args.extend(["--chunked-prefill-size", str(chunked_prefill_size)]) model = DEFAULT_MODEL_NAME_FOR_TEST - base_url = "http://127.0.0.1:9157" + base_url = DEFAULT_URL_FOR_TEST process = popen_launch_server( - model, base_url, timeout=300, other_args=other_args + model, + base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, ) # Run benchmark @@ -55,28 +64,41 @@ def run_test(self, disable_radix_cache, disable_flashinfer, chunked_prefill_size kill_child_process(process.pid) assert res["completed"] == num_prompts + return res def test_default(self): - self.run_test( - disable_radix_cache=False, - disable_flashinfer=False, - chunked_prefill_size=-1, + res = self.run_test( + disable_radix_cache=ServerArgs.disable_radix_cache, + disable_flashinfer=ServerArgs.disable_flashinfer, + chunked_prefill_size=ServerArgs.chunked_prefill_size, ) + if os.getenv("SGLANG_IS_IN_CI", "false") == "true": + # A100 (PCIE): 1450, H100 (SMX): 2550 + assert res["output_throughput"] > 2500 + def test_default_without_radix_cache(self): - self.run_test( + res = self.run_test( disable_radix_cache=True, - disable_flashinfer=False, - chunked_prefill_size=-1, + disable_flashinfer=ServerArgs.disable_flashinfer, + chunked_prefill_size=ServerArgs.chunked_prefill_size, ) - def test_default_without_flashinfer(self): - self.run_test( - disable_radix_cache=False, - disable_flashinfer=True, + if os.getenv("SGLANG_IS_IN_CI", "false") == "true": + # A100 (PCIE): 1500, H100 (SMX): 2850 + assert res["output_throughput"] > 2800 + + def test_default_without_chunked_prefill(self): + res = self.run_test( + disable_radix_cache=ServerArgs.disable_radix_cache, + disable_flashinfer=ServerArgs.disable_flashinfer, chunked_prefill_size=-1, ) + if os.getenv("SGLANG_IS_IN_CI", "false") == "true": + # A100 (PCIE): 1450, H100 (SMX): 2550 + assert res["output_throughput"] > 2500 + def test_all_cases(self): for disable_radix_cache in [False, True]: for disable_flashinfer in [False, True]: diff --git a/test/srt/test_skip_tokenizer_init.py b/test/srt/test_skip_tokenizer_init.py index 01bfdb96a3..b159bb5578 100644 --- a/test/srt/test_skip_tokenizer_init.py +++ b/test/srt/test_skip_tokenizer_init.py @@ -6,19 +6,22 @@ from sglang.srt.utils import kill_child_process from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, popen_launch_server, ) class TestSkipTokenizerInit(unittest.TestCase): - @classmethod def setUpClass(cls): cls.model = DEFAULT_MODEL_NAME_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST cls.process = popen_launch_server( - cls.model, cls.base_url, timeout=300, other_args=["--skip-tokenizer-init"] + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--skip-tokenizer-init"], ) @classmethod diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py index 2c40f53602..818aae2151 100644 --- a/test/srt/test_srt_endpoint.py +++ b/test/srt/test_srt_endpoint.py @@ -6,25 +6,32 @@ from sglang.srt.utils import kill_child_process from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, popen_launch_server, ) class TestSRTEndpoint(unittest.TestCase): - @classmethod def setUpClass(cls): cls.model = DEFAULT_MODEL_NAME_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300) + cls.process = popen_launch_server( + cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH + ) @classmethod def tearDownClass(cls): kill_child_process(cls.process.pid) def run_decode( - self, return_logprob=False, top_logprobs_num=0, return_text=False, n=1 + self, + return_logprob=False, + top_logprobs_num=0, + return_text=False, + n=1, + stream=False, ): response = requests.post( self.base_url + "/generate", @@ -35,14 +42,21 @@ def run_decode( "max_new_tokens": 32, "n": n, }, - "stream": False, + "stream": stream, "return_logprob": return_logprob, "top_logprobs_num": top_logprobs_num, "return_text_in_logprobs": return_text, "logprob_start_len": 0, }, ) - print(json.dumps(response.json())) + if not stream: + response_json = response.json() + else: + response_json = [] + for line in response.iter_lines(): + if line.startswith(b"data: ") and line[6:] != b"[DONE]": + response_json.append(json.loads(line[6:])) + print(json.dumps(response_json)) print("=" * 100) def test_simple_decode(self): @@ -51,6 +65,9 @@ def test_simple_decode(self): def test_parallel_sample(self): self.run_decode(n=3) + def test_parallel_sample_stream(self): + self.run_decode(n=3, stream=True) + def test_logprob(self): for top_logprobs_num in [0, 3]: for return_text in [True, False]: diff --git a/test/srt/test_torch_compile.py b/test/srt/test_torch_compile.py index c8869a9cca..e8cafa15d2 100644 --- a/test/srt/test_torch_compile.py +++ b/test/srt/test_torch_compile.py @@ -1,23 +1,28 @@ import unittest from types import SimpleNamespace +import requests + from sglang.srt.utils import kill_child_process from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, popen_launch_server, ) -class TestAccuracy(unittest.TestCase): - +class TestTorchCompile(unittest.TestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MODEL_NAME_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST cls.process = popen_launch_server( - cls.model, cls.base_url, timeout=300, other_args=["--enable-torch-compile"] + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--enable-torch-compile", "--disable-radix-cache"], ) @classmethod @@ -29,12 +34,39 @@ def test_mmlu(self): base_url=self.base_url, model=self.model, eval_name="mmlu", - num_examples=20, - num_threads=20, + num_examples=32, + num_threads=32, ) metrics = run_eval(args) - assert metrics["score"] >= 0.5 + assert metrics["score"] >= 0.6 + + def run_decode(self, max_new_tokens): + response = requests.post( + self.base_url + "/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": max_new_tokens, + }, + "ignore_eos": True, + }, + ) + return response.json() + + def test_throughput(self): + import time + + max_tokens = 256 + + tic = time.time() + res = self.run_decode(max_tokens) + tok = time.time() + print(res["text"]) + throughput = max_tokens / (tok - tic) + print(f"Throughput: {throughput} tokens/s") + assert throughput >= 152 if __name__ == "__main__": diff --git a/test/srt/test_triton_attn_backend.py b/test/srt/test_triton_attn_backend.py new file mode 100644 index 0000000000..a94ca92124 --- /dev/null +++ b/test/srt/test_triton_attn_backend.py @@ -0,0 +1,44 @@ +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_child_process +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestTritonAttnBackend(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--disable-flashinfer"], + ) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=32, + num_threads=32, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.6 + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_update_weights.py b/test/srt/test_update_weights.py new file mode 100644 index 0000000000..7b8404c735 --- /dev/null +++ b/test/srt/test_update_weights.py @@ -0,0 +1,109 @@ +import json +import unittest + +import requests + +from sglang.srt.utils import kill_child_process +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestReplaceWeights(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH + ) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid) + + def run_decode(self): + response = requests.post( + self.base_url + "/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": 32, + "n": 1, + }, + "stream": False, + "return_logprob": False, + "top_logprobs_num": 0, + "return_text_in_logprobs": False, + "logprob_start_len": 0, + }, + ) + print(json.dumps(response.json())) + print("=" * 100) + # return the "text" in response + text = response.json()["text"] + return text + + def get_model_info(self): + response = requests.get(self.base_url + "/get_model_info") + model_path = response.json()["model_path"] + print(json.dumps(response.json())) + return model_path + + def run_update_weights(self, model_path): + response = requests.post( + self.base_url + "/update_weights", + json={ + "model_path": model_path, + }, + ) + print(json.dumps(response.json())) + + def test_replace_weights(self): + origin_model_path = self.get_model_info() + print(f"origin_model_path: {origin_model_path}") + origin_response = self.run_decode() + + # update weights + new_model_path = "meta-llama/Meta-Llama-3.1-8B" + self.run_update_weights(new_model_path) + + updated_model_path = self.get_model_info() + print(f"updated_model_path: {updated_model_path}") + assert updated_model_path == new_model_path + assert updated_model_path != origin_model_path + + updated_response = self.run_decode() + assert origin_response[:32] != updated_response[:32] + + # update weights back + self.run_update_weights(origin_model_path) + updated_model_path = self.get_model_info() + assert updated_model_path == origin_model_path + + updated_response = self.run_decode() + assert origin_response[:32] == updated_response[:32] + + def test_replace_weights_unexist_model(self): + origin_model_path = self.get_model_info() + print(f"origin_model_path: {origin_model_path}") + origin_response = self.run_decode() + + # update weights + new_model_path = "meta-llama/Meta-Llama-3.1-8B-1" + self.run_update_weights(new_model_path) + + updated_model_path = self.get_model_info() + print(f"updated_model_path: {updated_model_path}") + assert updated_model_path == origin_model_path + + updated_response = self.run_decode() + assert origin_response[:32] == updated_response[:32] + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_vision_openai_server.py b/test/srt/test_vision_openai_server.py index 0449e33f1b..cf29c0e815 100644 --- a/test/srt/test_vision_openai_server.py +++ b/test/srt/test_vision_openai_server.py @@ -1,31 +1,38 @@ +import base64 +import io import json +import os import unittest +import numpy as np import openai +import requests +from decord import VideoReader, cpu +from PIL import Image -from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.utils import kill_child_process -from sglang.test.test_utils import DEFAULT_URL_FOR_TEST, popen_launch_server +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) class TestOpenAIVisionServer(unittest.TestCase): - @classmethod def setUpClass(cls): - cls.model = "liuhaotian/llava-v1.6-vicuna-7b" + cls.model = "lmms-lab/llava-onevision-qwen2-0.5b-ov" cls.base_url = DEFAULT_URL_FOR_TEST cls.api_key = "sk-123456" cls.process = popen_launch_server( cls.model, cls.base_url, - timeout=300, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, api_key=cls.api_key, other_args=[ "--chat-template", - "vicuna_v1.1", - "--tokenizer-path", - "llava-hf/llava-1.5-7b-hf", - "--log-requests", + "chatml-llava", + # "--log-requests", ], ) cls.base_url += "/v1" @@ -62,13 +69,130 @@ def test_chat_completion(self): assert response.choices[0].message.role == "assistant" text = response.choices[0].message.content assert isinstance(text, str) - assert "car" in text or "taxi" in text, text + assert "man" in text or "cab" in text, text + assert response.id + assert response.created + assert response.usage.prompt_tokens > 0 + assert response.usage.completion_tokens > 0 + assert response.usage.total_tokens > 0 + + def test_mult_images_chat_completion(self): + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + + response = client.chat.completions.create( + model="default", + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png" + }, + }, + { + "type": "image_url", + "image_url": { + "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png" + }, + }, + { + "type": "text", + "text": "I have two very different images. They are not related at all. " + "Please describe the first image in one sentence, and then describe the second image in another sentence.", + }, + ], + }, + ], + temperature=0, + ) + + assert response.choices[0].message.role == "assistant" + text = response.choices[0].message.content + assert isinstance(text, str) + print(text) + assert "man" in text and "taxi" in text, text + assert "logo" in text, text assert response.id assert response.created assert response.usage.prompt_tokens > 0 assert response.usage.completion_tokens > 0 assert response.usage.total_tokens > 0 + def prepare_video_messages(self, video_path): + max_frames_num = 32 + vr = VideoReader(video_path, ctx=cpu(0)) + total_frame_num = len(vr) + uniform_sampled_frames = np.linspace( + 0, total_frame_num - 1, max_frames_num, dtype=int + ) + frame_idx = uniform_sampled_frames.tolist() + frames = vr.get_batch(frame_idx).asnumpy() + + base64_frames = [] + for frame in frames: + pil_img = Image.fromarray(frame) + buff = io.BytesIO() + pil_img.save(buff, format="JPEG") + base64_str = base64.b64encode(buff.getvalue()).decode("utf-8") + base64_frames.append(base64_str) + + messages = [{"role": "user", "content": []}] + frame_format = { + "type": "image_url", + "image_url": {"url": "data:image/jpeg;base64,{}"}, + } + + for base64_frame in base64_frames: + frame_format["image_url"]["url"] = "data:image/jpeg;base64,{}".format( + base64_frame + ) + messages[0]["content"].append(frame_format.copy()) + + prompt = {"type": "text", "text": "Please describe the video in detail."} + messages[0]["content"].append(prompt) + + return messages + + def test_video_chat_completion(self): + url = "https://raw.githubusercontent.com/EvolvingLMMs-Lab/sglang/dev/onevision_local/assets/jobs.mp4" + cache_dir = os.path.expanduser("~/.cache") + file_path = os.path.join(cache_dir, "jobs.mp4") + os.makedirs(cache_dir, exist_ok=True) + + if not os.path.exists(file_path): + response = requests.get(url) + response.raise_for_status() + + with open(file_path, "wb") as f: + f.write(response.content) + + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + + messages = self.prepare_video_messages(file_path) + + video_request = client.chat.completions.create( + model="default", + messages=messages, + temperature=0, + max_tokens=1024, + stream=True, + ) + + print("-" * 30) + video_response = "" + for chunk in video_request: + if chunk.choices[0].delta.content is not None: + content = chunk.choices[0].delta.content + video_response += content + print(content, end="", flush=True) + print("-" * 30) + + # Add assertions to validate the video response + self.assertIsNotNone(video_response) + self.assertGreater(len(video_response), 0) + def test_regex(self): client = openai.Client(api_key=self.api_key, base_url=self.base_url)