diff --git a/.buildkite/check-wheel-size.py b/.buildkite/check-wheel-size.py index 75ad094fa138..b39dce2659a5 100644 --- a/.buildkite/check-wheel-size.py +++ b/.buildkite/check-wheel-size.py @@ -1,7 +1,7 @@ import os import zipfile -MAX_SIZE_MB = 200 +MAX_SIZE_MB = 250 def print_top_10_largest_files(zip_file): diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml new file mode 100644 index 000000000000..c457468902c9 --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml @@ -0,0 +1,11 @@ +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m HandH1998/QQQ-Llama-3-8b-g128 -b 32 -l 1000 -f 5 -t 1 +model_name: "HandH1998/QQQ-Llama-3-8b-g128" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.409 + - name: "exact_match,flexible-extract" + value: 0.406 +limit: 1000 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/Minitron-4B-Base.yaml b/.buildkite/lm-eval-harness/configs/Minitron-4B-Base.yaml new file mode 100644 index 000000000000..a0466748ea71 --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Minitron-4B-Base.yaml @@ -0,0 +1,11 @@ +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nvidia/Minitron-4B-Base -b auto -l 1000 -f 5 -t 1 +model_name: "nvidia/Minitron-4B-Base" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.252 + - name: "exact_match,flexible-extract" + value: 0.252 +limit: 1000 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-FP8W8.yaml b/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-FP8W8.yaml new file mode 100644 index 000000000000..42936fbfbe7d --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-FP8W8.yaml @@ -0,0 +1,11 @@ +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Qwen2-1.5B-Instruct-FP8W8 -b auto -l 1000 -f 5 -t 1 +model_name: "nm-testing/Qwen2-1.5B-Instruct-FP8W8" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.578 + - name: "exact_match,flexible-extract" + value: 0.585 +limit: 1000 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/models-small.txt b/.buildkite/lm-eval-harness/configs/models-small.txt index 1d1b0ed38671..bca89f00653e 100644 --- a/.buildkite/lm-eval-harness/configs/models-small.txt +++ b/.buildkite/lm-eval-harness/configs/models-small.txt @@ -4,4 +4,7 @@ Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml +Minitron-4B-Base.yaml Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml +Qwen2-1.5B-Instruct-FP8W8.yaml +Meta-Llama-3-8B-QQQ.yaml diff --git a/.buildkite/nightly-benchmarks/README.md b/.buildkite/nightly-benchmarks/README.md index c84e15093430..c1aebaf5b3bb 100644 --- a/.buildkite/nightly-benchmarks/README.md +++ b/.buildkite/nightly-benchmarks/README.md @@ -3,30 +3,51 @@ ## Introduction -This directory contains the performance benchmarking CI for vllm. -The goal is to help developers know the impact of their PRs on the performance of vllm. +This directory contains two sets of benchmark for vllm. +- Performance benchmark: benchmark vllm's performance under various workload, for **developers** to gain clarity on whether their PR improves/degrades vllm's performance +- Nightly benchmark: compare vllm's performance against alternatives (tgi, trt-llm and lmdeploy), for **the public** to know when to choose vllm. -This benchmark will be *triggered* upon: + +See [vLLM performance dashboard](https://perf.vllm.ai) for the latest performance benchmark results and [vLLM GitHub README](https://github.com/vllm-project/vllm/blob/main/README.md) for latest nightly benchmark results. + + +## Performance benchmark quick overview + +**Benchmarking Coverage**: latency, throughput and fix-qps serving on A100 (the support for FP8 benchmark on H100 is coming!), with different models. + +**Benchmarking Duration**: about 1hr. + +**For benchmarking developers**: please try your best to constraint the duration of benchmarking to about 1 hr so that it won't take forever to run. + + +## Nightly benchmark quick overview + +**Benchmarking Coverage**: Fix-qps serving on A100 (the support for FP8 benchmark on H100 is coming!) on Llama-3 8B, 70B and Mixtral 8x7B. + +**Benchmarking engines**: vllm, TGI, trt-llm and lmdeploy. + +**Benchmarking Duration**: about 3.5hrs. + + + +## Trigger the benchmark + +Performance benchmark will be triggered when: - A PR being merged into vllm. - Every commit for those PRs with `perf-benchmarks` label. -**Benchmarking Coverage**: latency, throughput and fix-qps serving on A100 (the support for more GPUs is comming later), with different models. +Nightly benchmark will be triggered when: +- Every commit for those PRs with `nightly-benchmarks` label. -**Benchmarking Duration**: about 1hr. -**For benchmarking developers**: please try your best to constraint the duration of benchmarking to less than 1.5 hr so that it won't take forever to run. -## Configuring the workload +## Performance benchmark details -The benchmarking workload contains three parts: -- Latency tests in `latency-tests.json`. -- Throughput tests in `throughput-tests.json`. -- Serving tests in `serving-tests.json`. +See [descriptions.md](tests/descriptions.md) for detailed descriptions, and use `tests/latency-tests.json`, `tests/throughput-tests.json`, `tests/serving-tests.json` to configure the test cases. -See [descriptions.md](tests/descriptions.md) for detailed descriptions. -### Latency test +#### Latency test Here is an example of one test inside `latency-tests.json`: @@ -54,12 +75,12 @@ Note that the performance numbers are highly sensitive to the value of the param WARNING: The benchmarking script will save json results by itself, so please do not configure `--output-json` parameter in the json file. -### Throughput test +#### Throughput test The tests are specified in `throughput-tests.json`. The syntax is similar to `latency-tests.json`, except for that the parameters will be fed forward to `benchmark_throughput.py`. The number of this test is also stable -- a slight change on the value of this number might vary the performance numbers by a lot. -### Serving test +#### Serving test We test the throughput by using `benchmark_serving.py` with request rate = inf to cover the online serving overhead. The corresponding parameters are in `serving-tests.json`, and here is an example: ``` @@ -96,9 +117,36 @@ The number of this test is less stable compared to the delay and latency benchma WARNING: The benchmarking script will save json results by itself, so please do not configure `--save-results` or other results-saving-related parameters in `serving-tests.json`. -## Visualizing the results +#### Visualizing the results The `convert-results-json-to-markdown.py` helps you put the benchmarking results inside a markdown table, by formatting [descriptions.md](tests/descriptions.md) with real benchmarking results. You can find the result presented as a table inside the `buildkite/performance-benchmark` job page. If you do not see the table, please wait till the benchmark finish running. The json version of the table (together with the json version of the benchmark) will be also attached to the markdown file. The raw benchmarking results (in the format of json files) are in the `Artifacts` tab of the benchmarking. + + + +## Nightly test details + +See [nightly-descriptions.md](nightly-descriptions.md) for the detailed description on test workload, models and docker containers of benchmarking other llm engines. + + +#### Workflow + +- The [nightly-pipeline.yaml](nightly-pipeline.yaml) specifies the docker containers for different LLM serving engines. +- Inside each container, we run [run-nightly-suite.sh](run-nightly-suite.sh), which will probe the serving engine of the current container. +- The `run-nightly-suite.sh` will redirect the request to `tests/run-[llm serving engine name]-nightly.sh`, which parses the workload described in [nightly-tests.json](tests/nightly-tests.json) and performs the benchmark. +- At last, we run [scripts/plot-nightly-results.py](scripts/plot-nightly-results.py) to collect and plot the final benchmarking results, and update the results to buildkite. + +#### Nightly tests + +In [nightly-tests.json](tests/nightly-tests.json), we include the command line arguments for benchmarking commands, together with the benchmarking test cases. The format is highly similar to performance benchmark. + +#### Docker containers + +The docker containers for benchmarking are specified in `nightly-pipeline.yaml`. + +WARNING: the docker versions are HARD-CODED and SHOULD BE ALIGNED WITH `nightly-descriptions.md`. The docker versions need to be hard-coded as there are several version-specific bug fixes inside `tests/run-[llm serving engine name]-nightly.sh`. + +WARNING: populating `trt-llm` to latest version is not easy, as it requires updating several protobuf files in [tensorrt-demo](https://github.com/neuralmagic/tensorrt-demo.git). + diff --git a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml index 02c0ee534d72..8490c9f1da22 100644 --- a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml +++ b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml @@ -42,20 +42,20 @@ steps: - name: devshm emptyDir: medium: Memory - - label: "H100" - agents: - queue: H100 - plugins: - - docker#v5.11.0: - image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT - command: - - bash - - .buildkite/nightly-benchmarks/run-benchmarks-suite.sh - mount-buildkite-agent: true - propagate-environment: true - ipc: host - gpus: all - environment: - - VLLM_USAGE_SOURCE - - HF_TOKEN + # - label: "H100" + # agents: + # queue: H100 + # plugins: + # - docker#v5.11.0: + # image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + # command: + # - bash + # - .buildkite/nightly-benchmarks/run-benchmarks-suite.sh + # mount-buildkite-agent: true + # propagate-environment: true + # ipc: host + # gpus: all + # environment: + # - VLLM_USAGE_SOURCE + # - HF_TOKEN diff --git a/.buildkite/nightly-benchmarks/run-benchmarks-suite.sh b/.buildkite/nightly-benchmarks/run-benchmarks-suite.sh index 04b02adf3644..1a88d038b4b5 100644 --- a/.buildkite/nightly-benchmarks/run-benchmarks-suite.sh +++ b/.buildkite/nightly-benchmarks/run-benchmarks-suite.sh @@ -34,6 +34,15 @@ check_hf_token() { fi } +ensure_sharegpt_downloaded() { + local FILE=ShareGPT_V3_unfiltered_cleaned_split.json + if [ ! -f "$FILE" ]; then + wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/$FILE + else + echo "$FILE already exists." + fi +} + json2args() { # transforms the JSON string to command line args, and '_' is replaced to '-' # example: @@ -73,11 +82,6 @@ kill_gpu_processes() { echo "All GPU processes have been killed." fi - # Sometimes kill with pid doesn't work properly, we can also kill all process running python or python3 - # since we are in container anyway - pkill -9 -f python - pkill -9 -f python3 - # waiting for GPU processes to be fully killed # loop while nvidia-smi returns any processes while [ -n "$(nvidia-smi --query-compute-apps=pid --format=csv,noheader)" ]; do @@ -355,7 +359,7 @@ main() { # prepare for benchmarking cd benchmarks || exit 1 - wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json + ensure_sharegpt_downloaded declare -g RESULTS_FOLDER=results/ mkdir -p $RESULTS_FOLDER QUICK_BENCHMARK_ROOT=../.buildkite/nightly-benchmarks/ diff --git a/.buildkite/nightly-benchmarks/tests/serving-tests.json b/.buildkite/nightly-benchmarks/tests/serving-tests.json index 86a0fefa339f..300af0524d7c 100644 --- a/.buildkite/nightly-benchmarks/tests/serving-tests.json +++ b/.buildkite/nightly-benchmarks/tests/serving-tests.json @@ -55,5 +55,26 @@ "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", "num_prompts": 200 } + }, + { + "test_name": "serving_llama70B_tp4_sharegpt_specdecode", + "qps_list": [2], + "server_parameters": { + "model": "meta-llama/Meta-Llama-3-70B-Instruct", + "disable_log_requests": "", + "tensor_parallel_size": 4, + "swap_space": 16, + "speculative_model": "turboderp/Qwama-0.5B-Instruct", + "num_speculative_tokens": 4, + "speculative_draft_tensor_parallel_size": 1, + "use_v2_block_manager": "" + }, + "client_parameters": { + "model": "meta-llama/Meta-Llama-3-70B-Instruct", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } } -] \ No newline at end of file +] diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml index 4fa1951134eb..5be9a553dddd 100644 --- a/.buildkite/release-pipeline.yaml +++ b/.buildkite/release-pipeline.yaml @@ -3,13 +3,15 @@ steps: agents: queue: cpu_queue commands: - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg CUDA_VERSION={{matrix.cuda_version}} --tag vllm-ci:build-image --target build --progress plain ." + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg buildkite_commit=$BUILDKITE_COMMIT --build-arg USE_SCCACHE=1 --build-arg CUDA_VERSION={{matrix.cuda_version}} --tag vllm-ci:build-image --target build --progress plain ." - "mkdir artifacts" - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" # rename the files to change linux -> manylinux1 - "for f in artifacts/dist/*.whl; do mv -- \"$$f\" \"$${f/linux/manylinux1}\"; done" - "aws s3 cp --recursive artifacts/dist s3://vllm-wheels/$BUILDKITE_COMMIT/" - "aws s3 cp --recursive artifacts/dist s3://vllm-wheels/nightly/" + env: + DOCKER_BUILDKIT: "1" matrix: setup: cuda_version: diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh index 618d712b0279..ccc2f090565e 100644 --- a/.buildkite/run-amd-test.sh +++ b/.buildkite/run-amd-test.sh @@ -55,7 +55,7 @@ while true; do done echo "--- Pulling container" -image_name="rocmshared/vllm-ci:${BUILDKITE_COMMIT}" +image_name="rocm/vllm-ci:${BUILDKITE_COMMIT}" container_name="rocm_${BUILDKITE_COMMIT}_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)" docker pull ${image_name} diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index a7678aae5464..45bc8eb2f847 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -3,26 +3,38 @@ set -ex # Try building the docker image -docker build -t cpu-test -f Dockerfile.cpu . -docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" -t cpu-test-avx2 -f Dockerfile.cpu . +numactl -C 48-95 -N 1 docker build -t cpu-test -f Dockerfile.cpu . +numactl -C 48-95 -N 1 docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" -t cpu-test-avx2 -f Dockerfile.cpu . # Setup cleanup remove_docker_container() { docker rm -f cpu-test cpu-test-avx2 || true; } trap remove_docker_container EXIT remove_docker_container -# Run the image +# Run the image, setting --shm-size=4g for tensor parallel. docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus=48-95 \ - --cpuset-mems=1 --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --name cpu-test cpu-test + --cpuset-mems=1 --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test cpu-test docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus=48-95 \ - --cpuset-mems=1 --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --name cpu-test-avx2 cpu-test-avx2 + --cpuset-mems=1 --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-avx2 cpu-test-avx2 # offline inference -docker exec cpu-test bash -c "python3 examples/offline_inference.py" docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py" # Run basic model test -docker exec cpu-test bash -c "cd tests; +docker exec cpu-test bash -c " pip install pytest Pillow protobuf - cd ../ - pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py" # Mamba on CPU is not supported + pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py --ignore=tests/models/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported + +# online inference +docker exec cpu-test bash -c " + export VLLM_CPU_KVCACHE_SPACE=10 + export VLLM_CPU_OMP_THREADS_BIND=48-92 + python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m & + timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1 + python3 benchmarks/benchmark_serving.py \ + --backend vllm \ + --dataset-name random \ + --model facebook/opt-125m \ + --num-prompts 20 \ + --endpoint /v1/completions \ + --tokenizer facebook/opt-125m" diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index e7dd1fdb2e66..93b3e3fe9166 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -17,11 +17,10 @@ steps: - pytest -v -s test_utils.py # Utils - pytest -v -s worker # Worker -- label: Tensorizer, Metrics, Tracing Test +- label: Metrics, Tracing Test fast_check: true fast_check_only: true commands: - - apt-get install -y curl libsodium23 && pytest -v -s tensorizer_loader # Tensorizer - pytest -v -s metrics # Metrics - "pip install \ opentelemetry-sdk \ @@ -45,7 +44,7 @@ steps: fast_check: true commands: # This flashinfer installation will fail on AMD ROCm, so it is set as optional. - - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl || true + - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl || true - pytest -v -s basic_correctness/test_basic_correctness.py - pytest -v -s basic_correctness/test_cpu_offload.py - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py @@ -57,7 +56,6 @@ steps: fast_check: true commands: - pytest -v -s core - - pytest -v -s distributed/test_parallel_state.py - label: Distributed Comm Ops Test #mirror_hardwares: [amd] @@ -84,20 +82,9 @@ steps: num_gpus: 2 commands: - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py - - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py - - TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py - - TEST_DIST_MODEL=microsoft/Phi-3-vision-128k-instruct DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py - - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py - - TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py - - TEST_DIST_MODEL=microsoft/Phi-3-vision-128k-instruct DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py + - TARGET_TEST_SUITE=L4 pytest -v -s distributed/test_basic_distributed_correctness.py + - pytest -v -s distributed/test_chunked_prefill_distributed.py + - pytest -v -s distributed/test_multimodal_broadcast.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py @@ -109,11 +96,6 @@ steps: fast_check: true commands: - pytest -v -s distributed/test_pynccl.py - # We want to test that models which use 2 GPUs work with 4 GPUs, which is why we duplicate them here. - # See https://github.com/vllm-project/vllm/pull/5473#issuecomment-2166601837 for context. - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py - label: Pipeline Parallelism Test @@ -141,14 +123,13 @@ steps: working_dir: "/vllm-workspace/examples" mirror_hardwares: [amd] commands: - # install aws cli for llava_example.py # install tensorizer for tensorize_vllm_model.py - pip install awscli tensorizer - python3 offline_inference.py - python3 cpu_offload.py - python3 offline_inference_with_prefix.py - python3 llm_engine_example.py - - python3 llava_example.py + - python3 offline_inference_vision_language.py - python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - label: Inputs Test @@ -157,17 +138,17 @@ steps: - pytest -v -s test_inputs.py - pytest -v -s multimodal -- label: Kernels Test %N - #mirror_hardwares: [amd] - commands: - - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl - - pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT - parallelism: 4 +# - label: Kernels Test %N +# #mirror_hardwares: [amd] +# commands: +# - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl +# - pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT +# parallelism: 4 - label: Models Test #mirror_hardwares: [amd] commands: - - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl + - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl - pytest -v -s models -m \"not vlm\" - label: Vision Language Models Test @@ -204,23 +185,24 @@ steps: - export VLLM_ATTENTION_BACKEND=XFORMERS - pytest -v -s spec_decode -- label: LoRA Test %N - #mirror_hardwares: [amd] - command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py - parallelism: 4 - -- label: LoRA Long Context (Distributed) - #mirror_hardwares: [amd] - num_gpus: 4 - # This test runs llama 13B, so it is required to run on 4 GPUs. - commands: - # FIXIT: find out which code initialize cuda before running the test - # before the fix, we need to use spawn to test it - - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - pytest -v -s -x lora/test_long_context.py +# - label: LoRA Test %N +# #mirror_hardwares: [amd] +# command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py +# parallelism: 4 + +# - label: LoRA Long Context (Distributed) +# #mirror_hardwares: [amd] +# num_gpus: 4 +# # This test runs llama 13B, so it is required to run on 4 GPUs. +# commands: +# # FIXIT: find out which code initialize cuda before running the test +# # before the fix, we need to use spawn to test it +# - export VLLM_WORKER_MULTIPROC_METHOD=spawn +# - pytest -v -s -x lora/test_long_context.py - label: Tensorizer Test #mirror_hardwares: [amd] + fast_check: true commands: - apt-get install -y curl libsodium23 - export VLLM_WORKER_MULTIPROC_METHOD=spawn @@ -281,9 +263,6 @@ steps: # NOTE: don't test llama model here, it seems hf implementation is buggy # see https://github.com/vllm-project/vllm/pull/5689 for details - pytest -v -s distributed/test_custom_all_reduce.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py - - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl - - VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - - VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=meta-llama/Meta-Llama-3-8B DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py + - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl + - TARGET_TEST_SUITE=A100 pytest -v -s distributed/test_basic_distributed_correctness.py - pytest -v -s -x lora/test_mixtral.py diff --git a/.github/workflows/clang-format.yml b/.github/workflows/clang-format.yml index e9b6e28fa6bc..79b85d8cad0d 100644 --- a/.github/workflows/clang-format.yml +++ b/.github/workflows/clang-format.yml @@ -30,12 +30,6 @@ jobs: run: | EXCLUDES=( 'csrc/moe/topk_softmax_kernels.cu' - 'csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu' - 'csrc/punica/bgmv/bgmv_config.h' - 'csrc/punica/bgmv/bgmv_impl.cuh' - 'csrc/punica/bgmv/vec_dtypes.cuh' - 'csrc/punica/punica_ops.cu' - 'csrc/punica/type_convert.h' ) find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \ | grep -vFf <(printf "%s\n" "${EXCLUDES[@]}") \ diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index 5780f09a646c..8d423657630c 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} @@ -32,22 +32,17 @@ jobs: pip install types-setuptools - name: Mypy run: | - mypy tests --config-file pyproject.toml - mypy vllm/*.py --config-file pyproject.toml - mypy vllm/attention --config-file pyproject.toml - mypy vllm/core --config-file pyproject.toml - mypy vllm/distributed --config-file pyproject.toml - mypy vllm/engine --config-file pyproject.toml - mypy vllm/entrypoints --config-file pyproject.toml - mypy vllm/executor --config-file pyproject.toml - mypy vllm/inputs --config-file pyproject.toml - mypy vllm/logging --config-file pyproject.toml - mypy vllm/lora --config-file pyproject.toml - mypy vllm/model_executor --config-file pyproject.toml - mypy vllm/multimodal --config-file pyproject.toml - mypy vllm/platforms --config-file pyproject.toml - mypy vllm/spec_decode --config-file pyproject.toml - mypy vllm/transformers_utils --config-file pyproject.toml - mypy vllm/usage --config-file pyproject.toml - mypy vllm/worker --config-file pyproject.toml + mypy + mypy tests --follow-imports skip + mypy vllm/attention --follow-imports skip + mypy vllm/core --follow-imports skip + mypy vllm/distributed --follow-imports skip + mypy vllm/engine --follow-imports skip + mypy vllm/entrypoints --follow-imports skip + mypy vllm/executor --follow-imports skip + mypy vllm/lora --follow-imports skip + mypy vllm/model_executor --follow-imports skip + mypy vllm/prompt_adapter --follow-imports skip + mypy vllm/spec_decode --follow-imports skip + mypy vllm/worker --follow-imports skip diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 15c2ec05b25d..aeeaf6efab04 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -48,8 +48,8 @@ jobs: fail-fast: false matrix: os: ['ubuntu-20.04'] - python-version: ['3.8', '3.9', '3.10', '3.11'] - pytorch-version: ['2.3.1'] # Must be the most recent version that meets requirements-cuda.txt. + python-version: ['3.8', '3.9', '3.10', '3.11', '3.12'] + pytorch-version: ['2.4.0'] # Must be the most recent version that meets requirements-cuda.txt. cuda-version: ['11.8', '12.1'] steps: diff --git a/.github/workflows/remove_label_not_ready_comment.yml b/.github/workflows/remove_label_not_ready_comment.yml new file mode 100644 index 000000000000..d1da7726eaee --- /dev/null +++ b/.github/workflows/remove_label_not_ready_comment.yml @@ -0,0 +1,23 @@ +name: Remove ready Label on notready Comment + +on: + issue_comment: + types: [created] + +jobs: + add-ready-label: + runs-on: ubuntu-latest + if: github.event.issue.pull_request && contains(github.event.comment.body, '/notready') + steps: + - name: Remove ready label + uses: actions/github-script@v5 + with: + script: | + github.rest.issues.removeLabel({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + name: 'ready' + }) + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index 773def58fd96..1a794af572fe 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} diff --git a/.github/workflows/scripts/build.sh b/.github/workflows/scripts/build.sh index 60a3978f9abd..0a759d303238 100644 --- a/.github/workflows/scripts/build.sh +++ b/.github/workflows/scripts/build.sh @@ -13,8 +13,6 @@ $python_executable -m pip install -r requirements-cuda.txt # Limit the number of parallel jobs to avoid OOM export MAX_JOBS=1 -# Make sure punica is built for the release (for LoRA) -export VLLM_INSTALL_PUNICA_KERNELS=1 # Make sure release wheels are built for the following architectures export TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0+PTX" # Build diff --git a/.github/workflows/yapf.yml b/.github/workflows/yapf.yml index 04f307bcf8b0..c89f82dfaaaf 100644 --- a/.github/workflows/yapf.yml +++ b/.github/workflows/yapf.yml @@ -14,7 +14,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 428e19908858..f1959ad2743f 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -10,6 +10,7 @@ build: sphinx: configuration: docs/source/conf.py + fail_on_warning: true # If using Sphinx, optionally build your docs in additional formats such as PDF formats: diff --git a/CMakeLists.txt b/CMakeLists.txt index 270a8568b2f0..835014c97de2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,7 +14,7 @@ include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake) # Supported python versions. These versions will be searched in order, the # first match will be selected. These should be kept in sync with setup.py. # -set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11") +set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11" "3.12") # Supported NVIDIA architectures. set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0") @@ -32,7 +32,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11 # requirements.txt files and should be kept consistent. The ROCm torch # versions are derived from Dockerfile.rocm # -set(TORCH_SUPPORTED_VERSION_CUDA "2.3.1") +set(TORCH_SUPPORTED_VERSION_CUDA "2.4.0") set(TORCH_SUPPORTED_VERSION_ROCM "2.5.0") # @@ -66,6 +66,39 @@ endif() # find_package(Torch REQUIRED) +# +# Add the `default` target which detects which extensions should be +# built based on platform/architecture. This is the same logic that +# setup.py uses to select which extensions should be built and should +# be kept in sync. +# +# The `default` target makes direct use of cmake easier since knowledge +# of which extensions are supported has been factored in, e.g. +# +# mkdir build && cd build +# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=../vllm .. +# cmake --build . --target default +# +add_custom_target(default) +message(STATUS "Enabling core extension.") + +# # Define _core_C extension +# # built for (almost) every target platform, (excludes TPU and Neuron) + +# set(VLLM_EXT_SRC +# "csrc/core/torch_bindings.cpp") + +# define_gpu_extension_target( +# _core_C +# DESTINATION vllm +# LANGUAGE CXX +# SOURCES ${VLLM_EXT_SRC} +# COMPILE_FLAGS ${CXX_COMPILE_FLAGS} +# USE_SABI 3 +# WITH_SOABI) + +# add_dependencies(default _core_C) + # # Forward the non-CUDA device extensions to external CMake scripts. # @@ -78,7 +111,7 @@ if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda" AND set(VLLM_GPU_LANG "SYCL") include(${CMAKE_CURRENT_LIST_DIR}/cmake/xpu_extension.cmake) else() - message(FATAL_ERROR "Unsupported vLLM target device: ${VLLM_TARGET_DEVICE}") + return() endif() return() endif() @@ -136,7 +169,7 @@ if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA") endif() # -# Define extension targets +# Define other extension targets # # @@ -160,12 +193,13 @@ set(VLLM_EXT_SRC if(VLLM_GPU_LANG STREQUAL "CUDA") include(FetchContent) - SET(CUTLASS_ENABLE_HEADERS_ONLY=ON) + SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") FetchContent_Declare( cutlass GIT_REPOSITORY https://github.com/nvidia/cutlass.git - # CUTLASS 3.5.0 - GIT_TAG 7d49e6c7e2f8896c47f586706e67e1fb215529dc + # CUTLASS 3.5.1 + GIT_TAG 06b21349bcf6ddf6a1686a47a137ad1446579db9 + GIT_PROGRESS TRUE ) FetchContent_MakeAvailable(cutlass) @@ -174,6 +208,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/awq/gemm_kernels.cu" "csrc/quantization/marlin/dense/marlin_cuda_kernel.cu" "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu" + "csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu" "csrc/quantization/gptq_marlin/gptq_marlin.cu" "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" "csrc/quantization/gptq_marlin/awq_marlin_repack.cu" @@ -204,7 +239,7 @@ define_gpu_extension_target( SOURCES ${VLLM_EXT_SRC} COMPILE_FLAGS ${VLLM_GPU_FLAGS} ARCHITECTURES ${VLLM_GPU_ARCHES} - INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} USE_SABI 3 WITH_SOABI) @@ -226,76 +261,7 @@ define_gpu_extension_target( USE_SABI 3 WITH_SOABI) -# -# _punica_C extension -# - -set(VLLM_PUNICA_EXT_SRC - "csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu" - "csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu" - "csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu" - "csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu" - "csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu" - "csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu" - "csrc/punica/punica_ops.cu" - "csrc/punica/torch_bindings.cpp") - -# -# Copy GPU compilation flags+update for punica -# -set(VLLM_PUNICA_GPU_FLAGS ${VLLM_GPU_FLAGS}) -list(REMOVE_ITEM VLLM_PUNICA_GPU_FLAGS - "-D__CUDA_NO_HALF_OPERATORS__" - "-D__CUDA_NO_HALF_CONVERSIONS__" - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" - "-D__CUDA_NO_HALF2_OPERATORS__") - -# -# Filter out CUDA architectures < 8.0 for punica. -# -if (${VLLM_GPU_LANG} STREQUAL "CUDA") - set(VLLM_PUNICA_GPU_ARCHES) - foreach(ARCH ${VLLM_GPU_ARCHES}) - string_to_ver(CODE_VER ${ARCH}) - if (CODE_VER GREATER_EQUAL 8.0) - list(APPEND VLLM_PUNICA_GPU_ARCHES ${ARCH}) - endif() - endforeach() - message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}") -elseif(${VLLM_GPU_LANG} STREQUAL "HIP") - set(VLLM_PUNICA_GPU_ARCHES ${VLLM_GPU_ARCHES}) - message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}") -endif() - -if (VLLM_PUNICA_GPU_ARCHES) - define_gpu_extension_target( - _punica_C - DESTINATION vllm - LANGUAGE ${VLLM_GPU_LANG} - SOURCES ${VLLM_PUNICA_EXT_SRC} - COMPILE_FLAGS ${VLLM_PUNICA_GPU_FLAGS} - ARCHITECTURES ${VLLM_PUNICA_GPU_ARCHES} - USE_SABI 3 - WITH_SOABI) -else() - message(WARNING "Unable to create _punica_C target because none of the " - "requested architectures (${VLLM_GPU_ARCHES}) are supported, i.e. >= 8.0") -endif() -# -# Add the `default` target which detects which extensions should be -# built based on platform/architecture. This is the same logic that -# setup.py uses to select which extensions should be built and should -# be kept in sync. -# -# The `default` target makes direct use of cmake easier since knowledge -# of which extensions are supported has been factored in, e.g. -# -# mkdir build && cd build -# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=../vllm .. -# cmake --build . --target default -# -add_custom_target(default) if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") message(STATUS "Enabling C extension.") @@ -304,12 +270,4 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") message(STATUS "Enabling moe extension.") add_dependencies(default _moe_C) - # Enable punica if -DVLLM_INSTALL_PUNICA_KERNELS=ON or - # VLLM_INSTALL_PUNICA_KERNELS is set in the environment and - # there are supported target arches. - if (VLLM_PUNICA_GPU_ARCHES AND - (ENV{VLLM_INSTALL_PUNICA_KERNELS} OR VLLM_INSTALL_PUNICA_KERNELS)) - message(STATUS "Enabling punica extension.") - add_dependencies(default _punica_C) - endif() endif() diff --git a/Dockerfile b/Dockerfile index 2b4da1ce7ee1..49aaea2949ac 100644 --- a/Dockerfile +++ b/Dockerfile @@ -42,6 +42,7 @@ WORKDIR /workspace # install build and runtime dependencies COPY requirements-common.txt requirements-common.txt +COPY requirements-adag.txt requirements-adag.txt COPY requirements-cuda.txt requirements-cuda.txt RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install -r requirements-cuda.txt @@ -78,6 +79,7 @@ COPY setup.py setup.py COPY cmake cmake COPY CMakeLists.txt CMakeLists.txt COPY requirements-common.txt requirements-common.txt +COPY requirements-adag.txt requirements-adag.txt COPY requirements-cuda.txt requirements-cuda.txt COPY pyproject.toml pyproject.toml COPY vllm vllm @@ -88,8 +90,6 @@ ENV MAX_JOBS=${max_jobs} # number of threads used by nvcc ARG nvcc_threads=8 ENV NVCC_THREADS=$nvcc_threads -# make sure punica kernels are built (for LoRA) -ENV VLLM_INSTALL_PUNICA_KERNELS=1 ARG buildkite_commit ENV BUILDKITE_COMMIT=${buildkite_commit} @@ -103,7 +103,11 @@ RUN --mount=type=cache,target=/root/.cache/pip \ && tar -xzf sccache.tar.gz \ && sudo mv sccache-v0.8.1-x86_64-unknown-linux-musl/sccache /usr/bin/sccache \ && rm -rf sccache.tar.gz sccache-v0.8.1-x86_64-unknown-linux-musl \ - && export SCCACHE_BUCKET=vllm-build-sccache \ + && if [ "$CUDA_VERSION" = "11.8.0" ]; then \ + export SCCACHE_BUCKET=vllm-build-sccache-2; \ + else \ + export SCCACHE_BUCKET=vllm-build-sccache; \ + fi \ && export SCCACHE_REGION=us-west-2 \ && export CMAKE_BUILD_TYPE=Release \ && sccache --show-stats \ @@ -168,7 +172,7 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ && python3 --version RUN apt-get update -y \ - && apt-get install -y python3-pip git curl libibverbs-dev + && apt-get install -y python3-pip git vim curl libibverbs-dev # Install pip s.t. it will be compatible with our PYTHON_VERSION RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} @@ -190,7 +194,7 @@ RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamb python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir RUN --mount=type=cache,target=/root/.cache/pip \ - python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.9/flashinfer-0.0.9+cu121torch2.3-cp310-cp310-linux_x86_64.whl + python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl #################### vLLM installation IMAGE #################### diff --git a/Dockerfile.cpu b/Dockerfile.cpu index f95d748f1e4b..78730f39721c 100644 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -2,8 +2,8 @@ FROM ubuntu:22.04 AS cpu-test-1 -RUN apt-get update -y \ - && apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 \ +RUN apt-get update -y \ + && apt-get install -y curl git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 libnuma-dev \ && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 # https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/performance_tuning/tuning_guide.html @@ -13,8 +13,9 @@ RUN pip install intel-openmp ENV LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:/usr/local/lib/libiomp5.so:$LD_PRELOAD" +RUN echo 'ulimit -c 0' >> ~/.bashrc -RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.3.100%2Bgit0eb3473-cp310-cp310-linux_x86_64.whl +RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.4.0%2Bgitfbaa4bc-cp310-cp310-linux_x86_64.whl RUN pip install --upgrade pip \ && pip install wheel packaging ninja "setuptools>=49.4.0" numpy diff --git a/Dockerfile.openvino b/Dockerfile.openvino index cfb786485266..c84dea419e58 100644 --- a/Dockerfile.openvino +++ b/Dockerfile.openvino @@ -1,7 +1,7 @@ # The vLLM Dockerfile is used to construct vLLM image that can be directly used # to run the OpenAI compatible server. -FROM ubuntu:20.04 AS dev +FROM ubuntu:22.04 AS dev RUN apt-get update -y && \ apt-get install -y python3-pip git @@ -13,12 +13,15 @@ COPY requirements-common.txt /workspace/vllm/ COPY requirements-openvino.txt /workspace/vllm/ COPY vllm/ /workspace/vllm/vllm +COPY csrc/core /workspace/vllm/csrc/core +COPY cmake/utils.cmake /workspace/vllm/cmake/ +COPY CMakeLists.txt /workspace/vllm/ COPY setup.py /workspace/vllm/ # install build requirements RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" python3 -m pip install -r /workspace/vllm/requirements-build.txt # build vLLM with OpenVINO backend -RUN PIP_PRE=1 PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu https://storage.openvinotoolkit.org/simple/wheels/nightly/" VLLM_TARGET_DEVICE="openvino" python3 -m pip install /workspace/vllm/ +RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu https://storage.openvinotoolkit.org/simple/wheels/pre-release" VLLM_TARGET_DEVICE="openvino" python3 -m pip install /workspace/vllm/ COPY examples/ /workspace/vllm/examples COPY benchmarks/ /workspace/vllm/benchmarks diff --git a/Dockerfile.rocm b/Dockerfile.rocm index ff3979145639..33423fde4ff9 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -53,10 +53,10 @@ RUN apt-get purge -y sccache; python3 -m pip uninstall -y sccache; rm -f "$(whic # Install torch == 2.5.0 on ROCm RUN case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \ *"rocm-6.1"*) \ - python3 -m pip uninstall -y torch torchaudio torchvision \ + python3 -m pip uninstall -y torch torchvision \ && python3 -m pip install --no-cache-dir --pre \ - torch==2.5.0.dev20240710 torchaudio==2.4.0.dev20240710 \ - torchvision==0.20.0.dev20240710 \ + torch==2.5.0.dev20240726 \ + torchvision==0.20.0.dev20240726 \ --index-url https://download.pytorch.org/whl/nightly/rocm6.1;; \ *) ;; esac @@ -127,19 +127,11 @@ FROM base AS final # Import the vLLM development directory from the build context COPY . . -# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt. -# Manually remove it so that later steps of numpy upgrade can continue -RUN case "$(which python3)" in \ - *"/opt/conda/envs/py_3.9"*) \ - rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/;; \ - *) ;; esac - # Package upgrades for useful functionality or to avoid dependency issues RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install --upgrade numba scipy huggingface-hub[cli] -# Make sure punica kernels are built (for LoRA) -ENV VLLM_INSTALL_PUNICA_KERNELS=1 + # Workaround for ray >= 2.10.0 ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1 # Silences the HF Tokenizers warning diff --git a/Dockerfile.tpu b/Dockerfile.tpu index be7dbe63cb23..adebb8ab5adc 100644 --- a/Dockerfile.tpu +++ b/Dockerfile.tpu @@ -1,4 +1,4 @@ -ARG NIGHTLY_DATE="20240713" +ARG NIGHTLY_DATE="20240726" ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE" FROM $BASE_IMAGE @@ -12,6 +12,9 @@ RUN pip install "numpy<2" RUN pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html RUN pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html +# Fix FastAPI dependence +RUN pip install "starlette<0.38.0" + # Build vLLM. COPY . /workspace/vllm ENV VLLM_TARGET_DEVICE="tpu" diff --git a/MANIFEST.in b/MANIFEST.in index 82be639ef4d7..5a41e5e71418 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,5 @@ include LICENSE +include requirements-adag.txt include requirements-common.txt include requirements-cuda.txt include requirements-rocm.txt diff --git a/README.md b/README.md index 725c69b3f02a..5f23f0813f60 100644 --- a/README.md +++ b/README.md @@ -16,16 +16,9 @@ Easy, fast, and cheap LLM serving for everyone --- -**The Fifth vLLM Bay Area Meetup (July 24th 5pm-8pm PT)** - -We are excited to announce our fifth vLLM Meetup! -Join us to hear the vLLM's recent updates and the upcoming roadmap. -Additionally, our collaborators from AWS will be presenting their insights and experiences in deploying vLLM. -Register now [here](https://lu.ma/lp0gyjqr) and be part of the event! - ---- - *Latest News* 🔥 +- [2024/07] We hosted [the fifth vLLM meetup](https://lu.ma/lp0gyjqr) with AWS! Please find the meetup slides [here](https://docs.google.com/presentation/d/1RgUD8aCfcHocghoP3zmXzck9vX3RCI9yfUAB2Bbcl4Y/edit?usp=sharing). +- [2024/07] In partnership with Meta, vLLM officially supports Llama 3.1 with FP8 quantization and pipeline parallelism! Please check out our blog post [here](https://blog.vllm.ai/2024/07/23/llama31.html). - [2024/06] We hosted [the fourth vLLM meetup](https://lu.ma/agivllm) with Cloudflare and BentoML! Please find the meetup slides [here](https://docs.google.com/presentation/d/1iJ8o7V2bQEi0BFEljLTwc5G1S10_Rhv3beed5oB0NJ4/edit?usp=sharing). - [2024/04] We hosted [the third vLLM meetup](https://robloxandvllmmeetup2024.splashthat.com/) with Roblox! Please find the meetup slides [here](https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing). - [2024/01] We hosted [the second vLLM meetup](https://lu.ma/ygxbpzhl) with IBM! Please find the meetup slides [here](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing). @@ -46,7 +39,7 @@ vLLM is fast with: - Quantization: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [SqueezeLLM](https://arxiv.org/abs/2306.07629), FP8 KV Cache - Optimized CUDA kernels -**Performance benchmark**: We include a [performance benchmark](https://buildkite.com/vllm/performance-benchmark/builds/3924) that compares the performance of vllm against other LLM serving engines ([TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM), [text-generation-inference](https://github.com/huggingface/text-generation-inference) and [lmdeploy](https://github.com/InternLM/lmdeploy)). +**Performance benchmark**: We include a [performance benchmark](https://buildkite.com/vllm/performance-benchmark/builds/4068) that compares the performance of vllm against other LLM serving engines ([TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM), [text-generation-inference](https://github.com/huggingface/text-generation-inference) and [lmdeploy](https://github.com/InternLM/lmdeploy)). vLLM is flexible and easy to use with: diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index fbab547d094f..44c47617e2fe 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -236,6 +236,8 @@ async def async_request_openai_completions( "temperature": 0.0, "best_of": request_func_input.best_of, "max_tokens": request_func_input.output_len, + "min_tokens": request_func_input.output_len, + "ignore_eos": True, "stream": True, } headers = { diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 8d0554b0f4f0..97afd301c8f2 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -11,7 +11,7 @@ from vllm import LLM, SamplingParams from vllm.engine.arg_utils import EngineArgs -from vllm.inputs import PromptStrictInputs +from vllm.inputs import PromptInputs from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.utils import FlexibleArgumentParser @@ -61,7 +61,7 @@ def main(args: argparse.Namespace): dummy_prompt_token_ids = np.random.randint(10000, size=(args.batch_size, args.input_len)) - dummy_inputs: List[PromptStrictInputs] = [{ + dummy_inputs: List[PromptInputs] = [{ "prompt_token_ids": batch } for batch in dummy_prompt_token_ids.tolist()] diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index fc0dbf77f16b..07c0e676dd9d 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -428,7 +428,9 @@ def main(args: argparse.Namespace): np.random.seed(args.seed) backend = args.backend - model_id = args.model + # model_id = args.model + model_id = args.model.split('/')[-1] + tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model if args.base_url is not None: diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py index 234c2c8a1074..64011b2db239 100644 --- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -13,7 +13,7 @@ from vllm import _custom_ops as ops from vllm.utils import FlexibleArgumentParser -DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())[1:] +DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] DEFAULT_TP_SIZES = [1] @@ -112,13 +112,20 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) timers = [] - # pytorch impl + # pytorch impl - bfloat16 timers.append( bench_fn(a.to(dtype=torch.bfloat16, device="cuda"), b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b, torch.bfloat16, label, sub_label, pytorch_mm_impl, "pytorch_bf16_bf16_bf16_matmul-no-scales")) + # pytorch impl - float16 + timers.append( + bench_fn(a.to(dtype=torch.float16, device="cuda"), + b.to(dtype=torch.float16, device="cuda"), scale_a, scale_b, + torch.float16, label, sub_label, pytorch_mm_impl, + "pytorch_fp16_fp16_fp16_matmul-no-scales")) + # cutlass impl timers.append( bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label, diff --git a/benchmarks/kernels/benchmark_marlin.py b/benchmarks/kernels/benchmark_marlin.py index 3da4cecd7eef..536c133bb334 100644 --- a/benchmarks/kernels/benchmark_marlin.py +++ b/benchmarks/kernels/benchmark_marlin.py @@ -7,16 +7,17 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, - GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS) + GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, - GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS) + MARLIN_SUPPORTED_GROUP_SIZES, query_marlin_supported_quant_types) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( MarlinWorkspace, marlin_quantize) from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import ( marlin_24_quantize) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - gptq_pack, quantize_weights, sort_weights) + gptq_pack, gptq_quantize_weights, sort_weights) +from vllm.scalar_type import ScalarType from vllm.utils import FlexibleArgumentParser DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"] @@ -27,13 +28,14 @@ def bench_run(results: List[benchmark.Measurement], model: str, - act_order: bool, is_k_full: bool, num_bits: int, group_size: int, - size_m: int, size_k: int, size_n: int): + act_order: bool, is_k_full: bool, quant_type: ScalarType, + group_size: int, size_m: int, size_k: int, size_n: int): label = "Quant Matmul" - sub_label = ("{}, act={} k_full={}, b={}, g={}, " - "MKN=({}x{}x{})".format(model, act_order, is_k_full, num_bits, - group_size, size_m, size_k, size_n)) + sub_label = ("{}, act={} k_full={}, q={}, g={}, " + "MKN=({}x{}x{})".format(model, act_order, is_k_full, + str(quant_type), group_size, size_m, + size_k, size_n)) print(f"Testing: {sub_label}") @@ -50,16 +52,18 @@ def bench_run(results: List[benchmark.Measurement], model: str, marlin_g_idx, marlin_sort_indices, marlin_rand_perm, - ) = marlin_quantize(b, num_bits, group_size, act_order) + ) = marlin_quantize(b, quant_type, group_size, act_order) # Marlin_24 quant (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, - marlin_24_s) = marlin_24_quantize(b, num_bits, group_size) + marlin_24_s) = marlin_24_quantize(b, quant_type, group_size) + + marlin_zp = torch.empty(0, dtype=torch.int, device=b.device) # GPTQ quant (w_ref, q_w, s, g_idx, - rand_perm) = quantize_weights(b, num_bits, group_size, act_order) - q_w_gptq = gptq_pack(q_w, num_bits, size_k, size_n) + rand_perm) = gptq_quantize_weights(b, quant_type, group_size, act_order) + q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n) # For act_order, sort the "weights" and "g_idx" # so that group ids are increasing @@ -73,10 +77,11 @@ def bench_run(results: List[benchmark.Measurement], model: str, marlin_24_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL) + marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int) globals = { # Gen params - "num_bits": num_bits, + "quant_type": quant_type, "group_size": group_size, "size_m": size_m, "size_n": size_n, @@ -87,6 +92,7 @@ def bench_run(results: List[benchmark.Measurement], model: str, "marlin_w_ref": marlin_w_ref, "marlin_q_w": marlin_q_w, "marlin_s": marlin_s, + "marlin_zp": marlin_zp, "marlin_g_idx": marlin_g_idx, "marlin_sort_indices": marlin_sort_indices, "marlin_rand_perm": marlin_rand_perm, @@ -125,19 +131,29 @@ def bench_run(results: List[benchmark.Measurement], model: str, results.append( benchmark.Timer( stmt= - "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full)", # noqa: E501 + "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="gptq_marlin_gemm_fp16", + ).blocked_autorange(min_run_time=min_run_time)) + + results.append( + benchmark.Timer( + stmt= + "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, - description="gptq_marlin_gemm", + description="gptq_marlin_gemm_fp32", ).blocked_autorange(min_run_time=min_run_time)) - if (num_bits in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS + if (quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES): results.append( benchmark.Timer( stmt= - "output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, num_bits, size_m, size_n, size_k)", # noqa: E501 + "output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, @@ -147,7 +163,7 @@ def bench_run(results: List[benchmark.Measurement], model: str, results.append( benchmark.Timer( stmt= - "q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, num_bits)", # noqa: E501 + "q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, @@ -183,12 +199,13 @@ def main(args): ) > 0 and is_k_full not in args.limit_k_full: continue - for num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS: - if len(args.limit_num_bits - ) > 0 and num_bits not in args.limit_num_bits: + for quant_type in query_marlin_supported_quant_types( + False): + if len(args.limit_num_bits) > 0 and \ + quant_type.size_bits not in args.limit_num_bits: continue - for group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES: + for group_size in MARLIN_SUPPORTED_GROUP_SIZES: if len( args.limit_group_size ) > 0 and group_size not in args.limit_group_size: @@ -202,8 +219,8 @@ def main(args): for size_m in args.batch_sizes: bench_run(results, model, act_order, is_k_full, - num_bits, group_size, size_m, size_k, - size_n) + quant_type, group_size, size_m, + size_k, size_n) compare = benchmark.Compare(results) compare.print() diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 78cac8a555d1..a04433142da4 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -175,7 +175,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: parser.add_argument("--num-kv-heads", type=int, default=8) parser.add_argument("--head-size", type=int, - choices=[64, 80, 96, 112, 128, 192, 256], + choices=[64, 80, 96, 112, 120, 128, 192, 256], default=128) parser.add_argument("--block-size", type=int, choices=[16, 32], default=16) parser.add_argument("--use-alibi", action="store_true") diff --git a/benchmarks/kernels/benchmark_rope.py b/benchmarks/kernels/benchmark_rope.py index 78736c7a7ba6..f542684a9a2a 100644 --- a/benchmarks/kernels/benchmark_rope.py +++ b/benchmarks/kernels/benchmark_rope.py @@ -94,7 +94,7 @@ def benchmark_rope_kernels_multi_lora( parser.add_argument("--num-heads", type=int, default=8) parser.add_argument("--head-size", type=int, - choices=[64, 80, 96, 112, 128, 192, 256], + choices=[64, 80, 96, 112, 120, 128, 192, 256], default=128) parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32) parser.add_argument("--dtype", diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 690559ee265e..3ba3a2b6a93c 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -83,6 +83,8 @@ endif() message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}") +list(APPEND LIBS "numa") + # # Define extension targets @@ -95,6 +97,7 @@ set(VLLM_EXT_SRC "csrc/cpu/activation.cpp" "csrc/cpu/attention.cpp" "csrc/cpu/cache.cpp" + "csrc/cpu/utils.cpp" "csrc/cpu/layernorm.cpp" "csrc/cpu/pos_encoding.cpp" "csrc/cpu/torch_bindings.cpp") @@ -104,11 +107,11 @@ define_gpu_extension_target( DESTINATION vllm LANGUAGE CXX SOURCES ${VLLM_EXT_SRC} + LIBRARIES ${LIBS} COMPILE_FLAGS ${CXX_COMPILE_FLAGS} USE_SABI 3 WITH_SOABI ) -add_custom_target(default) message(STATUS "Enabling C extension.") add_dependencies(default _C) diff --git a/cmake/utils.cmake b/cmake/utils.cmake index 0c60a86c960d..86ac11289e1b 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -181,7 +181,7 @@ macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES) # # The torch cmake setup hardcodes the detected architecture flags in # `CMAKE_CUDA_FLAGS`. Since `CMAKE_CUDA_FLAGS` is a "global" variable, it - # can't modified on a per-target basis, e.g. for the `punica` extension. + # can't modified on a per-target basis. # So, all the `-gencode` flags need to be extracted and removed from # `CMAKE_CUDA_FLAGS` for processing so they can be passed by another method. # Since it's not possible to use `target_compiler_options` for adding target diff --git a/cmake/xpu_extension.cmake b/cmake/xpu_extension.cmake index 1a5e6780bb2b..008751e26f5c 100644 --- a/cmake/xpu_extension.cmake +++ b/cmake/xpu_extension.cmake @@ -53,7 +53,7 @@ define_gpu_extension_target( WITH_SOABI ) -add_custom_target(default) +add_custom_target(default_xpu) message(STATUS "Enabling C extension.") -add_dependencies(default _C) +add_dependencies(default_xpu _C) diff --git a/collect_env.py b/collect_env.py index 083cb768f539..244e4ddd5aed 100644 --- a/collect_env.py +++ b/collect_env.py @@ -65,6 +65,7 @@ "optree", "nccl", "transformers", + "zmq", } DEFAULT_PIP_PATTERNS = { @@ -77,6 +78,7 @@ "onnx", "nccl", "transformers", + "zmq", } diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 350dbce1d7ba..bcd170411e7c 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -706,7 +706,7 @@ void paged_attention_v1_launcher( int kv_block_stride = key_cache.stride(0); int kv_head_stride = key_cache.stride(1); - int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); + [[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); assert(head_size % thread_group_size == 0); // NOTE: alibi_slopes is optional. @@ -751,6 +751,9 @@ void paged_attention_v1_launcher( case 112: LAUNCH_PAGED_ATTENTION_V1(112); break; + case 120: + LAUNCH_PAGED_ATTENTION_V1(120); + break; case 128: LAUNCH_PAGED_ATTENTION_V1(128); break; @@ -862,7 +865,7 @@ void paged_attention_v2_launcher( int kv_block_stride = key_cache.stride(0); int kv_head_stride = key_cache.stride(1); - int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); + [[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); assert(head_size % thread_group_size == 0); // NOTE: alibi_slopes is optional. @@ -912,6 +915,9 @@ void paged_attention_v2_launcher( case 112: LAUNCH_PAGED_ATTENTION_V2(112); break; + case 120: + LAUNCH_PAGED_ATTENTION_V2(120); + break; case 128: LAUNCH_PAGED_ATTENTION_V2(128); break; diff --git a/csrc/attention/dtype_bfloat16.cuh b/csrc/attention/dtype_bfloat16.cuh index 3cdcb95e0809..97a25baa1fc0 100644 --- a/csrc/attention/dtype_bfloat16.cuh +++ b/csrc/attention/dtype_bfloat16.cuh @@ -94,6 +94,7 @@ inline __device__ float2 bf1622float2(const __nv_bfloat162 val) { #else return __bfloat1622float2(val); #endif + __builtin_unreachable(); // Suppress missing return statement warning } inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { @@ -102,6 +103,7 @@ inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { #else return __bfloat162bfloat162(val); #endif + __builtin_unreachable(); // Suppress missing return statement warning } // Vector addition. @@ -115,6 +117,7 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { return __hadd(a, b); #endif #endif + __builtin_unreachable(); // Suppress missing return statement warning } inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) { @@ -123,6 +126,7 @@ inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) { #else return __hadd2(a, b); #endif + __builtin_unreachable(); // Suppress missing return statement warning } inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) { @@ -170,6 +174,7 @@ inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) { #else return __hmul(a, b); #endif + __builtin_unreachable(); // Suppress missing return statement warning } template <> @@ -179,6 +184,7 @@ inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) { #else return __hmul2(a, b); #endif + __builtin_unreachable(); // Suppress missing return statement warning } template <> @@ -289,6 +295,7 @@ inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, #else return __hfma2(a, b, c); #endif + __builtin_unreachable(); // Suppress missing return statement warning } inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, @@ -298,6 +305,7 @@ inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, #else return __hfma2(bf162bf162(a), b, c); #endif + __builtin_unreachable(); // Suppress missing return statement warning } inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) { diff --git a/csrc/cache.h b/csrc/cache.h index 52177e8901a8..11c4c5001daa 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -25,7 +25,8 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype); + const std::string& kv_cache_dtype, + const double k_scale, const double v_scale); // Just for unittest void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index caef7f5e1863..1be806bbfa43 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -203,17 +203,18 @@ __global__ void reshape_and_cache_kernel( } } -template +template __global__ void reshape_and_cache_flash_kernel( const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] - scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads, + cache_t* __restrict__ key_cache, // [num_blocks, block_size, num_heads, // head_size] - scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads, + cache_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads, // head_size] const int64_t* __restrict__ slot_mapping, // [num_tokens] const int block_stride, const int key_stride, const int value_stride, - const int num_heads, const int head_size, const int block_size) { + const int num_heads, const int head_size, const int block_size, + const float k_scale, const float v_scale) { const int64_t token_idx = blockIdx.x; const int64_t slot_idx = slot_mapping[token_idx]; // NOTE: slot_idx can be -1 if the token is padded @@ -228,11 +229,20 @@ __global__ void reshape_and_cache_flash_kernel( const int64_t src_value_idx = token_idx * value_stride + i; const int head_idx = i / head_size; const int head_offset = i % head_size; - const int64_t tgt_value_idx = block_idx * block_stride + - block_offset * num_heads * head_size + - head_idx * head_size + head_offset; - k_cache[tgt_value_idx] = key[src_key_idx]; - v_cache[tgt_value_idx] = value[src_value_idx]; + const int64_t tgt_key_value_idx = block_idx * block_stride + + block_offset * num_heads * head_size + + head_idx * head_size + head_offset; + scalar_t tgt_key = key[src_key_idx]; + scalar_t tgt_value = value[src_value_idx]; + if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { + key_cache[tgt_key_value_idx] = tgt_key; + value_cache[tgt_key_value_idx] = tgt_value; + } else { + key_cache[tgt_key_value_idx] = + fp8::scaled_convert(tgt_key, k_scale); + value_cache[tgt_key_value_idx] = + fp8::scaled_convert(tgt_value, v_scale); + } } } } // namespace vllm @@ -278,40 +288,45 @@ void reshape_and_cache( CALL_RESHAPE_AND_CACHE) } +// KV_T is the stored data type of kv-cache. +// CACHE_T is the data type of key and value tensors. +// KV_DTYPE is the real data type of kv-cache. +#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \ + vllm::reshape_and_cache_flash_kernel \ + <<>>( \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(key_cache.data_ptr()), \ + reinterpret_cast(value_cache.data_ptr()), \ + slot_mapping.data_ptr(), block_stride, key_stride, \ + value_stride, num_heads, head_size, block_size, k_scale, v_scale); + void reshape_and_cache_flash( - torch::Tensor& key, // [num_tokens, num_heads, head_size] - torch::Tensor& value, // [num_tokens, num_heads, head_size] - torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size] - torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& + value_cache, // [num_blocks, block_size, num_heads, head_size] torch::Tensor& slot_mapping, // [num_tokens] - const std::string& kv_cache_dtype) { - // FIXME: only support auto datatype, does not support fp8 - if (kv_cache_dtype != "auto") { - TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); - } + const std::string& kv_cache_dtype, const double k_scale, + const double v_scale) { int num_tokens = key.size(0); int num_heads = key.size(1); int head_size = key.size(2); - int block_size = k_cache.size(1); + int block_size = key_cache.size(1); int key_stride = key.stride(0); int value_stride = value.stride(0); - int block_stride = k_cache.stride(0); - TORCH_CHECK(k_cache.stride(0) == v_cache.stride(0)); + int block_stride = key_cache.stride(0); + TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0)); dim3 grid(num_tokens); dim3 block(std::min(num_heads * head_size, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( - key.scalar_type(), "reshape_and_cache_flash", [&] { - vllm::reshape_and_cache_flash_kernel - <<>>( - key.data_ptr(), value.data_ptr(), - k_cache.data_ptr(), v_cache.data_ptr(), - slot_mapping.data_ptr(), block_stride, key_stride, - value_stride, num_heads, head_size, block_size); - }); + + DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, + CALL_RESHAPE_AND_CACHE_FLASH); } namespace vllm { diff --git a/csrc/registration.h b/csrc/core/registration.h similarity index 100% rename from csrc/registration.h rename to csrc/core/registration.h diff --git a/csrc/core/scalar_type.hpp b/csrc/core/scalar_type.hpp new file mode 100644 index 000000000000..9f78402eee2a --- /dev/null +++ b/csrc/core/scalar_type.hpp @@ -0,0 +1,382 @@ +#pragma once + +#include + +namespace vllm { + +// +// ScalarType can represent a wide range of floating point and integer types, +// in particular it can be used to represent sub-byte data types (something +// that torch.dtype currently does not support). +// +// ScalarTypeTorch is a subclass of ScalarType that is compatible with +// TORCH_LIBRARY, making it accessible from Python as well meaning this class +// can be used as a argument for custom operators, helping to simplify these +// interfaces. +// +// The type definitions on the Python side can be found in: vllm/_core_ext.pyi +// these type definitions should be kept up to date with any Python API changes +// here. +// +class ScalarType { + public: + enum NanRepr : int64_t { + NAN_NONE = 0, // nans are not supported + NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s + NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s + + NAN_REPR_ID_MAX + }; + + constexpr ScalarType(bool signed_, int64_t exponent, int64_t mantissa, + int64_t bias, bool finite_values_only = false, + NanRepr nan_repr = NAN_IEEE_754) + : exponent(exponent), + mantissa(mantissa), + bias(bias), + signed_(signed_), + finite_values_only(finite_values_only), + nan_repr(nan_repr){}; + + static constexpr ScalarType int_(int64_t size_bits, int64_t bias = 0) { + return ScalarType(true, 0, size_bits - 1, bias); + } + + static constexpr ScalarType uint(int64_t size_bits, int64_t bias = 0) { + return ScalarType(false, 0, size_bits, bias); + } + + // IEEE 754 compliant floating point type + static constexpr ScalarType float_IEEE754(int64_t exponent, + int64_t mantissa) { + TORCH_CHECK(mantissa > 0 && exponent > 0); + return ScalarType(true, exponent, mantissa, 0, false, NAN_IEEE_754); + } + + // IEEE 754 non-compliant floating point type + static constexpr ScalarType float_(int64_t exponent, int64_t mantissa, + bool finite_values_only, + NanRepr nan_repr) { + TORCH_CHECK(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr"); + TORCH_CHECK(mantissa > 0 && exponent > 0); + TORCH_CHECK(nan_repr != NAN_IEEE_754, + "use `float_IEEE754` constructor for floating point types that " + "follow IEEE 754 conventions"); + return ScalarType(true, exponent, mantissa, 0, finite_values_only, + nan_repr); + } + + int64_t const exponent; // size of the exponent field (0 for integer types) + int64_t const mantissa; // size of the mantissa field (size of the integer + // excluding the sign bit for integer types) + int64_t const bias; // stored values equal value + bias, + // used for quantized type + bool const signed_; // flag if the type supports negative numbers (i.e. has a + // sign bit) + + // Extra Floating point info + bool const finite_values_only; // i.e. no +/-inf if true + NanRepr const nan_repr; // how NaNs are represented + // (not applicable for integer types) + + int64_t size_bits() const { return mantissa + exponent + is_signed(); } + bool is_signed() const { return signed_; } + bool is_integer() const { return exponent == 0; } + bool is_floating_point() const { return exponent > 0; } + bool is_ieee_754() const { + return is_floating_point() && finite_values_only == false && + nan_repr == NAN_IEEE_754; + } + bool has_nans() const { return is_floating_point() && nan_repr != NAN_NONE; } + bool has_infs() const { + return is_floating_point() && finite_values_only == false; + } + bool has_bias() const { return bias != 0; } + + private: + double _floating_point_max() const { + TORCH_CHECK(mantissa <= 52 && exponent <= 11, + "Cannot represent max/min as a double for type ", str()); + + uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1; + if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) { + max_mantissa -= 1; + } + + uint64_t max_exponent = (uint64_t(1) << exponent) - 2; + if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) { + TORCH_CHECK(exponent < 11, + "Cannot represent max/min as a double for type ", str()); + max_exponent += 1; + } + + // adjust the exponent to match that of a double + // for now we assume the exponent bias is the standard 2^(e-1) -1, (where e + // is the exponent bits), there is some precedent for non-standard biases, + // example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes + // but to avoid premature over complication we are just assuming the + // standard exponent bias until there is a need to support non-standard + // biases + uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1; + uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1; // double e = 11 + + uint64_t max_exponent_double = + max_exponent - exponent_bias + exponent_bias_double; + + // shift the mantissa into the position for a double and + // the exponent + uint64_t double_raw = + (max_mantissa << (52 - mantissa)) | (max_exponent_double << 52); + + return *reinterpret_cast(&double_raw); + } + + std::variant _raw_max() const { + if (is_floating_point()) { + return {_floating_point_max()}; + } else { + TORCH_CHECK(size_bits() < 64 || size_bits() == 64 && is_signed(), + "Cannot represent max as a int64_t"); + return {(int64_t(1) << mantissa) - 1}; + } + } + + std::variant _raw_min() const { + if (is_floating_point()) { + TORCH_CHECK(is_signed(), + "We currently assume all floating point types are signed"); + constexpr uint64_t sign_bit_double = (uint64_t(1) << 63); + + double max = _floating_point_max(); + uint64_t max_raw = *reinterpret_cast(&max); + uint64_t min_raw = max_raw | sign_bit_double; + return {*reinterpret_cast(&min_raw)}; + } else { + TORCH_CHECK(!is_signed() || size_bits() <= 64, + "Cannot represent min as a int64_t"); + if (is_signed()) { + // set the top bit to 1 (i.e. INT64_MIN) and the rest to 0 + // then perform an arithmetic shift right to set all the bits above + // (size_bits() - 1) to 1 + return {INT64_MIN >> (64 - size_bits())}; + } else { + return {int64_t(0)}; + } + } + } + + public: + // Max representable value for this scalar type. + // (accounting for bias if there is one) + std::variant max() const { + return std::visit( + [this](auto x) -> std::variant { return {x - bias}; }, + _raw_max()); + } + + // Min representable value for this scalar type. + // (accounting for bias if there is one) + std::variant min() const { + return std::visit( + [this](auto x) -> std::variant { return {x - bias}; }, + _raw_min()); + } + + std::string str() const { + /* naming generally follows: https://github.com/jax-ml/ml_dtypes + * for floating point types (leading f) the scheme is: + * `float_em[flags]` + * flags: + * - no-flags: means it follows IEEE 754 conventions + * - f: means finite values only (no infinities) + * - n: means nans are supported (non-standard encoding) + * for integer types the scheme is: + * `[u]int[b]` + * - if bias is not present it means its zero + */ + if (is_floating_point()) { + auto ret = "float" + std::to_string(size_bits()) + "_e" + + std::to_string(exponent) + "m" + std::to_string(mantissa); + if (!is_ieee_754()) { + if (finite_values_only) { + ret += "f"; + } + if (nan_repr != NAN_NONE) { + ret += "n"; + } + } + return ret; + } else { + auto ret = ((is_signed()) ? "int" : "uint") + std::to_string(size_bits()); + if (has_bias()) { + ret += "b" + std::to_string(bias); + } + return ret; + } + } + + bool operator==(ScalarType const& other) const { + return mantissa == other.mantissa && exponent == other.exponent && + bias == other.bias && signed_ == other.signed_ && + finite_values_only == other.finite_values_only && + nan_repr == other.nan_repr; + } +}; + +// Create a TORCH_LIBRARY compatible version of ScalarType (i.e. inherit from +// torch::CustomClassHolder), we use multiple inheritance here since we cannot +// have ScalarType inherit from torch::CustomClassHolder and have a constexpr +// constructor at the same time (torch::CustomClassHolder does not have a +// constexpr destructor) +class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType { + public: + ScalarTypeTorch(int64_t exponent, int64_t mantissa, int64_t bias, + bool _signed) + : ScalarType(exponent, mantissa, bias, _signed){}; + + ScalarTypeTorch(ScalarType type) : ScalarType(type){}; + + using Base = ScalarType; + using Self = ScalarTypeTorch; + using SelfPtr = c10::intrusive_ptr; + + static SelfPtr int_(int64_t size_bits, c10::optional bias) { + return c10::make_intrusive( + ScalarType::int_(size_bits, bias.value_or(0))); + } + + static SelfPtr uint(int64_t size_bits, c10::optional bias) { + return c10::make_intrusive( + ScalarType::uint(size_bits, bias.value_or(0))); + } + + static SelfPtr float_IEEE754(int64_t exponent, int64_t mantissa) { + return c10::make_intrusive( + ScalarType::float_IEEE754(exponent, mantissa)); + } + + static SelfPtr float_(int64_t exponent, int64_t mantissa, + bool finite_values_only, int64_t nan_repr) { + return c10::make_intrusive(ScalarType::float_( + exponent, mantissa, finite_values_only, NanRepr(nan_repr))); + } + + template + static void bind_readonly_property(torch::class_& cls, + std::string const& name, T Base::*field) { + auto getter_func = [field = std::move(field)](SelfPtr const& self) { + if constexpr (std::is_member_function_pointer_v) { + return (self.get()->*field)(); + } else { + return self.get()->*field; + } + }; + + cls.def_property(name, getter_func); + } + + template + static void bind_function(torch::class_& cls, const std::string& name, + MemberFunc Cls::*member) { + cls.def(name, [member = std::move(member)](SelfPtr const& self) { + return (self.get()->*member)(); + }); + } + + template + static void bind_function(torch::class_& cls, const std::string& name, + Func func) { + cls.def(name, func); + } + + template + static void bind_static_function(torch::class_& cls, + const std::string& name, Func func) { + cls.def_static(name, func); + } + + static void bind_class(torch::Library& lib) { + auto cls = lib.class_("ScalarType") + .def(torch::init()); + + // Bind Properties + bind_readonly_property(cls, "mantissa", &Base::mantissa); + bind_readonly_property(cls, "exponent", &Base::exponent); + bind_readonly_property(cls, "bias", &Base::bias); + bind_readonly_property(cls, "signed", &Base::is_signed); + bind_readonly_property(cls, "size_bits", &Base::size_bits); + + // Bind member functions + bind_function(cls, "is_signed", &Base::is_signed); + bind_function(cls, "is_integer", &Base::is_integer); + bind_function(cls, "is_floating_point", &Base::is_floating_point); + bind_function(cls, "is_ieee_754", &Base::is_ieee_754); + bind_function(cls, "has_nans", &Base::has_nans); + bind_function(cls, "has_infs", &Base::has_infs); + bind_function(cls, "has_bias", &Base::has_bias); + + bind_function(cls, "max", [](SelfPtr const& self) { + return std::visit([](auto arg) { return c10::IValue(arg); }, + self.get()->max()); + }); + bind_function(cls, "min", [](SelfPtr const& self) { + return std::visit([](auto arg) { return c10::IValue(arg); }, + self.get()->min()); + }); + + bind_function(cls, "__str__", &Base::str); + bind_function(cls, "__eq__", [](SelfPtr const& self, SelfPtr const& other) { + return *self == *other; + }); + bind_function(cls, "__repr__", [](SelfPtr const& self) { + return "ScalarType." + self.get()->str(); + }); + + // Bind static functions (convenience constructors) + bind_static_function(cls, "int_", &ScalarTypeTorch::int_); + bind_static_function(cls, "uint", &ScalarTypeTorch::uint); + bind_static_function(cls, "float_IEEE754", &ScalarTypeTorch::float_IEEE754); + bind_static_function(cls, "float_", &ScalarTypeTorch::float_); + } +}; + +using ScalarTypeTorchPtr = c10::intrusive_ptr; + +// "rust style" names generally following: +// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70 +static inline constexpr auto kS4 = ScalarType::int_(4); +static inline constexpr auto kU4 = ScalarType::uint(4); +static inline constexpr auto kU4B8 = ScalarType::uint(4, 8); +static inline constexpr auto kS8 = ScalarType::int_(8); +static inline constexpr auto kU8 = ScalarType::uint(8); +static inline constexpr auto kU8B128 = ScalarType::uint(8, 128); + +static inline constexpr auto kFE3M2f = + ScalarType::float_(3, 2, true, ScalarType::NAN_NONE); +static inline constexpr auto kFE4M3fn = + ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN); +static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2); +static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7); +static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10); + +// Fixed width style names, generally following: +// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L47-L57 +static inline constexpr auto kInt4 = kS4; +static inline constexpr auto kUint4 = kU4; +static inline constexpr auto kUint4b8 = kU4B8; +static inline constexpr auto kInt8 = kS8; +static inline constexpr auto kUint8 = kU8; +static inline constexpr auto kUint8b128 = kU8B128; + +static inline constexpr auto kFloat6_e3m2f = kFE3M2f; +static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn; +static inline constexpr auto kFloat8_e5m2 = kFE5M2; +static inline constexpr auto kFloat16_e8m7 = kFE8M7; +static inline constexpr auto kFloat16_e5m10 = kFE5M10; + +// colloquial names +static inline constexpr auto kHalf = kFE5M10; +static inline constexpr auto kFloat16 = kHalf; +static inline constexpr auto kBFloat16 = kFE8M7; + +}; // namespace vllm diff --git a/csrc/core/torch_bindings.cpp b/csrc/core/torch_bindings.cpp new file mode 100644 index 000000000000..f60254189a2f --- /dev/null +++ b/csrc/core/torch_bindings.cpp @@ -0,0 +1,16 @@ +#include + +#include "scalar_type.hpp" +#include "registration.h" + +// Note the CORE exstension will be built for (almost) all hardware targets so +// new additions must account for this. (currently not built for TPU and Neuron) + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, lib) { + // ScalarType, a custom class for representing data types that supports + // quantized types, declared here so it can be used when creating interfaces + // for custom ops. + vllm::ScalarTypeTorch::bind_class(lib); +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 5be0e9810b5b..cf7d977da7c1 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -1,9 +1,11 @@ #include "cache.h" #include "ops.h" -#include "registration.h" +#include "core/registration.h" #include +void init_cpu_threads_env(const std::string& cpu_ids); + TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // vLLM custom ops @@ -107,4 +109,9 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache); } +TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) { + // CPU utils + utils.def("init_cpu_threads_env(str cpu_ids) -> ()", &init_cpu_threads_env); +} + REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/cpu/utils.cpp b/csrc/cpu/utils.cpp new file mode 100644 index 000000000000..5782580baa86 --- /dev/null +++ b/csrc/cpu/utils.cpp @@ -0,0 +1,65 @@ +#include +#include +#include +#include + +#include "cpu_types.hpp" + +void init_cpu_threads_env(const std::string& cpu_ids) { + bitmask* omp_cpu_mask = numa_parse_cpustring(cpu_ids.c_str()); + TORCH_CHECK(omp_cpu_mask->size > 0); + std::vector omp_cpu_ids; + omp_cpu_ids.reserve(omp_cpu_mask->size); + + constexpr int group_size = 8 * sizeof(*omp_cpu_mask->maskp); + + for (int offset = 0; offset < omp_cpu_mask->size; offset += group_size) { + unsigned long group_mask = omp_cpu_mask->maskp[offset / group_size]; + int i = 0; + while (group_mask) { + if (group_mask & 1) { + omp_cpu_ids.emplace_back(offset + i); + } + ++i; + group_mask >>= 1; + } + } + + // Memory node binding + if (numa_available() != -1) { + int mem_node_id = numa_node_of_cpu(omp_cpu_ids.front()); + bitmask* mask = numa_parse_nodestring(std::to_string(mem_node_id).c_str()); + bitmask* src_mask = numa_get_membind(); + + int pid = getpid(); + + // move all existing pages to the specified numa node. + *(src_mask->maskp) = *(src_mask->maskp) ^ *(mask->maskp); + int page_num = numa_migrate_pages(pid, src_mask, mask); + if (page_num == -1) { + TORCH_CHECK(false, + "numa_migrate_pages failed. errno: " + std::to_string(errno)); + } + + // restrict memory allocation node. + numa_set_membind(mask); + numa_set_strict(1); + } + + // OMP threads binding + omp_set_num_threads((int)omp_cpu_ids.size()); + torch::set_num_threads((int)omp_cpu_ids.size()); + TORCH_CHECK_EQ(omp_cpu_ids.size(), torch::get_num_threads()); + TORCH_CHECK_EQ(omp_cpu_ids.size(), omp_get_max_threads()); +#pragma omp parallel for schedule(static, 1) + for (size_t i = 0; i < omp_cpu_ids.size(); ++i) { + cpu_set_t* mask = CPU_ALLOC(omp_cpu_mask->size); + size_t size = CPU_ALLOC_SIZE(omp_cpu_mask->size); + CPU_ZERO_S(size, mask); + CPU_SET_S(omp_cpu_ids[i], size, mask); + sched_setaffinity(0, sizeof(cpu_set_t), mask); + CPU_FREE(mask); + } + + numa_free_nodemask(omp_cpu_mask); +} diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 243752b9a9e8..86e42af44df1 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -1,4 +1,4 @@ -#include "registration.h" +#include "core/registration.h" #include "moe_ops.h" TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { diff --git a/csrc/ops.h b/csrc/ops.h index 9ef1fcb465bf..3bd4a9eda5ee 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -3,6 +3,8 @@ #include #include +#include "core/scalar_type.hpp" + void paged_attention_v1( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, @@ -84,16 +86,19 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_meta, torch::Tensor& b_scales, - torch::Tensor& workspace, int64_t num_bits, + torch::Tensor& workspace, + vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n, int64_t size_k); torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& b_zeros, torch::Tensor& g_idx, torch::Tensor& perm, - torch::Tensor& workspace, int64_t num_bits, + torch::Tensor& workspace, + vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full, bool has_zp); + bool is_k_full, bool has_zp, + bool use_fp32_reduce); torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, @@ -114,6 +119,13 @@ void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b_scales, c10::optional const& bias); +torch::Tensor marlin_qqq_gemm(torch::Tensor const& a, + torch::Tensor const& b_q_weight, + torch::Tensor const& s_tok, + torch::Tensor const& s_ch, + torch::Tensor const& s_group, + torch::Tensor& workspace, int64_t size_m, + int64_t size_n, int64_t size_k); #endif void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, diff --git a/csrc/punica/LICENSE b/csrc/punica/LICENSE deleted file mode 100644 index a46e2cdcadf7..000000000000 --- a/csrc/punica/LICENSE +++ /dev/null @@ -1,217 +0,0 @@ -Contains code from https://github.com/punica-ai/punica - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "{}" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright {yyyy} {name of copyright owner} - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ------------------------------------------------------------------------------------- - -This product bundles various third-party components under other open source licenses. -This section summarizes those components and their licenses. See licenses/ -for text of these licenses. - - -Apache-2.0 -* third_party/nvbench (with LLVM exception) -* third_party/flashinfer - -BSD-3-Clause: -* third_party/cutlass \ No newline at end of file diff --git a/csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu b/csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu deleted file mode 100644 index 86846c274c90..000000000000 --- a/csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu +++ /dev/null @@ -1,5 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16) -FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu b/csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu deleted file mode 100644 index de39c3121f5d..000000000000 --- a/csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu +++ /dev/null @@ -1,5 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16) -FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_bfloat16, float, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h deleted file mode 100644 index 2c8d007d8719..000000000000 --- a/csrc/punica/bgmv/bgmv_config.h +++ /dev/null @@ -1,218 +0,0 @@ -#pragma once - -template -void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, - const W_T *__restrict__ W, - const int64_t *__restrict__ indicies, int64_t y_offset, - int64_t full_y_size, int64_t batch_size, int64_t num_layers, - int64_t layer_idx, float scale); - -// clang-format off - -#define FOR_BGMV_WIDE(f, in_T, out_T, W_T, narrow) \ - f(in_T, out_T, W_T, narrow, 128) \ - f(in_T, out_T, W_T, narrow, 256) \ - f(in_T, out_T, W_T, narrow, 512) \ - f(in_T, out_T, W_T, narrow, 640) \ - f(in_T, out_T, W_T, narrow, 768) \ - f(in_T, out_T, W_T, narrow, 896) \ - f(in_T, out_T, W_T, narrow, 1024) \ - f(in_T, out_T, W_T, narrow, 1152) \ - f(in_T, out_T, W_T, narrow, 1216) \ - f(in_T, out_T, W_T, narrow, 1280) \ - f(in_T, out_T, W_T, narrow, 1536) \ - f(in_T, out_T, W_T, narrow, 1664) \ - f(in_T, out_T, W_T, narrow, 1728) \ - f(in_T, out_T, W_T, narrow, 1792) \ - f(in_T, out_T, W_T, narrow, 2048) \ - f(in_T, out_T, W_T, narrow, 2240) \ - f(in_T, out_T, W_T, narrow, 2304) \ - f(in_T, out_T, W_T, narrow, 2368) \ - f(in_T, out_T, W_T, narrow, 2432) \ - f(in_T, out_T, W_T, narrow, 2560) \ - f(in_T, out_T, W_T, narrow, 2752) \ - f(in_T, out_T, W_T, narrow, 2816) \ - f(in_T, out_T, W_T, narrow, 3072) \ - f(in_T, out_T, W_T, narrow, 3328) \ - f(in_T, out_T, W_T, narrow, 3456) \ - f(in_T, out_T, W_T, narrow, 3584) \ - f(in_T, out_T, W_T, narrow, 3712) \ - f(in_T, out_T, W_T, narrow, 4096) \ - f(in_T, out_T, W_T, narrow, 4480) \ - f(in_T, out_T, W_T, narrow, 4608) \ - f(in_T, out_T, W_T, narrow, 4736) \ - f(in_T, out_T, W_T, narrow, 4864) \ - f(in_T, out_T, W_T, narrow, 5120) \ - f(in_T, out_T, W_T, narrow, 5504) \ - f(in_T, out_T, W_T, narrow, 5632) \ - f(in_T, out_T, W_T, narrow, 5888) \ - f(in_T, out_T, W_T, narrow, 6144) \ - f(in_T, out_T, W_T, narrow, 6400) \ - f(in_T, out_T, W_T, narrow, 6848) \ - f(in_T, out_T, W_T, narrow, 6912) \ - f(in_T, out_T, W_T, narrow, 7168) \ - f(in_T, out_T, W_T, narrow, 7424) \ - f(in_T, out_T, W_T, narrow, 8192) \ - f(in_T, out_T, W_T, narrow, 8960) \ - f(in_T, out_T, W_T, narrow, 9216) \ - f(in_T, out_T, W_T, narrow, 9472) \ - f(in_T, out_T, W_T, narrow, 10240) \ - f(in_T, out_T, W_T, narrow, 11008) \ - f(in_T, out_T, W_T, narrow, 11264) \ - f(in_T, out_T, W_T, narrow, 12288) \ - f(in_T, out_T, W_T, narrow, 13696) \ - f(in_T, out_T, W_T, narrow, 13824) \ - f(in_T, out_T, W_T, narrow, 14336) \ - f(in_T, out_T, W_T, narrow, 14784) \ - f(in_T, out_T, W_T, narrow, 14848) \ - f(in_T, out_T, W_T, narrow, 15360) \ - f(in_T, out_T, W_T, narrow, 16384) \ - f(in_T, out_T, W_T, narrow, 18944) \ - f(in_T, out_T, W_T, narrow, 20480) \ - f(in_T, out_T, W_T, narrow, 22016) \ - f(in_T, out_T, W_T, narrow, 22528) \ - f(in_T, out_T, W_T, narrow, 24576) \ - f(in_T, out_T, W_T, narrow, 27392) \ - f(in_T, out_T, W_T, narrow, 27648) \ - f(in_T, out_T, W_T, narrow, 28672) \ - f(in_T, out_T, W_T, narrow, 29568) \ - f(in_T, out_T, W_T, narrow, 29696) \ - f(in_T, out_T, W_T, narrow, 32000) \ - f(in_T, out_T, W_T, narrow, 32256) \ - f(in_T, out_T, W_T, narrow, 32512) \ - f(in_T, out_T, W_T, narrow, 32768) \ - f(in_T, out_T, W_T, narrow, 33024) \ - f(in_T, out_T, W_T, narrow, 36864) \ - f(in_T, out_T, W_T, narrow, 43264) \ - f(in_T, out_T, W_T, narrow, 49152) \ - f(in_T, out_T, W_T, narrow, 49408) \ - f(in_T, out_T, W_T, narrow, 60544) \ - f(in_T, out_T, W_T, narrow, 60672) \ - f(in_T, out_T, W_T, narrow, 64000) \ - f(in_T, out_T, W_T, narrow, 64256) \ - f(in_T, out_T, W_T, narrow, 64512) \ - f(in_T, out_T, W_T, narrow, 102400) \ - f(in_T, out_T, W_T, narrow, 102656) \ - f(in_T, out_T, W_T, narrow, 102912) \ - f(in_T, out_T, W_T, narrow, 128000) \ - f(in_T, out_T, W_T, narrow, 128256) \ - f(in_T, out_T, W_T, narrow, 128512) \ - - -// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA -// and vllm/tests/lora/test_punica.py - -// Used for defining kernels going from the variety of -// dim in to the narrow dim out - // Using it for the fully sharded column - // parallel LoRA A which splits the rank dim -#define FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, narrow) \ - f(in_T, out_T, W_T, 128, narrow) \ - f(in_T, out_T, W_T, 256, narrow) \ - f(in_T, out_T, W_T, 512, narrow) \ - f(in_T, out_T, W_T, 640, narrow) \ - f(in_T, out_T, W_T, 768, narrow) \ - f(in_T, out_T, W_T, 896, narrow) \ - f(in_T, out_T, W_T, 1024, narrow) \ - f(in_T, out_T, W_T, 1152, narrow) \ - f(in_T, out_T, W_T, 1216, narrow) \ - f(in_T, out_T, W_T, 1280, narrow) \ - f(in_T, out_T, W_T, 1536, narrow) \ - f(in_T, out_T, W_T, 1664, narrow) \ - f(in_T, out_T, W_T, 1728, narrow) \ - f(in_T, out_T, W_T, 1792, narrow) \ - f(in_T, out_T, W_T, 2048, narrow) \ - f(in_T, out_T, W_T, 2240, narrow) \ - f(in_T, out_T, W_T, 2304, narrow) \ - f(in_T, out_T, W_T, 2368, narrow) \ - f(in_T, out_T, W_T, 2432, narrow) \ - f(in_T, out_T, W_T, 2560, narrow) \ - f(in_T, out_T, W_T, 2752, narrow) \ - f(in_T, out_T, W_T, 2816, narrow) \ - f(in_T, out_T, W_T, 3072, narrow) \ - f(in_T, out_T, W_T, 3328, narrow) \ - f(in_T, out_T, W_T, 3456, narrow) \ - f(in_T, out_T, W_T, 3584, narrow) \ - f(in_T, out_T, W_T, 3712, narrow) \ - f(in_T, out_T, W_T, 4096, narrow) \ - f(in_T, out_T, W_T, 4480, narrow) \ - f(in_T, out_T, W_T, 4608, narrow) \ - f(in_T, out_T, W_T, 4736, narrow) \ - f(in_T, out_T, W_T, 4864, narrow) \ - f(in_T, out_T, W_T, 5120, narrow) \ - f(in_T, out_T, W_T, 5504, narrow) \ - f(in_T, out_T, W_T, 5632, narrow) \ - f(in_T, out_T, W_T, 5888, narrow) \ - f(in_T, out_T, W_T, 6144, narrow) \ - f(in_T, out_T, W_T, 6400, narrow) \ - f(in_T, out_T, W_T, 6848, narrow) \ - f(in_T, out_T, W_T, 6912, narrow) \ - f(in_T, out_T, W_T, 7168, narrow) \ - f(in_T, out_T, W_T, 7424, narrow) \ - f(in_T, out_T, W_T, 8192, narrow) \ - f(in_T, out_T, W_T, 8960, narrow) \ - f(in_T, out_T, W_T, 9216, narrow) \ - f(in_T, out_T, W_T, 9472, narrow) \ - f(in_T, out_T, W_T, 10240, narrow) \ - f(in_T, out_T, W_T, 11008, narrow) \ - f(in_T, out_T, W_T, 11264, narrow) \ - f(in_T, out_T, W_T, 12288, narrow) \ - f(in_T, out_T, W_T, 13696, narrow) \ - f(in_T, out_T, W_T, 13824, narrow) \ - f(in_T, out_T, W_T, 14336, narrow) \ - f(in_T, out_T, W_T, 14784, narrow) \ - f(in_T, out_T, W_T, 14848, narrow) \ - f(in_T, out_T, W_T, 15360, narrow) \ - f(in_T, out_T, W_T, 16384, narrow) \ - f(in_T, out_T, W_T, 18944, narrow) \ - f(in_T, out_T, W_T, 20480, narrow) \ - f(in_T, out_T, W_T, 22016, narrow) \ - f(in_T, out_T, W_T, 22528, narrow) \ - f(in_T, out_T, W_T, 24576, narrow) \ - f(in_T, out_T, W_T, 27392, narrow) \ - f(in_T, out_T, W_T, 27648, narrow) \ - f(in_T, out_T, W_T, 28672, narrow) \ - f(in_T, out_T, W_T, 29568, narrow) \ - f(in_T, out_T, W_T, 29696, narrow) \ - f(in_T, out_T, W_T, 32000, narrow) \ - f(in_T, out_T, W_T, 32256, narrow) \ - f(in_T, out_T, W_T, 32512, narrow) \ - f(in_T, out_T, W_T, 32768, narrow) \ - f(in_T, out_T, W_T, 33024, narrow) \ - f(in_T, out_T, W_T, 36864, narrow) \ - f(in_T, out_T, W_T, 43264, narrow) \ - f(in_T, out_T, W_T, 49152, narrow) \ - f(in_T, out_T, W_T, 49408, narrow) \ - f(in_T, out_T, W_T, 60544, narrow) \ - f(in_T, out_T, W_T, 60672, narrow) \ - f(in_T, out_T, W_T, 64000, narrow) \ - f(in_T, out_T, W_T, 64256, narrow) \ - f(in_T, out_T, W_T, 64512, narrow) \ - f(in_T, out_T, W_T, 102400, narrow) \ - f(in_T, out_T, W_T, 102656, narrow) \ - f(in_T, out_T, W_T, 102912, narrow) \ - f(in_T, out_T, W_T, 128000, narrow) \ - f(in_T, out_T, W_T, 128256, narrow) \ - f(in_T, out_T, W_T, 128512, narrow) \ -// Keep above in sync with vllm/lora/layers::SamplerWithLoRA - - -// Keep this in sync with vllm/config::LoRAConfig -#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \ - FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \ - FOR_BGMV_WIDE(f, in_T, out_T, W_T, 16) \ - FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \ - FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64) - - -#define FOR_INST_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \ - FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 1) \ - FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 2) \ - FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 4) \ - f(in_T, out_T, W_T, 8, 64) \ - f(in_T, out_T, W_T, 16, 64) \ - f(in_T, out_T, W_T, 32, 64) \ - f(in_T, out_T, W_T, 64, 64) - -// clang-format on diff --git a/csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu b/csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu deleted file mode 100644 index d225a1eaa82b..000000000000 --- a/csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu +++ /dev/null @@ -1,5 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half) -FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, nv_half, nv_half) diff --git a/csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu b/csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu deleted file mode 100644 index b37d288a7556..000000000000 --- a/csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu +++ /dev/null @@ -1,5 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half) -FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, float, nv_half) diff --git a/csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu b/csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu deleted file mode 100644 index a1ab2deecbab..000000000000 --- a/csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu +++ /dev/null @@ -1,5 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16) -FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, nv_bfloat16, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu b/csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu deleted file mode 100644 index 0b35bf569989..000000000000 --- a/csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu +++ /dev/null @@ -1,5 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half) -FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, nv_half, nv_half) diff --git a/csrc/punica/bgmv/bgmv_impl.cuh b/csrc/punica/bgmv/bgmv_impl.cuh deleted file mode 100644 index 8a3b8403b4a6..000000000000 --- a/csrc/punica/bgmv/bgmv_impl.cuh +++ /dev/null @@ -1,451 +0,0 @@ -#pragma once - -#include -#ifndef USE_ROCM -#include -#else -#include -#endif -#ifndef USE_ROCM -#include -#endif -#include -#include -#include - -#include "vec_dtypes.cuh" - -namespace cg = cooperative_groups; - -#ifdef USE_ROCM -template -__host__ __device__ -inline void* memcpy_blocking(void *dst, const void *src) { - // Does not handle the case of long datatypes - char *d = reinterpret_cast(dst); - const char *s = reinterpret_cast(src); - size_t i = 0; -#pragma unroll - for (i = 0; i < len; ++i) { - d[i] = s[i]; - } - return dst; -} -#endif - -#ifndef USE_ROCM - -// nthrs = (32, 4) -template -__global__ void -bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, - const W_T *__restrict__ W, - const int64_t *__restrict__ indicies, int64_t y_offset, - int64_t full_y_size, int64_t num_layers, int64_t layer_idx, - float scale) { - size_t batch_idx = blockIdx.y; - int64_t idx = indicies[batch_idx] * num_layers + layer_idx; - if (idx < 0) { - return; - } - - auto block = cg::this_thread_block(); - size_t j = blockIdx.x; - constexpr size_t num_pipeline_stages = 2; - constexpr size_t tile_size = tx * ty * vec_size; - __shared__ W_T W_shared[num_pipeline_stages * tile_size]; - __shared__ in_T X_shared[num_pipeline_stages * tile_size]; - __shared__ float y_warpwise[ty]; - - size_t W_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size}; - size_t X_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size}; - auto pipe = cuda::make_pipeline(); - - // pipeline load W/X and compute WX; - pipe.producer_acquire(); - cuda::memcpy_async(W_shared + (threadIdx.y * tx + threadIdx.x) * vec_size, - W + (idx * feat_out + j) * feat_in + - (threadIdx.y * tx + threadIdx.x) * vec_size, - cuda::aligned_size_t(W_copy_size), pipe); - cuda::memcpy_async(X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size, - X + (batch_idx * feat_in) + - (threadIdx.y * tx + threadIdx.x) * vec_size, - cuda::aligned_size_t(X_copy_size), pipe); - pipe.producer_commit(); - size_t copy_idx, compute_idx; - float y = 0.f; - vec_t x_vec; - vec_t w_vec; - size_t tile_idx; - -#pragma unroll - for (tile_idx = 1; tile_idx < (feat_in + tile_size - 1) / tile_size; - ++tile_idx) { - copy_idx = tile_idx % num_pipeline_stages; - // pipeline stage: async copy W fragment - pipe.producer_acquire(); - if (tile_idx * tile_size + threadIdx.y * tx * vec_size < feat_in) { - cuda::memcpy_async(W_shared + W_shared_offset[copy_idx] + - (threadIdx.y * tx + threadIdx.x) * vec_size, - W + (idx * feat_out + j) * feat_in + - tile_idx * tile_size + - (threadIdx.y * tx + threadIdx.x) * vec_size, - cuda::aligned_size_t(W_copy_size), pipe); - cuda::memcpy_async(X_shared + X_shared_offset[copy_idx] + - (threadIdx.y * tx + threadIdx.x) * vec_size, - X + (batch_idx * feat_in) + tile_idx * tile_size + - (threadIdx.y * tx + threadIdx.x) * vec_size, - cuda::aligned_size_t(X_copy_size), pipe); - } - pipe.producer_commit(); - - compute_idx = (tile_idx - 1) % num_pipeline_stages; - // pipeline stage: compute WX - pipe.consumer_wait(); - block.sync(); - x_vec.load(X_shared + X_shared_offset[compute_idx] + - (threadIdx.y * tx + threadIdx.x) * vec_size); - w_vec.load(W_shared + W_shared_offset[compute_idx] + - (threadIdx.y * tx + threadIdx.x) * vec_size); - float sum = 0.f; -#pragma unroll - for (size_t i = 0; i < vec_size; ++i) { - sum += float(w_vec[i]) * float(x_vec[i]) * scale; - } -#pragma unroll - for (size_t offset = tx / 2; offset > 0; offset /= 2) { - sum += __shfl_down_sync(0xffffffff, sum, offset); - } - y_warpwise[threadIdx.y] = sum; - block.sync(); -#pragma unroll - for (size_t i = 0; i < ty; ++i) { - y += y_warpwise[i]; - } - - block.sync(); - pipe.consumer_release(); - } - - compute_idx = (tile_idx - 1) % num_pipeline_stages; - // final pipeline stage - pipe.consumer_wait(); - block.sync(); - x_vec.load(X_shared + X_shared_offset[compute_idx] + - (threadIdx.y * tx + threadIdx.x) * vec_size); - w_vec.load(W_shared + W_shared_offset[compute_idx] + - (threadIdx.y * tx + threadIdx.x) * vec_size); - float sum = 0.f; -#pragma unroll - for (size_t i = 0; i < vec_size; ++i) { - sum += float(w_vec[i]) * float(x_vec[i]) * scale; - } -#pragma unroll - for (size_t offset = tx / 2; offset > 0; offset /= 2) { - sum += __shfl_down_sync(0xffffffff, sum, offset); - } - y_warpwise[threadIdx.y] = - ((tile_idx - 1) * tile_size + threadIdx.y * tx * vec_size < feat_in) - ? sum - : 0.f; - block.sync(); -#pragma unroll - for (size_t i = 0; i < ty; ++i) { - y += y_warpwise[i]; - } - - block.sync(); - pipe.consumer_release(); - - // write Y; - if (block.thread_rank() == 0) { - Y[batch_idx * full_y_size + y_offset + j] += static_cast(y); - } -} - -#else - -template -__global__ void -bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, - const W_T *__restrict__ W, - const int64_t *__restrict__ indicies, int64_t y_offset, - int64_t full_y_size, int64_t num_layers, int64_t layer_idx, - float scale) { - size_t batch_idx = blockIdx.y; - int64_t idx = indicies[batch_idx] * num_layers + layer_idx; - if (idx < 0) { - return; - } - - size_t j = blockIdx.x; - constexpr size_t tile_size = tx * ty * vec_size; - constexpr size_t num_tiles = (feat_in + tile_size - 1) / tile_size; - __shared__ float y_warpwise[ty]; - - float y = 0; - vec_t x_vec; - vec_t w_vec; - size_t tile_idx; - -#pragma unroll - for (tile_idx = 0; tile_idx < num_tiles; ++tile_idx) { - if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) { - x_vec.load(X + (batch_idx * feat_in) + - tile_idx * tile_size + - (threadIdx.y * tx + threadIdx.x) * vec_size); - w_vec.load(W + (idx * feat_out + j) * feat_in + - tile_idx * tile_size + - (threadIdx.y * tx + threadIdx.x) * vec_size); - } - - float sum = 0.f; -#pragma unroll - for (size_t i = 0; i < vec_size; ++i) { - sum += convert_type(w_vec[i]) * convert_type(x_vec[i]) * scale; - } -#pragma unroll - for (size_t offset = tx / 2; offset > 0; offset /= 2) { - sum += VLLM_SHFL_DOWN_SYNC(sum, offset); - } - - __syncthreads(); - - if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) { - y += sum; - } - } - - if (threadIdx.x == 0) { - y_warpwise[threadIdx.y] = y; - } - __syncthreads(); - - float y_write = 0.f; -#pragma unroll - for (size_t i = 0; i < ty; ++i) { - y_write += y_warpwise[i]; - } - - // write Y; - if (threadIdx.x == 0 && threadIdx.y == 0) { - size_t y_idx = batch_idx * full_y_size + y_offset + j; - Y[y_idx] = vllm_add(Y[y_idx], convert_type(y_write)); - } -} - -#endif - -// nthrs = (2, 16, 4) -template -__global__ void -bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, - const W_T *__restrict__ W, - const int64_t *__restrict__ indicies, int64_t y_offset, - int64_t full_y_size, int64_t num_layers, int64_t layer_idx, - float scale) { - size_t batch_idx = blockIdx.y; - int64_t idx = indicies[batch_idx] * num_layers + layer_idx; - - if (idx < 0) { - return; - } - - auto block = cg::this_thread_block(); - size_t tile_idx = blockIdx.x; - - // load X; - vec_t x_vec; - x_vec.load(X + batch_idx * feat_in + threadIdx.x * vec_size); - - // load W; - vec_t w_vec; - w_vec.load(W + (idx * feat_out + tile_idx * tz * ty) * feat_in + - block.thread_rank() * vec_size); - - float sum = 0.f; -#pragma unroll - for (size_t i = 0; i < vec_size; ++i) { -#ifndef USE_ROCM - sum += float(w_vec[i]) * float(x_vec[i]) * scale; -#else - sum += convert_type(w_vec[i]) * convert_type(x_vec[i]) * scale; -#endif - } - - cg::thread_block_tile g = cg::tiled_partition(block); -#pragma unroll - for (size_t offset = tx / 2; offset > 0; offset /= 2) { - sum += g.shfl_down(sum, offset); - } - sum = g.shfl(sum, 0); - - if (threadIdx.x == 0) { -#ifndef USE_ROCM - Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) + - threadIdx.z * ty + threadIdx.y] += static_cast(sum); -#else - size_t y_idx = batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) + - threadIdx.z * ty + threadIdx.y; - Y[y_idx] = vllm_add(Y[y_idx], convert_type(sum)); -#endif - } -} - -template -void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, - const W_T *__restrict__ W, - const int64_t *__restrict__ indicies, int64_t y_offset, - int64_t full_y_size, int64_t batch_size, int64_t num_layers, - int64_t layer_idx, float scale) { - constexpr size_t vec_size = 8; - constexpr int tz = 4; - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - if constexpr (feat_in <= feat_out) { - static_assert(feat_in % vec_size == 0); - constexpr int tx = feat_in / vec_size; - - static_assert((32 % tx == 0 && feat_out % (32 / tx * tz) == 0) || - (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) || - (8 % tx == 0 && feat_out % (8 / tx * tz) == 0)); - - if constexpr (32 % tx == 0 && feat_out % (32 / tx * tz) == 0) { - constexpr int ty = 32 / tx; - dim3 nblks(feat_out / (ty * tz), batch_size); - dim3 nthrs(tx, ty, tz); - - bgmv_expand_kernel - <<>>(Y, X, W, indicies, y_offset, - full_y_size, num_layers, layer_idx, - scale); - } else if (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) { - constexpr int ty = 16 / tx; - dim3 nblks(feat_out / (ty * tz), batch_size); - dim3 nthrs(tx, ty, tz); - - bgmv_expand_kernel - <<>>(Y, X, W, indicies, y_offset, - full_y_size, num_layers, layer_idx, - scale); - } else { - constexpr int ty = 8 / tx; - dim3 nblks(feat_out / (ty * tz), batch_size); - dim3 nthrs(tx, ty, tz); - - bgmv_expand_kernel - <<>>(Y, X, W, indicies, y_offset, - full_y_size, num_layers, layer_idx, - scale); - } - } else { -#ifndef USE_ROCM - static_assert(feat_in % (vec_size * 32) == 0 || - feat_in % (vec_size * 16) == 0 || - feat_in % (vec_size * 8) == 0); - - if constexpr (feat_in % (vec_size * 32) == 0) { - constexpr int tx = 32; - constexpr int ty = 4; - - dim3 nblks(feat_out, batch_size); - dim3 nthrs(tx, ty); - - bgmv_shrink_kernel - <<>>(Y, X, W, indicies, y_offset, - full_y_size, num_layers, layer_idx, - scale); - } else if constexpr (feat_in % (vec_size / 2 * 32) == 0) { - constexpr int tx = 32; - constexpr int ty = 4; - - dim3 nblks(feat_out, batch_size); - dim3 nthrs(tx, ty); - - bgmv_shrink_kernel - <<>>(Y, X, W, indicies, y_offset, - full_y_size, num_layers, layer_idx, - scale); - } else if constexpr (feat_in % (vec_size / 2 * 16) == 0) { - constexpr int tx = 16; - constexpr int ty = 4; - - dim3 nblks(feat_out, batch_size); - dim3 nthrs(tx, ty); - - bgmv_shrink_kernel - <<>>(Y, X, W, indicies, y_offset, - full_y_size, num_layers, layer_idx, - scale); - } -#else - constexpr size_t rocm_warp_size = warpSize; - -#define CHECK_INPUT_TILEABLE_BY(vec_size_) \ - feat_in % (rocm_warp_size * vec_size_) == 0 - -#define LAUNCH_BGMV_SHRINK_KERNELS_ROCM(factor_, vec_size_, tx_, ty_) \ - if constexpr (CHECK_INPUT_TILEABLE_BY(factor_)) { \ - constexpr size_t vec_size_shrink = vec_size_; \ - constexpr int tx = tx_; \ - constexpr int ty = ty_; \ - dim3 nblks(feat_out, batch_size); \ - dim3 nthrs(tx, ty); \ - bgmv_shrink_kernel \ - <<>>(Y, X, W, indicies, y_offset, \ - full_y_size, num_layers, layer_idx, \ - scale); \ - } - - static_assert(CHECK_INPUT_TILEABLE_BY(32) || - CHECK_INPUT_TILEABLE_BY(16) || - CHECK_INPUT_TILEABLE_BY( 8) || - CHECK_INPUT_TILEABLE_BY( 4) || - CHECK_INPUT_TILEABLE_BY( 2) || - CHECK_INPUT_TILEABLE_BY( 1)); - - LAUNCH_BGMV_SHRINK_KERNELS_ROCM(32, vec_size, rocm_warp_size, 32/vec_size) - else - LAUNCH_BGMV_SHRINK_KERNELS_ROCM(16, vec_size, rocm_warp_size, 16/vec_size) - else - LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 8, vec_size, rocm_warp_size, 8/vec_size) - else - LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 4, vec_size, rocm_warp_size/(vec_size/4), vec_size/4) - else - LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 2, vec_size, rocm_warp_size/(vec_size/2), vec_size/2) - else - LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 1, vec_size, rocm_warp_size/(vec_size/1), vec_size/1) - -#undef CHECK_INPUT_TILEABLE_BY -#undef LAUNCH_BGMV_SHRINK_KERNELS_ROCM -#endif - } -} - -#define INST_BGMV(feat_in, feat_out, in_T, out_T, W_T) \ - template void bgmv_kernel( \ - out_T * __restrict__ Y, const in_T *__restrict__ X, \ - const W_T *__restrict__ W, const int64_t *__restrict__ indicies, \ - int64_t y_offset, int64_t full_y_size, int64_t batch_size, \ - int64_t num_layers, int64_t layer_idx, float scale); - -#define INST_BGMV_ONESIDE(in_T, out_T, W_T, feat_in, feat_out) \ - INST_BGMV(feat_in, feat_out, in_T, out_T, W_T) - -#define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \ - INST_BGMV(narrow, wide, in_T, out_T, W_T) \ - INST_BGMV(wide, narrow, in_T, out_T, W_T) diff --git a/csrc/punica/bgmv/generator.py b/csrc/punica/bgmv/generator.py deleted file mode 100644 index 972df5a7208c..000000000000 --- a/csrc/punica/bgmv/generator.py +++ /dev/null @@ -1,48 +0,0 @@ -DTYPES = ["fp16", "bf16", "fp32"] -DTYPE_MAP = { - "fp16": "nv_half", - "bf16": "nv_bfloat16", - "fp32": "float", -} - -TEMPLATE = """ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype}) -FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, {input_dtype}, {output_dtype}, {weight_dtype}) -""".lstrip() # noqa: E501 - -for input_dtype in DTYPES: - for output_dtype in DTYPES: - for weight_dtype in DTYPES: - if weight_dtype == "fp32": - # FP32 weights are not supported. - continue - if output_dtype == "fp32": - # LoRA A matrix. - if input_dtype != weight_dtype: - # NOTE(woosuk): While Punica supports the case where the - # input and weight dtypes are different, we only generate - # the kernels the same dtypes to reduce the binary size. - continue - elif input_dtype == "fp32": - # LoRA B matrix. - if output_dtype != weight_dtype: - # NOTE(woosuk): While Punica supports the case where the - # output and weight dtypes are different, we only generate - # the kernels the same dtypes to reduce the binary size. - continue - elif not (input_dtype == output_dtype == weight_dtype): - # NOTE(woosuk): While Punica supports mixed data types for - # input, output, and weight, we only generate the kernels with - # the same data types to reduce the binary size. - continue - - kernel_definition = TEMPLATE.format( - input_dtype=DTYPE_MAP[input_dtype], - output_dtype=DTYPE_MAP[output_dtype], - weight_dtype=DTYPE_MAP[weight_dtype]) - filename = f"bgmv_{input_dtype}_{output_dtype}_{weight_dtype}.cu" - with open(filename, "w") as f: - f.write(kernel_definition) diff --git a/csrc/punica/bgmv/vec_dtypes.cuh b/csrc/punica/bgmv/vec_dtypes.cuh deleted file mode 100644 index 2738892e6dc4..000000000000 --- a/csrc/punica/bgmv/vec_dtypes.cuh +++ /dev/null @@ -1,1325 +0,0 @@ -#ifndef VEC_DTYPES_CUH_ -#define VEC_DTYPES_CUH_ - -#ifdef FLASHINFER_USE_FP8 -#include -#endif -#include - -#include - -#include "../type_convert.h" -#include "../../cuda_compat.h" - -#define FLASHINFER_INLINE \ - inline __attribute__((always_inline)) __device__ __host__ - -template -struct vec_t { - FLASHINFER_INLINE float_t &operator[](size_t i); - FLASHINFER_INLINE const float_t &operator[](size_t i) const; - FLASHINFER_INLINE void fill(float_t val); - FLASHINFER_INLINE void load(const float_t *ptr); - FLASHINFER_INLINE void store(float_t *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src); - template - FLASHINFER_INLINE void cast_load(const T *ptr); - template - FLASHINFER_INLINE void cast_store(T *ptr) const; - FLASHINFER_INLINE static void memcpy(float_t *dst, const float_t *src); -}; - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t &src, - vec_t &dst) { -#pragma unroll - for (size_t i = 0; i < vec_size; ++i) { - dst[i] = tgt_float_t(src[i]); - } -} - -template -FLASHINFER_INLINE void cast_load_impl(const src_float_t *src_ptr, - vec_t &dst) { - if constexpr (std::is_same::value) { - dst.load(src_ptr); - } else { - vec_t tmp; - tmp.load(src_ptr); - dst.cast_from(tmp); - } -} - -template -FLASHINFER_INLINE void cast_store_impl(const vec_t &src, - tgt_float_t *dst_ptr) { - if constexpr (std::is_same::value) { - src.store(dst_ptr); - } else { - vec_t tmp; - tmp.cast_from(src); - tmp.store(dst_ptr); - } -} - -#ifdef FLASHINFER_USE_FP8 -/******************* vec_t<__nv_fp8_e4m3> *******************/ - -// __nv_fp8_e4m3 x 1 -template <> -struct vec_t<__nv_fp8_e4m3, 1> { - __nv_fp8_e4m3 data; - - FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { - return ((__nv_fp8_e4m3 *)(&data))[i]; - } - FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { - return ((const __nv_fp8_e4m3 *)(&data))[i]; - } - FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); - FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr); - FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, - const __nv_fp8_e4m3 *src); -}; - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::fill(__nv_fp8_e4m3 val) { - data = val; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::load(const __nv_fp8_e4m3 *ptr) { - data = *ptr; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::store( - __nv_fp8_e4m3 *ptr) const { - *ptr = data; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::memcpy( - __nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) { - *dst = *src; -} - -// __nv_fp8_e4m3 x 2 -template <> -struct vec_t<__nv_fp8_e4m3, 2> { - __nv_fp8x2_e4m3 data; - - FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { - return ((__nv_fp8_e4m3 *)(&data))[i]; - } - FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { - return ((const __nv_fp8_e4m3 *)(&data))[i]; - } - FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); - FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr); - FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, - const __nv_fp8_e4m3 *src); -}; - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::fill(__nv_fp8_e4m3 val) { - data.__x = - (__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::load(const __nv_fp8_e4m3 *ptr) { - data = *((__nv_fp8x2_e4m3 *)ptr); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::store( - __nv_fp8_e4m3 *ptr) const { - *((__nv_fp8x2_e4m3 *)ptr) = data; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::memcpy( - __nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) { - *((__nv_fp8x2_e4m3 *)dst) = *((__nv_fp8x2_e4m3 *)src); -} - -// __nv_fp8_e4m3 x 4 - -template <> -struct vec_t<__nv_fp8_e4m3, 4> { - __nv_fp8x4_e4m3 data; - - FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { - return ((__nv_fp8_e4m3 *)(&data))[i]; - } - FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { - return ((const __nv_fp8_e4m3 *)(&data))[i]; - } - FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); - FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr); - FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, - const __nv_fp8_e4m3 *src); -}; - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::fill(__nv_fp8_e4m3 val) { - data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | - __nv_fp8x4_storage_t(val.__x); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::load(const __nv_fp8_e4m3 *ptr) { - data = *((__nv_fp8x4_e4m3 *)ptr); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::store( - __nv_fp8_e4m3 *ptr) const { - *((__nv_fp8x4_e4m3 *)ptr) = data; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::memcpy( - __nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) { - *((__nv_fp8x4_e4m3 *)dst) = *((__nv_fp8x4_e4m3 *)src); -} - -// __nv_fp8_e4m3 x 8 - -template <> -struct vec_t<__nv_fp8_e4m3, 8> { - uint2 data; - - FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { - return ((__nv_fp8_e4m3 *)(&data))[i]; - } - FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { - return ((const __nv_fp8_e4m3 *)(&data))[i]; - } - FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); - FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr); - FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, - const __nv_fp8_e4m3 *src); -}; - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::fill(__nv_fp8_e4m3 val) { - ((__nv_fp8x4_e4m3 *)(&data.x))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | - __nv_fp8x4_storage_t(val.__x); - ((__nv_fp8x4_e4m3 *)(&data.y))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | - __nv_fp8x4_storage_t(val.__x); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::load(const __nv_fp8_e4m3 *ptr) { - data = *((uint2 *)ptr); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::store( - __nv_fp8_e4m3 *ptr) const { - *((uint2 *)ptr) = data; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::memcpy( - __nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) { - *((__nv_fp8_e4m3 *)dst) = *((__nv_fp8_e4m3 *)src); -} - -// __nv_fp8_e4m3 x 16 or more -template -struct vec_t<__nv_fp8_e4m3, vec_size> { - uint4 data[vec_size / 16]; - - FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { - return ((__nv_fp8_e4m3 *)data)[i]; - } - FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { - return ((const __nv_fp8_e4m3 *)data)[i]; - } - FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val) { -#pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((__nv_fp8x4_e4m3 *)(&(data[i].x)))->__x = - (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); - ((__nv_fp8x4_e4m3 *)(&(data[i].y)))->__x = - (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); - ((__nv_fp8x4_e4m3 *)(&(data[i].z)))->__x = - (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); - ((__nv_fp8x4_e4m3 *)(&(data[i].w)))->__x = - (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); - } - } - FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr) { -#pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - data[i] = ((uint4 *)ptr)[i]; - } - } - FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const { -#pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((uint4 *)ptr)[i] = data[i]; - } - } - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, - const __nv_fp8_e4m3 *src) { -#pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((uint4 *)dst)[i] = ((uint4 *)src)[i]; - } - } -}; - -/******************* vec_t<__nv_fp8_e5m2> *******************/ - -// __nv_fp8_e5m2 x 1 -template <> -struct vec_t<__nv_fp8_e5m2, 1> { - __nv_fp8_e5m2 data; - - FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { - return ((__nv_fp8_e5m2 *)(&data))[i]; - } - FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { - return ((const __nv_fp8_e5m2 *)(&data))[i]; - } - FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); - FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr); - FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, - const __nv_fp8_e5m2 *src); -}; - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::fill(__nv_fp8_e5m2 val) { - data = val; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::load(const __nv_fp8_e5m2 *ptr) { - data = *ptr; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::store( - __nv_fp8_e5m2 *ptr) const { - *ptr = data; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::memcpy( - __nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) { - *dst = *src; -} - -// __nv_fp8_e5m2 x 2 -template <> -struct vec_t<__nv_fp8_e5m2, 2> { - __nv_fp8x2_e5m2 data; - - FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { - return ((__nv_fp8_e5m2 *)(&data))[i]; - } - FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { - return ((const __nv_fp8_e5m2 *)(&data))[i]; - } - FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); - FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr); - FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, - const __nv_fp8_e5m2 *src); -}; - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::fill(__nv_fp8_e5m2 val) { - data.__x = - (__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::load(const __nv_fp8_e5m2 *ptr) { - data = *((__nv_fp8x2_e5m2 *)ptr); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::store( - __nv_fp8_e5m2 *ptr) const { - *((__nv_fp8x2_e5m2 *)ptr) = data; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::memcpy( - __nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) { - *((__nv_fp8x2_e5m2 *)dst) = *((__nv_fp8x2_e5m2 *)src); -} - -// __nv_fp8_e5m2 x 4 - -template <> -struct vec_t<__nv_fp8_e5m2, 4> { - __nv_fp8x4_e5m2 data; - - FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { - return ((__nv_fp8_e5m2 *)(&data))[i]; - } - FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { - return ((const __nv_fp8_e5m2 *)(&data))[i]; - } - FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); - FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr); - FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, - const __nv_fp8_e5m2 *src); -}; - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::fill(__nv_fp8_e5m2 val) { - data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | - __nv_fp8x4_storage_t(val.__x); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::load(const __nv_fp8_e5m2 *ptr) { - data = *((__nv_fp8x4_e5m2 *)ptr); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::store( - __nv_fp8_e5m2 *ptr) const { - *((__nv_fp8x4_e5m2 *)ptr) = data; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::memcpy( - __nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) { - *((__nv_fp8x4_e5m2 *)dst) = *((__nv_fp8x4_e5m2 *)src); -} - -// __nv_fp8_e5m2 x 8 - -template <> -struct vec_t<__nv_fp8_e5m2, 8> { - uint2 data; - - FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { - return ((__nv_fp8_e5m2 *)(&data))[i]; - } - FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { - return ((const __nv_fp8_e5m2 *)(&data))[i]; - } - FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); - FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr); - FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, - const __nv_fp8_e5m2 *src); -}; - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::fill(__nv_fp8_e5m2 val) { - ((__nv_fp8x4_e5m2 *)(&data.x))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | - __nv_fp8x4_storage_t(val.__x); - ((__nv_fp8x4_e5m2 *)(&data.y))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | - __nv_fp8x4_storage_t(val.__x); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::load(const __nv_fp8_e5m2 *ptr) { - data = *((uint2 *)ptr); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::store( - __nv_fp8_e5m2 *ptr) const { - *((uint2 *)ptr) = data; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::memcpy( - __nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) { - *((__nv_fp8_e5m2 *)dst) = *((__nv_fp8_e5m2 *)src); -} - -// __nv_fp8_e5m2 x 16 or more - -template -struct vec_t<__nv_fp8_e5m2, vec_size> { - uint4 data[vec_size / 16]; - - FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { - return ((__nv_fp8_e5m2 *)data)[i]; - } - FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { - return ((const __nv_fp8_e5m2 *)data)[i]; - } - FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val) { -#pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((__nv_fp8x4_e5m2 *)(&(data[i].x)))->__x = - (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); - ((__nv_fp8x4_e5m2 *)(&(data[i].y)))->__x = - (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); - ((__nv_fp8x4_e5m2 *)(&(data[i].z)))->__x = - (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); - ((__nv_fp8x4_e5m2 *)(&(data[i].w)))->__x = - (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); - } - } - FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr) { -#pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - data[i] = ((uint4 *)ptr)[i]; - } - } - FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const { -#pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((uint4 *)ptr)[i] = data[i]; - } - } - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, - const __nv_fp8_e5m2 *src) { -#pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((uint4 *)dst)[i] = ((uint4 *)src)[i]; - } - } -}; -#endif - -/******************* vec_t *******************/ - -// half x 1 -template <> -struct vec_t { - half data; - - FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)(&data))[i]; } - FLASHINFER_INLINE const half &operator[](size_t i) const { - return ((const half *)(&data))[i]; - } - FLASHINFER_INLINE void fill(half val); - FLASHINFER_INLINE void load(const half *ptr); - FLASHINFER_INLINE void store(half *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(half *dst, const half *src); -}; - -FLASHINFER_INLINE void vec_t::fill(half val) { data = val; } - -FLASHINFER_INLINE void vec_t::load(const half *ptr) { data = *ptr; } - -FLASHINFER_INLINE void vec_t::store(half *ptr) const { *ptr = data; } - -FLASHINFER_INLINE void vec_t::memcpy(half *dst, const half *src) { - *dst = *src; -} - -// half x 2 -template <> -struct vec_t { - half2 data; - - FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)(&data))[i]; } - FLASHINFER_INLINE const half &operator[](size_t i) const { - return ((const half *)(&data))[i]; - } - FLASHINFER_INLINE void fill(half val); - FLASHINFER_INLINE void load(const half *ptr); - FLASHINFER_INLINE void store(half *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(half *dst, const half *src); -}; - -FLASHINFER_INLINE void vec_t::fill(half val) { - data = make_half2(val, val); -} - -FLASHINFER_INLINE void vec_t::load(const half *ptr) { - data = *((half2 *)ptr); -} - -FLASHINFER_INLINE void vec_t::store(half *ptr) const { - *((half2 *)ptr) = data; -} - -FLASHINFER_INLINE void vec_t::memcpy(half *dst, const half *src) { - *((half2 *)dst) = *((half2 *)src); -} - -// half x 4 - -template <> -struct vec_t { - uint2 data; - - FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)(&data))[i]; } - FLASHINFER_INLINE const half &operator[](size_t i) const { - return ((const half *)(&data))[i]; - } - FLASHINFER_INLINE void fill(half val); - FLASHINFER_INLINE void load(const half *ptr); - FLASHINFER_INLINE void store(half *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(half *dst, const half *src); -}; - -FLASHINFER_INLINE void vec_t::fill(half val) { - *(half2 *)(&data.x) = make_half2(val, val); - *(half2 *)(&data.y) = make_half2(val, val); -} - -FLASHINFER_INLINE void vec_t::load(const half *ptr) { - data = *((uint2 *)ptr); -} - -FLASHINFER_INLINE void vec_t::store(half *ptr) const { - *((uint2 *)ptr) = data; -} - -FLASHINFER_INLINE void vec_t::memcpy(half *dst, const half *src) { - *((uint2 *)dst) = *((uint2 *)src); -} - -// half x 8 or more - -template -struct vec_t { - uint4 data[vec_size / 8]; - FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)data)[i]; } - FLASHINFER_INLINE const half &operator[](size_t i) const { - return ((const half *)data)[i]; - } - FLASHINFER_INLINE void fill(half val) { -#pragma unroll - for (size_t i = 0; i < vec_size; ++i) { - *(half2 *)(&(data[i].x)) = make_half2(val, val); - *(half2 *)(&(data[i].y)) = make_half2(val, val); - *(half2 *)(&(data[i].z)) = make_half2(val, val); - *(half2 *)(&(data[i].w)) = make_half2(val, val); - } - } - FLASHINFER_INLINE void load(const half *ptr) { -#pragma unroll - for (size_t i = 0; i < vec_size / 8; ++i) { - data[i] = ((uint4 *)ptr)[i]; - } - } - FLASHINFER_INLINE void store(half *ptr) const { -#pragma unroll - for (size_t i = 0; i < vec_size / 8; ++i) { - ((uint4 *)ptr)[i] = data[i]; - } - } - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(half *dst, const half *src) { -#pragma unroll - for (size_t i = 0; i < vec_size / 8; ++i) { - ((uint4 *)dst)[i] = ((uint4 *)src)[i]; - } - } -}; - -/******************* vec_t *******************/ - -// nv_bfloat16 x 1 -template <> -struct vec_t { - nv_bfloat16 data; - - FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) { - return ((nv_bfloat16 *)(&data))[i]; - } - FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const { - return ((const nv_bfloat16 *)(&data))[i]; - } - FLASHINFER_INLINE void fill(nv_bfloat16 val); - FLASHINFER_INLINE void load(const nv_bfloat16 *ptr); - FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst, - const nv_bfloat16 *src); -}; - -FLASHINFER_INLINE void vec_t::fill(nv_bfloat16 val) { - data = val; -} - -FLASHINFER_INLINE void vec_t::load(const nv_bfloat16 *ptr) { - data = *ptr; -} - -FLASHINFER_INLINE void vec_t::store(nv_bfloat16 *ptr) const { - *ptr = data; -} - -FLASHINFER_INLINE void vec_t::memcpy(nv_bfloat16 *dst, - const nv_bfloat16 *src) { - *dst = *src; -} - -// nv_bfloat16 x 2 -template <> -struct vec_t { - nv_bfloat162 data; - - FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) { - return ((nv_bfloat16 *)(&data))[i]; - } - FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const { - return ((const nv_bfloat16 *)(&data))[i]; - } - FLASHINFER_INLINE void fill(nv_bfloat16 val); - FLASHINFER_INLINE void load(const nv_bfloat16 *ptr); - FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst, - const nv_bfloat16 *src); -}; - -FLASHINFER_INLINE void vec_t::fill(nv_bfloat16 val) { - data = make_bfloat162(val, val); -} - -FLASHINFER_INLINE void vec_t::load(const nv_bfloat16 *ptr) { - data = *((nv_bfloat162 *)ptr); -} - -FLASHINFER_INLINE void vec_t::store(nv_bfloat16 *ptr) const { - *((nv_bfloat162 *)ptr) = data; -} - -FLASHINFER_INLINE void vec_t::memcpy(nv_bfloat16 *dst, - const nv_bfloat16 *src) { - *((nv_bfloat162 *)dst) = *((nv_bfloat162 *)src); -} - -// nv_bfloat16 x 4 - -template <> -struct vec_t { - uint2 data; - - FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) { - return ((nv_bfloat16 *)(&data))[i]; - } - FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const { - return ((const nv_bfloat16 *)(&data))[i]; - } - FLASHINFER_INLINE void fill(nv_bfloat16 val); - FLASHINFER_INLINE void load(const nv_bfloat16 *ptr); - FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst, - const nv_bfloat16 *src); -}; - -FLASHINFER_INLINE void vec_t::fill(nv_bfloat16 val) { - *(nv_bfloat162 *)(&data.x) = make_bfloat162(val, val); - *(nv_bfloat162 *)(&data.y) = make_bfloat162(val, val); -} - -FLASHINFER_INLINE void vec_t::load(const nv_bfloat16 *ptr) { - data = *((uint2 *)ptr); -} - -FLASHINFER_INLINE void vec_t::store(nv_bfloat16 *ptr) const { - *((uint2 *)ptr) = data; -} - -FLASHINFER_INLINE void vec_t::memcpy(nv_bfloat16 *dst, - const nv_bfloat16 *src) { - *((uint2 *)dst) = *((uint2 *)src); -} - -// nv_bfloat16 x 8 or more - -template -struct vec_t { - uint4 data[vec_size / 8]; - - FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) { - return ((nv_bfloat16 *)data)[i]; - } - FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const { - return ((const nv_bfloat16 *)data)[i]; - } - FLASHINFER_INLINE void fill(nv_bfloat16 val) { -#pragma unoll - for (size_t i = 0; i < vec_size / 8; ++i) { - *(nv_bfloat162 *)(&(data[i].x)) = make_bfloat162(val, val); - *(nv_bfloat162 *)(&(data[i].y)) = make_bfloat162(val, val); - *(nv_bfloat162 *)(&(data[i].z)) = make_bfloat162(val, val); - *(nv_bfloat162 *)(&(data[i].w)) = make_bfloat162(val, val); - } - } - FLASHINFER_INLINE void load(const nv_bfloat16 *ptr) { -#pragma unoll - for (size_t i = 0; i < vec_size / 8; ++i) { - data[i] = ((uint4 *)ptr)[i]; - } - } - FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const { -#pragma unoll - for (size_t i = 0; i < vec_size / 8; ++i) { - ((uint4 *)ptr)[i] = data[i]; - } - } - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst, - const nv_bfloat16 *src) { -#pragma unoll - for (size_t i = 0; i < vec_size / 8; ++i) { - ((uint4 *)dst)[i] = ((uint4 *)src)[i]; - } - } -}; - -/******************* vec_t *******************/ - -// float x 1 - -template <> -struct vec_t { - float data; - - FLASHINFER_INLINE float &operator[](size_t i) { - return ((float *)(&data))[i]; - } - FLASHINFER_INLINE const float &operator[](size_t i) const { - return ((const float *)(&data))[i]; - } - FLASHINFER_INLINE void fill(float val); - FLASHINFER_INLINE void load(const float *ptr); - FLASHINFER_INLINE void store(float *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(float *dst, const float *src); -}; - -FLASHINFER_INLINE void vec_t::fill(float val) { data = val; } - -FLASHINFER_INLINE void vec_t::load(const float *ptr) { data = *ptr; } - -FLASHINFER_INLINE void vec_t::store(float *ptr) const { *ptr = data; } - -FLASHINFER_INLINE void vec_t::memcpy(float *dst, const float *src) { - *dst = *src; -} - -// float x 2 - -template <> -struct vec_t { - float2 data; - - FLASHINFER_INLINE float &operator[](size_t i) { - return ((float *)(&data))[i]; - } - FLASHINFER_INLINE const float &operator[](size_t i) const { - return ((const float *)(&data))[i]; - } - FLASHINFER_INLINE void fill(float val); - FLASHINFER_INLINE void load(const float *ptr); - FLASHINFER_INLINE void store(float *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - FLASHINFER_INLINE static void memcpy(float *dst, const float *src); -}; - -FLASHINFER_INLINE void vec_t::fill(float val) { - data = make_float2(val, val); -} - -FLASHINFER_INLINE void vec_t::load(const float *ptr) { - data = *((float2 *)ptr); -} - -FLASHINFER_INLINE void vec_t::store(float *ptr) const { - *((float2 *)ptr) = data; -} - -FLASHINFER_INLINE void vec_t::memcpy(float *dst, const float *src) { - *((float2 *)dst) = *((float2 *)src); -} - -// float x 4 or more -template -struct vec_t { - float4 data[vec_size / 4]; - - FLASHINFER_INLINE float &operator[](size_t i) { return ((float *)(data))[i]; } - FLASHINFER_INLINE const float &operator[](size_t i) const { - return ((const float *)(data))[i]; - } - FLASHINFER_INLINE void fill(float val) { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - data[i] = make_float4(val, val, val, val); - } - } - FLASHINFER_INLINE void load(const float *ptr) { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - data[i] = ((float4 *)ptr)[i]; - } - } - FLASHINFER_INLINE void store(float *ptr) const { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - ((float4 *)ptr)[i] = data[i]; - } - } - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - FLASHINFER_INLINE static void memcpy(float *dst, const float *src) { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - ((float4 *)dst)[i] = ((float4 *)src)[i]; - } - } -}; - -/******************* vec_t type cast *******************/ - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t &src, - vec_t &dst) { - if constexpr (vec_size == 1) { - dst.data = float(src.data); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((float2 *)(&dst.data))[i] = __half22float2(((half2 *)(&src.data))[i]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t &src, - vec_t &dst) { - if constexpr (vec_size == 1) { - dst.data = half(src.data); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((half2 *)(&dst.data))[i] = __float22half2_rn(((float2 *)(&src.data))[i]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t &src, - vec_t &dst) { - if constexpr (vec_size == 1) { - dst.data = float(src.data); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((float2 *)(&dst.data))[i] = - __bfloat1622float2(((nv_bfloat162 *)(&src.data))[i]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t &src, - vec_t &dst) { - if constexpr (vec_size == 1) { - dst.data = nv_bfloat16(src.data); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((nv_bfloat162 *)(&dst.data))[i] = - __float22bfloat162_rn(((float2 *)(&src.data))[i]); - } - } -} - -#ifdef FLASHINFER_USE_FP8 - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e4m3, vec_size> &src, - vec_t &dst) { - if constexpr (vec_size == 1) { - dst.data = float(src.data); - } else if constexpr (vec_size == 2) { - *(float2 *)(&dst.data) = float2(*(__nv_fp8x2_e4m3 *)(&src.data)); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - ((float4 *)(&dst.data))[i] = float4(((__nv_fp8x4_e4m3 *)(&src.data))[i]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e4m3, vec_size> &src, - vec_t &dst) { - if constexpr (vec_size == 1) { - dst.data = float(src.data); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((half2 *)(&dst.data))[i] = half2(((__nv_fp8x2_e4m3 *)(&src.data))[i]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t &src, - vec_t<__nv_fp8_e4m3, vec_size> &dst) { - if constexpr (vec_size == 1) { - dst.data = __nv_fp8_e4m3(src.data); - } else if constexpr (vec_size == 2) { - *(__nv_fp8x2_e4m3 *)(&dst.data) = __nv_fp8x2_e4m3(*(float2 *)(&src.data)); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - ((__nv_fp8x4_e4m3 *)(&dst.data))[i] = - __nv_fp8x4_e4m3(((float4 *)(&src.data))[i]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t &src, - vec_t<__nv_fp8_e4m3, vec_size> &dst) { - if constexpr (vec_size == 1) { - dst.data = __nv_fp8_e4m3(src.data); - } else if constexpr (vec_size == 2) { - *(__nv_fp8x2_e4m3 *)(&dst.data) = __nv_fp8x2_e4m3(*(half2 *)(&src.data)); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - // NOTE(Zihao): need to double check if we properly handle flo and fhi - ((__nv_fp8x4_e4m3 *)(&dst.data))[i] = __nv_fp8x4_e4m3( - ((half2 *)(&src.data))[i * 2], ((half2 *)(&src.data))[i * 2 + 1]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e5m2, vec_size> &src, - vec_t &dst) { - if constexpr (vec_size == 1) { - dst.data = float(src.data); - } else if constexpr (vec_size == 2) { - *(float2 *)(&dst.data) = float2(*(__nv_fp8x2_e5m2 *)(&src.data)); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - ((float4 *)(&dst.data))[i] = float4(((__nv_fp8x4_e5m2 *)(&src.data))[i]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e5m2, vec_size> &src, - vec_t &dst) { - if constexpr (vec_size == 1) { - dst.data = float(src.data); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((half2 *)(&dst.data))[i] = half2(((__nv_fp8x2_e5m2 *)(&src.data))[i]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t &src, - vec_t<__nv_fp8_e5m2, vec_size> &dst) { - if constexpr (vec_size == 1) { - dst.data = __nv_fp8_e5m2(src.data); - } else if constexpr (vec_size == 2) { - *(__nv_fp8x2_e5m2 *)(&dst.data) = __nv_fp8x2_e5m2(*(float2 *)(&src.data)); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - ((__nv_fp8x4_e5m2 *)(&dst.data))[i] = - __nv_fp8x4_e5m2(((float4 *)(&src.data))[i]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t &src, - vec_t<__nv_fp8_e5m2, vec_size> &dst) { - if constexpr (vec_size == 1) { - dst.data = __nv_fp8_e4m3(src.data); - } else if constexpr (vec_size == 2) { - *(__nv_fp8x2_e5m2 *)(&dst.data) = __nv_fp8x2_e5m2(*(half2 *)(&src.data)); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - // NOTE(Zihao): need to double check if we properly handle flo and fhi - ((__nv_fp8x4_e5m2 *)(&dst.data))[i] = __nv_fp8x4_e5m2( - ((half2 *)(&src.data))[i * 2], ((half2 *)(&src.data))[i * 2 + 1]); - } - } -} - -#endif // FLASHINFER_USE_FP8 - -#endif // VEC_DTYPES_CUH_ diff --git a/csrc/punica/punica_ops.cu b/csrc/punica/punica_ops.cu deleted file mode 100644 index dd29820144b3..000000000000 --- a/csrc/punica/punica_ops.cu +++ /dev/null @@ -1,569 +0,0 @@ -#include -#include -#include - -#include "type_convert.h" -#include "../cuda_compat.h" -#include "bgmv/bgmv_config.h" - - -//====== utils ====== - -inline void check_shape(const torch::Tensor &a, const torch::Tensor &b, - const char *a_name, const char *b_name) { - TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ", - a.dim(), " vs ", b.dim()); - for (int i = 0; i < a.dim(); ++i) { - TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name, - ".size(", i, ")"); - } -} - -inline constexpr uint64_t pack_u32(uint32_t a, uint32_t b) { - return (uint64_t(a) << 32) | uint64_t(b); -} - -#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") - -#define CHECK_CONTIGUOUS(x) \ - TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") - -#define CHECK_INPUT(x) \ - CHECK_CUDA(x); \ - CHECK_CONTIGUOUS(x) - -#define CHECK_DIM(d, x) \ - TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor") - -#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b) - -#define CHECK_EQ(a, b) \ - TORCH_CHECK(a == b, "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b) - -//====== bgmv ====== - -template -inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W, - const int64_t *lora_indices, - uint32_t in_features, uint32_t out_features, - int64_t y_offset, int64_t full_y_size, - int64_t batch_size, int64_t num_layers, - int64_t layer_idx, float scale) { - // NOTE(woosuk): While Punica supports various combinations of input/output - // data types, we limit the supported data types to reduce the binary size. - constexpr bool is_input_float = std::is_same::value; - constexpr bool is_output_float = std::is_same::value; - if (is_input_float) { - if (!std::is_same::value) { - return false; - } - } else if (is_output_float) { - if (!std::is_same::value) { - return false; - } - } else if (!(std::is_same::value && - std::is_same::value)) { - return false; - } - - switch (pack_u32(in_features, out_features)) { -#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \ - case pack_u32(feat_in, feat_out): \ - bgmv_kernel(Y, X, W, lora_indices, y_offset, \ - full_y_size, batch_size, num_layers, \ - layer_idx, scale); \ - break; -#define CASE(_in_T, _out_T, _W_T, narrow, wide) \ - CASE_ONESIDE(in_T, out_T, W_T, narrow, wide) \ - CASE_ONESIDE(in_T, out_T, W_T, wide, narrow) - - FOR_BGMV_WIDE_NARROW(CASE, _, _, _) - FOR_INST_BGMV_WIDE_NARROW(CASE_ONESIDE, _, _, _) -#undef CASE -#undef CASE_ONESIDE - default: - return false; - } - return true; -} - -void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, - torch::Tensor indicies, int64_t layer_idx, double scale) { - CHECK_INPUT(y); - CHECK_INPUT(x); - CHECK_INPUT(w); - CHECK_INPUT(indicies); - - CHECK_DIM(2, y); - CHECK_DIM(2, x); - CHECK_DIM(4, w); - CHECK_DIM(1, indicies); - - int64_t B = x.size(0); - int64_t h_in = x.size(1); - int64_t h_out = y.size(1); - int64_t num_layers = w.size(1); - CHECK_EQ(w.size(3), h_in); - CHECK_EQ(w.size(2), h_out); - CHECK_EQ(indicies.size(0), x.size(0)); - CHECK_EQ(y.size(0), x.size(0)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); - bool ok = false; - if (h_in <= 128512 && h_out <= 128512) { - // TODO: See if we can get rid of this massive nested switch - switch (x.scalar_type()) { - case at::ScalarType::Half: - switch (y.scalar_type()) { - case at::ScalarType::Half: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::BFloat16: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::Float: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - default: - break; - } - break; - default: - break; - } - break; - case at::ScalarType::BFloat16: - switch (y.scalar_type()) { - case at::ScalarType::Half: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::BFloat16: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::Float: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - default: - break; - } - break; - default: - break; - } - break; - case at::ScalarType::Float: - switch (y.scalar_type()) { - case at::ScalarType::Half: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::BFloat16: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::Float: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - default: - break; - } - break; - default: - break; - } - break; - default: - break; - } - } - TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out, - " dtype=", x.scalar_type(), " out_dtype=", y.scalar_type()); -} - -void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, - torch::Tensor indicies, int64_t layer_idx, - double scale, int64_t h_in, int64_t h_out, - int64_t y_offset) { - CHECK_INPUT(y); - CHECK_INPUT(x); - CHECK_INPUT(w); - CHECK_INPUT(indicies); - - CHECK_DIM(2, y); - CHECK_DIM(2, x); - CHECK_DIM(4, w); - CHECK_DIM(1, indicies); - - int64_t B = x.size(0); - int64_t num_layers = w.size(1); - int64_t full_y_size = y.size(1); - CHECK_EQ(w.size(3), h_in); - CHECK_EQ(w.size(2), h_out); - CHECK_EQ(indicies.size(0), x.size(0)); - CHECK_EQ(y.size(0), x.size(0)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); - bool ok = false; - if (h_in <= 128512 && h_out <= 128512) { - // TODO: See if we can get rid of this massive nested switch - switch (x.scalar_type()) { - case at::ScalarType::Half: - switch (y.scalar_type()) { - case at::ScalarType::Half: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::BFloat16: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::Float: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - default: - break; - } - break; - default: - break; - } - break; - case at::ScalarType::BFloat16: - switch (y.scalar_type()) { - case at::ScalarType::Half: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::BFloat16: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::Float: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - default: - break; - } - break; - default: - break; - } - break; - case at::ScalarType::Float: - switch (y.scalar_type()) { - case at::ScalarType::Half: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::BFloat16: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::Float: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - default: - break; - } - break; - default: - break; - } - break; - default: - break; - } - } - TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out, - " dtype=", x.scalar_type(), " out_dtype=", y.scalar_type()); -} diff --git a/csrc/punica/punica_ops.h b/csrc/punica/punica_ops.h deleted file mode 100644 index 5d625d0564f7..000000000000 --- a/csrc/punica/punica_ops.h +++ /dev/null @@ -1,11 +0,0 @@ -#pragma once - -#include - -void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, - torch::Tensor indicies, int64_t layer_idx, double scale); - -void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, - torch::Tensor indicies, int64_t layer_idx, - double scale, int64_t h_in, int64_t h_out, - int64_t y_offset); diff --git a/csrc/punica/torch_bindings.cpp b/csrc/punica/torch_bindings.cpp deleted file mode 100644 index 894e229b6d9d..000000000000 --- a/csrc/punica/torch_bindings.cpp +++ /dev/null @@ -1,18 +0,0 @@ -#include "registration.h" -#include "punica_ops.h" - -TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { - m.def( - "dispatch_bgmv(Tensor! y, Tensor x, Tensor w, Tensor indicies, int " - "layer_idx, float scale) -> ()"); - m.impl("dispatch_bgmv", torch::kCUDA, &dispatch_bgmv); - - m.def( - "dispatch_bgmv_low_level(Tensor! y, Tensor x, Tensor w," - "Tensor indicies, int layer_idx," - "float scale, int h_in, int h_out," - "int y_offset) -> ()"); - m.impl("dispatch_bgmv_low_level", torch::kCUDA, &dispatch_bgmv_low_level); -} - -REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/punica/type_convert.h b/csrc/punica/type_convert.h deleted file mode 100644 index dff7ce49283d..000000000000 --- a/csrc/punica/type_convert.h +++ /dev/null @@ -1,82 +0,0 @@ -#ifndef CSRC__PUNICA__TYPE_CONVERT_H__ -#define CSRC__PUNICA__TYPE_CONVERT_H__ - -#ifndef USE_ROCM - -#include -#include - -#else - -#include -#include - -#define __TYPE_CONVERT__HOST_DEVICE__ __host__ __device__ - -typedef __half nv_half; -typedef __hip_bfloat16 nv_bfloat16; -typedef __hip_bfloat162 nv_bfloat162; - -__TYPE_CONVERT__HOST_DEVICE__ -inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 val) { - return __hip_bfloat162{val, val}; -} - -__TYPE_CONVERT__HOST_DEVICE__ -inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 vall, __hip_bfloat16 valr) { - return __hip_bfloat162{vall, valr}; -} - -template -__TYPE_CONVERT__HOST_DEVICE__ -inline T_dst convert_type(T_src val) { - return static_cast(val); -} - -template <> -__TYPE_CONVERT__HOST_DEVICE__ -inline float convert_type<__half, float>(__half val) { - return __half2float(val); -} - -template <> -__TYPE_CONVERT__HOST_DEVICE__ -inline __half convert_type(float val) { - return __float2half(val); -} - -template <> -__TYPE_CONVERT__HOST_DEVICE__ -inline float convert_type<__hip_bfloat16, float>(__hip_bfloat16 val) { - return __bfloat162float(val); -} - -template <> -__TYPE_CONVERT__HOST_DEVICE__ -inline __hip_bfloat16 convert_type(float val) { - return __float2bfloat16(val); -} - -template -__TYPE_CONVERT__HOST_DEVICE__ -inline T vllm_add(T a, T b) { - return a + b; -} - -template <> -__TYPE_CONVERT__HOST_DEVICE__ -inline __half vllm_add<__half>(__half a, __half b) { - return __hadd(a, b); -} - -template <> -__TYPE_CONVERT__HOST_DEVICE__ -inline __hip_bfloat16 vllm_add<__hip_bfloat16>(__hip_bfloat16 a, __hip_bfloat16 b) { - return __hadd(a, b); -} - -#undef __TYPE_CONVERT__HOST_DEVICE__ - -#endif // USE_ROCM - -#endif // CSRC__PUNICA__TYPE_CONVERT_H__ diff --git a/csrc/quantization/aqlm/gemm_kernels.cu b/csrc/quantization/aqlm/gemm_kernels.cu index 8fb985680086..22da5e4f08a1 100644 --- a/csrc/quantization/aqlm/gemm_kernels.cu +++ b/csrc/quantization/aqlm/gemm_kernels.cu @@ -273,8 +273,6 @@ __global__ void Code2x8Dequant( } __syncthreads(); - float res = 0; - int iters = (prob_k / 8 - 1) / (8 * 32) + 1; while (iters--) { if (pred && a_gl_rd < a_gl_end) { diff --git a/csrc/quantization/awq/dequantize.cuh b/csrc/quantization/awq/dequantize.cuh index 813ec6716cf5..5fa4b5f64027 100644 --- a/csrc/quantization/awq/dequantize.cuh +++ b/csrc/quantization/awq/dequantize.cuh @@ -95,6 +95,7 @@ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) { return result; #endif + __builtin_unreachable(); // Suppress missing return statement warning } } // namespace awq diff --git a/csrc/quantization/awq/gemm_kernels.cu b/csrc/quantization/awq/gemm_kernels.cu index 6d6da5f3d874..9da724a1b43c 100644 --- a/csrc/quantization/awq/gemm_kernels.cu +++ b/csrc/quantization/awq/gemm_kernels.cu @@ -17,14 +17,6 @@ Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023} namespace vllm { namespace awq { -// Pack two half values. -static inline __device__ __host__ unsigned __pack_half2(const half x, - const half y) { - unsigned v0 = *((unsigned short*)&x); - unsigned v1 = *((unsigned short*)&y); - return (v1 << 16) | v0; -} - template __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(int G, int split_k_iters, @@ -42,11 +34,7 @@ __global__ void __launch_bounds__(64) __shared__ half A_shared[16 * (32 + 8)]; __shared__ half B_shared[32 * (N + 8)]; - __shared__ half scaling_factors_shared[N]; - __shared__ half zeros_shared[N]; - int j_factors1 = ((OC + N - 1) / N); - int blockIdx_x = 0; int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1); int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1); @@ -60,7 +48,6 @@ __global__ void __launch_bounds__(64) static constexpr int row_stride_warp = 32 * 8 / 32; static constexpr int row_stride = 2 * 32 * 8 / N; - bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N; // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + @@ -145,11 +132,7 @@ __global__ void __launch_bounds__(64) uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); - // uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / - // 8)) * 8); - // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x - // % (cta_N / 8)) * 8); // - zero and * scale // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = // q * scale - zero * scale. @@ -367,17 +350,11 @@ __global__ void __launch_bounds__(64) __global__ void __launch_bounds__(64) dequantize_weights(int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, half* __restrict__ C, int G) { - int j_factors1 = 4; - int row_stride2 = 4; - int split_k_iters = 1; static constexpr uint32_t ZERO = 0x0; half B_shared[32 * (128 + 8)]; half* B_shared_ptr2 = B_shared; - half B_shared_warp[32]; - int OC = 512; - int N = blockDim.x * gridDim.x; // 2 int col = (blockIdx.x * blockDim.x + threadIdx.x); int row = blockIdx.y * blockDim.y + threadIdx.y; diff --git a/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp b/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp index 877a9f5b9e5d..58b1e8ff159f 100644 --- a/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp +++ b/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp @@ -64,8 +64,6 @@ using namespace detail; // Row vector broadcast template< - // Row bcast reuses the mbarriers from the epilogue subtile load pipeline, so this must be at least - // ceil_div(StagesC, epi tiles per CTA tile) + 1 to ensure no data races int Stages, class CtaTileShapeMNK, class Element, @@ -73,14 +71,12 @@ template< int Alignment = 128 / sizeof_bits_v > struct Sm90RowOrScalarBroadcast { - static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); - static_assert( - (cute::is_same_v>) || // row vector broadcast, e.g. per-col alpha/bias - (cute::is_same_v>)); // batched row vector broadcast + static_assert(Stages == 0, "Row broadcast doesn't support smem usage"); + static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{}); - // Accumulator doesn't distribute row elements evenly amongst threads so we must buffer in smem - struct SharedStorage { - alignas(16) array_aligned(CtaTileShapeMNK{}) * Stages> smem_row; + struct SharedStorage { + array_aligned(CtaTileShapeMNK{})> smem; }; // This struct has been modified to have a bool indicating that ptr_row is a @@ -100,6 +96,12 @@ struct Sm90RowOrScalarBroadcast { return args; } + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + template static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { @@ -118,15 +120,15 @@ struct Sm90RowOrScalarBroadcast { CUTLASS_HOST_DEVICE Sm90RowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) - : params(params), - smem_row(const_cast(shared_storage.smem_row.data())) { } + : params(params) + , smem(const_cast(shared_storage.smem.data())) { } Params params; - Element* smem_row; + Element *smem = nullptr; CUTLASS_DEVICE bool is_producer_load_needed() const { - return true; + return false; } CUTLASS_DEVICE bool @@ -139,78 +141,76 @@ struct Sm90RowOrScalarBroadcast { return (!params.row_broadcast && *(params.ptr_row) == Element(0)); } - template - struct ProducerLoadCallbacks : EmptyProducerLoadCallbacks { - CUTLASS_DEVICE - ProducerLoadCallbacks(GTensor&& gRow, STensor&& sRow, Params const& params) - : gRow(cute::forward(gRow)), - sRow(cute::forward(sRow)), - params(params) {} - - GTensor gRow; // (CTA_M,CTA_N) - STensor sRow; // (CTA_M,CTA_N,PIPE) - Params const& params; - - CUTLASS_DEVICE void - begin(uint64_t* full_mbarrier_ptr, int load_iteration, bool issue_tma_load) { - if (!params.row_broadcast) { - return; - } - - if (issue_tma_load) { - // Increment the expect-tx count of the first subtile's mbarrier by the row vector's byte-size - constexpr uint32_t copy_bytes = size<1>(CtaTileShapeMNK{}) * sizeof_bits_v / 8; - cutlass::arch::ClusterTransactionBarrier::expect_transaction(full_mbarrier_ptr, copy_bytes); - // Issue the TMA bulk copy - auto bulk_copy = Copy_Atom{}.with(*full_mbarrier_ptr); - // Filter so we don't issue redundant copies over stride-0 modes - int bcast_pipe_index = (load_iteration / EpiTiles) % Stages; - copy(bulk_copy, filter(gRow), filter(sRow(_,_,bcast_pipe_index))); - } - } - }; - template CUTLASS_DEVICE auto get_producer_load_callbacks(ProducerLoadArgs const& args) { - - auto [M, N, K, L] = args.problem_shape_mnkl; - auto [m, n, k, l] = args.tile_coord_mnkl; - Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); - Tensor gRow = local_tile(mRow, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) - Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE) - make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), - make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); - - constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value; - return ProducerLoadCallbacks( - cute::move(gRow), cute::move(sRow), params); + return EmptyProducerLoadCallbacks{}; } - template + template struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { CUTLASS_DEVICE - ConsumerStoreCallbacks(RTensor&& tCrRow, STensor&& tCsRow, Params const& params) - : tCrRow(cute::forward(tCrRow)), - tCsRow(cute::forward(tCsRow)), - params(params) {} - - RTensor tCrRow; // (CPY,CPY_M,CPY_N) - STensor tCsRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) + ConsumerStoreCallbacks( + GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_, + GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_, + SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_, + CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_, Params const& params_) + : tGS_gRow(tGS_gRow_) + , tGS_sRow(tGS_sRow_) + , tGS_cRow(tGS_cRow_) + , tiled_G2S(tiled_g2s_) + , tSR_sRow(tSR_sRow_) + , tSR_rRow(tSR_rRow_) + , tCcRow(tCcRow_) + , residue_tCcRow(residue_tCcRow_) + , params(params_) {} + + GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N) + GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N) + GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N) + Tiled_G2S tiled_G2S; + + SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + ThrResidue residue_tCcRow; // (m, n) + ThrNum thr_num; Params const& params; CUTLASS_DEVICE void - previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { + begin() { if (!params.row_broadcast) { - fill(tCrRow, *(params.ptr_row)); + fill(tSR_rRow, *(params.ptr_row)); return; } + auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + Tensor tGS_gRow_flt = filter_zeros(tGS_gRow); + Tensor tGS_sRow_flt = filter_zeros(tGS_sRow); + Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride())); + + for (int i = 0; i < size(tGS_gRow_flt); ++i) { + if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) { + continue; // OOB of SMEM, + } + if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) { + tGS_sRow_flt(i) = tGS_gRow_flt(i); + } + else { + tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds. + } + } + synchronize(); + } + + CUTLASS_DEVICE void + begin_loop(int epi_m, int epi_n) { if (epi_m == 0) { // Assumes M-major subtile loop - // Filter so we don't issue redundant copies over stride-0 modes - // (only works if 0-strides are in same location, which is by construction) - int bcast_pipe_index = (load_iteration / EpiTiles) % Stages; - copy_aligned(filter(tCsRow(_,_,_,epi_m,epi_n,bcast_pipe_index)), filter(tCrRow)); + if (!params.row_broadcast) return; // Do not issue LDS when row is scalar + Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n)); + Tensor tSR_rRow_flt = filter_zeros(tSR_rRow); + copy(tSR_sRow_flt, tSR_rRow_flt); } } @@ -221,7 +221,7 @@ struct Sm90RowOrScalarBroadcast { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < FragmentSize; ++i) { - frg_row[i] = tCrRow(epi_v * FragmentSize + i); + frg_row[i] = tSR_rRow(epi_v * FragmentSize + i); } return frg_row; @@ -234,17 +234,41 @@ struct Sm90RowOrScalarBroadcast { > CUTLASS_DEVICE auto get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + using ThreadCount = decltype(size(args.tiled_copy)); - Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE) - make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), - make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); - Tensor tCsRow = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) - sRow, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tCrRow = make_tensor_like(take<0,3>(tCsRow)); // (CPY,CPY_M,CPY_N) - - constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value; - return ConsumerStoreCallbacks( - cute::move(tCrRow), cute::move(tCsRow), params); + Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); + Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N) + Tensor sRow = make_tensor(make_smem_ptr(smem), + make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N) + //// G2S: Gmem to Smem + auto tiled_g2s = make_tiled_copy(Copy_Atom{}, + Layout< Shape<_1, ThreadCount>, + Stride<_0, _1>>{}, + Layout<_1>{}); + auto thr_g2s = tiled_g2s.get_slice(args.thread_idx); + Tensor tGS_gRow = thr_g2s.partition_S(gRow); + Tensor tGS_sRow = thr_g2s.partition_D(sRow); + + //// G2S: Coord + auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}))); + Tensor tGS_cRow = thr_g2s.partition_S(cRow); + + //// S2R: Smem to Reg + Tensor tSR_sRow = sm90_partition_for_epilogue(sRow, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N) + + return ConsumerStoreCallbacks( + tGS_gRow, + tGS_sRow, + tGS_cRow, tiled_g2s, + tSR_sRow, + tSR_rRow, + args.tCcD, + args.residue_cD, + ThreadCount{}, + params); } }; @@ -285,6 +309,12 @@ struct Sm90ColOrScalarBroadcast { return args; } + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + template static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { @@ -328,20 +358,36 @@ struct Sm90ColOrScalarBroadcast { return EmptyProducerLoadCallbacks{}; } - template + template struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { CUTLASS_DEVICE - ConsumerStoreCallbacks(GTensor&& tCgCol, RTensor&& tCrCol, Params const& params) - : tCgCol(cute::forward(tCgCol)), - tCrCol(cute::forward(tCrCol)), - params(params) {} + ConsumerStoreCallbacks( + GTensor&& tCgCol, + RTensor&& tCrCol, + CTensor&& tCcCol, + ProblemShape problem_shape, + Params const& params + ): + tCgCol(cute::forward(tCgCol)), + tCrCol(cute::forward(tCrCol)), + tCcCol(cute::forward(tCcCol)), + m(get<0>(problem_shape)), + params(params) {} GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - RTensor tCrCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + RTensor tCrCol; + CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) Params const& params; + int m; CUTLASS_DEVICE void begin() { + Tensor pred = make_tensor(shape(tCgCol)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(pred); ++i) { + pred(i) = get<0>(tCcCol(i)) < m; + } + if (!params.col_broadcast) { fill(tCrCol, *(params.ptr_col)); return; @@ -349,7 +395,7 @@ struct Sm90ColOrScalarBroadcast { // Filter so we don't issue redundant copies over stride-0 modes // (only works if 0-strides are in same location, which is by construction) - copy_aligned(filter(tCgCol), filter(tCrCol)); + copy_if(pred, filter(tCgCol), filter(tCrCol)); } template @@ -381,8 +427,20 @@ struct Sm90ColOrScalarBroadcast { mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - return ConsumerStoreCallbacks( - cute::move(tCgCol), cute::move(tCrCol), params); + // Generate an identity tensor matching the shape of the global tensor and + // partition the same way, this will be used to generate the predicate + // tensor for loading + Tensor cCol = make_identity_tensor(mCol.shape()); + Tensor tCcCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + + return ConsumerStoreCallbacks( + cute::move(tCgCol), + cute::move(tCrCol), + cute::move(tCcCol), + args.problem_shape_mnkl, + params + ); } }; diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu index 6ce25c5ac897..8d0dfee7bf23 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu @@ -1,469 +1,17 @@ #include #include - -#include - -// clang-format will break include orders -// clang-format off -#include "cute/tensor.hpp" -#include "cute/atom/mma_atom.hpp" -#include "cutlass/numeric_types.h" - -#include "cutlass/util/device_memory.h" - #include "cutlass/cutlass.h" -#include "cutlass/gemm_coord.h" -#include "cutlass/arch/mma_sm75.h" -#include "cutlass/arch/arch.h" -#include "cutlass/arch/mma.h" -#include "cutlass/gemm/device/gemm.h" -#include "cutlass/gemm/device/gemm_universal_adapter.h" - -#include "cutlass/epilogue/threadblock/fusion/visitors.hpp" -#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" - -#include "broadcast_load_epilogue_c2x.hpp" -#include "common.hpp" -// clang-format on -using namespace cute; +#include "scaled_mm_c2x.cuh" +#include "scaled_mm_c2x_sm75_dispatch.cuh" +#include "scaled_mm_c2x_sm80_dispatch.cuh" +#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh" +#include "scaled_mm_c2x_sm89_int8_dispatch.cuh" /* This file defines quantized GEMM operations using the CUTLASS 2.x API, for NVIDIA GPUs with SM versions prior to sm90 (Hopper). - - Epilogue functions can be defined to post-process the output before it is - written to GPU memory. - Epilogues must contain a public type named EVTCompute of type Sm80EVT, - as well as a static prepare_args function that constructs an - EVTCompute::Arguments struct. -*/ - -namespace { - -// Wrappers for the GEMM kernel that is used to guard against compilation on -// architectures that will never use the kernel. The purpose of this is to -// reduce the size of the compiled binary. -// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef -// into code that will be executed on the device where it is defined. -template -struct enable_sm75_to_sm80 : Kernel { - template - CUTLASS_DEVICE static void invoke(Args&&... args) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800 - Kernel::invoke(std::forward(args)...); -#endif - } -}; - -template -struct enable_sm80_to_sm89 : Kernel { - template - CUTLASS_DEVICE static void invoke(Args&&... args) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890 - Kernel::invoke(std::forward(args)...); -#endif - } -}; - -template -struct enable_sm89_to_sm90 : Kernel { - template - CUTLASS_DEVICE static void invoke(Args&&... args) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900 - Kernel::invoke(std::forward(args)...); -#endif - } -}; - -/* - * This class provides the common ScaleA and ScaleB descriptors for the - * ScaledEpilogue and ScaledEpilogueBias classes. - */ -template -struct ScaledEpilogueBase { - protected: - using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; - - using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast< - OutputTileThreadMap, float, Stride, Int<0>, Int<0>>>; - - using ScaleB = cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast< - OutputTileThreadMap, float, Stride, Int<1>, Int<0>>>; -}; - -/* - This epilogue function defines a quantized GEMM operation similar to - torch._scaled_mm. - - A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or - per-row. B can be quantized per-tensor or per-column. - Any combination of per-tensor and per-row or column is supported. - A and B must have symmetric quantization (zero point == 0). - - So the GEMM operation is D = (a_scales * A) (b_scales * B), where the - scales are applied elementwise with numpy-style broadcasting. - - ScaleA and ScaleB define the epilogue functions that apply the scales for - the A and B operands respectively. These scales may be either per-tensor or - per row or column. */ -template -struct ScaledEpilogue - : private ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::ScaleA; - using ScaleB = typename SUPER::ScaleB; - - using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTCompute0 = - cutlass::epilogue::threadblock::Sm80EVT; - - using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = - cutlass::epilogue::threadblock::Sm80EVT; - using ArgumentType = typename EVTCompute::Arguments; - - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { - using ScaleAArgs = typename ScaleA::Arguments; - using ScaleBArgs = typename ScaleB::Arguments; - - ScaleBArgs b_args{b_scales.data_ptr(), b_scales.numel() != 1, {}}; - ScaleAArgs a_args{a_scales.data_ptr(), a_scales.numel() != 1, {}}; - - typename EVTCompute0::Arguments evt0_compute_args{b_args}; - - typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args}; - return evt_compute_args; - } -}; - -template -struct ScaledEpilogueBias - : private ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::ScaleA; - using ScaleB = typename SUPER::ScaleB; - - using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTCompute0 = - cutlass::epilogue::threadblock::Sm80EVT; - - using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiply_add, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using Bias = cutlass::epilogue::threadblock::VisitorRowBroadcast< - OutputTileThreadMap, ElementD, Stride, Int<1>, Int<0>>>; - - public: - using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT; - using ArgumentType = typename EVTCompute::Arguments; - - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& bias) { - using ScaleAArgs = typename ScaleA::Arguments; - using ScaleBArgs = typename ScaleB::Arguments; - using BiasArgs = typename Bias::Arguments; - - ScaleBArgs b_args{b_scales.data_ptr(), b_scales.numel() != 1, {}}; - ScaleAArgs a_args{a_scales.data_ptr(), a_scales.numel() != 1, {}}; - BiasArgs bias_args{static_cast(bias.data_ptr()), {}}; - - typename EVTCompute0::Arguments evt0_compute_args{b_args}; - - typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args, - bias_args}; - return evt_compute_args; - } -}; - -template typename ArchGuard, - typename ElementAB_, typename ElementD_, - template typename Epilogue_, typename TileShape, - typename WarpShape, typename InstructionShape, int32_t MainLoopStages> -struct cutlass_2x_gemm { - using ElementAB = ElementAB_; - using ElementD = ElementD_; - - using ElementAcc = - typename std::conditional, int32_t, - float>::type; - - using Operator = - typename std::conditional, - cutlass::arch::OpMultiplyAddSaturate, - cutlass::arch::OpMultiplyAdd>::type; - - using OutputTileThreadMap = - cutlass::epilogue::threadblock::OutputTileThreadLayout< - TileShape, WarpShape, float, 4, 1 /* epilogue stages */ - >; - - using Epilogue = Epilogue_; - using EVTCompute = typename Epilogue::EVTCompute; - - using D = cutlass::epilogue::threadblock::VisitorAuxStore< - OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest, - Stride, Int<0>>>; - - using EVTD = cutlass::epilogue::threadblock::Sm80EVT; - - // clang-format off - using RowMajor = typename cutlass::layout::RowMajor; - using ColumnMajor = typename cutlass::layout::ColumnMajor; - using KernelType = - ArchGuard::GemmKernel>; - // clang-format on - - using Op = cutlass::gemm::device::GemmUniversalAdapter; -}; - -template -void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - EpilogueArgs&&... epilogue_params) { - using ElementAB = typename Gemm::ElementAB; - using ElementD = typename Gemm::ElementD; - - int32_t m = a.size(0); - int32_t n = b.size(1); - int32_t k = a.size(1); - cutlass::gemm::GemmCoord problem_size{m, n, k}; - - int64_t lda = a.stride(0); - int64_t ldb = b.stride(1); - int64_t ldc = out.stride(0); - - using StrideC = Stride, Int<0>>; - StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; - - auto a_ptr = static_cast(a.data_ptr()); - auto b_ptr = static_cast(b.data_ptr()); - auto c_ptr = static_cast(out.data_ptr()); - - typename Gemm::D::Arguments d_args{c_ptr, c_stride}; - - using Epilogue = typename Gemm::Epilogue; - auto evt_args = - Epilogue::prepare_args(std::forward(epilogue_params)...); - - typename Gemm::EVTD::Arguments epilogue_args{ - evt_args, - d_args, - }; - - typename Gemm::Op::Arguments args{ - cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, // universal mode - problem_size, // problem size - 1, // batch count - epilogue_args, - a_ptr, - b_ptr, - nullptr, - nullptr, - 0, - 0, - 0, - 0, - lda, - ldb, - ldc, - ldc}; - - // Launch the CUTLASS GEMM kernel. - typename Gemm::Op gemm_op; - size_t workspace_size = gemm_op.get_workspace_size(args); - cutlass::device_memory::allocation workspace(workspace_size); - - auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); - - CUTLASS_CHECK(gemm_op.can_implement(args)); - cutlass::Status status = gemm_op(args, workspace.get(), stream); - CUTLASS_CHECK(status); -} - -template -void fallback_cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - EpilogueArgs&&... args) { - // In some cases, the GPU isn't able to accommodate the - // shared memory requirements of the Gemm. In such cases, use - // the FallbackGemm instead. - static const int max_shared_mem_per_block_opt_in = - get_cuda_max_shared_memory_per_block_opt_in(0); - - size_t const gemm_shared_mem_size = - sizeof(typename Gemm::KernelType::SharedStorage); - size_t const fallback_gemm_shared_mem_size = - sizeof(typename FallbackGemm::KernelType::SharedStorage); - - if (gemm_shared_mem_size <= max_shared_mem_per_block_opt_in) { - return cutlass_gemm_caller(out, a, b, - std::forward(args)...); - } else { - TORCH_CHECK(fallback_gemm_shared_mem_size <= - max_shared_mem_per_block_opt_in); - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } -} - -template typename Epilogue> -struct sm80_config_default { - // This config is used in 2 cases, - // - M in (128, inf) - // - M in (64, 128] and N >= 8192 - // Shared Memory required by this Gemm - 81920 bytes - static_assert(std::is_same()); - using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; - using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; - using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; - using Cutlass2xGemm = - cutlass_2x_gemm; -}; - -template typename Epilogue> -struct sm80_config_M64 { - // This config is used in 2 cases, - // - M in (32, 64] - // - M in (64, 128] and N < 8192 - // Shared Memory required by this Gemm - 122880 bytes - static_assert(std::is_same()); - using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>; - using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; - using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; - using Cutlass2xGemm = - cutlass_2x_gemm; -}; - -template typename Epilogue> -struct sm80_config_M32 { - // M in (16, 32] - // Shared Memory required by this Gemm - 61440 bytes - static_assert(std::is_same()); - using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>; - using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>; - using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; - using Cutlass2xGemm = - cutlass_2x_gemm; -}; - -template typename Epilogue> -struct sm80_config_M16 { - // M in [1, 16] - // Shared Memory required by this Gemm - 51200 bytes - static_assert(std::is_same()); - using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>; - using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; - using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; - using Cutlass2xGemm = - cutlass_2x_gemm; -}; - -} // namespace - -template typename Epilogue, - typename... EpilogueArgs> -void cutlass_gemm_sm80_dispatch(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - EpilogueArgs&&... args) { - static_assert(std::is_same()); - TORCH_CHECK(a.dtype() == torch::kInt8); - TORCH_CHECK(b.dtype() == torch::kInt8); - - using Cutlass2xGemmDefault = - typename sm80_config_default::Cutlass2xGemm; - using Cutlass2xGemmM128BigN = - typename sm80_config_default::Cutlass2xGemm; - using Cutlass2xGemmM128SmallN = - typename sm80_config_M64::Cutlass2xGemm; - using Cutlass2xGemmM64 = - typename sm80_config_M64::Cutlass2xGemm; - using Cutlass2xGemmM32 = - typename sm80_config_M32::Cutlass2xGemm; - using Cutlass2xGemmM16 = - typename sm80_config_M16::Cutlass2xGemm; - - // Due to shared memory requirements, some Gemms may fail to run on some - // GPUs. As the name indicates, the Fallback Gemm is used as an alternative - // in such cases. - // sm80_config_M16 has the least shared-memory requirement. However, - // based on some profiling, we select sm80_config_M32 as a better alternative - // performance wise. - using FallbackGemm = - typename sm80_config_M32::Cutlass2xGemm; - - uint32_t const m = a.size(0); - uint32_t const mp2 = - std::max(static_cast(16), next_pow_2(m)); // next power of 2 - if (mp2 <= 16) { - // M in [1, 16] - return fallback_cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } else if (mp2 <= 32) { - // M in (16, 32] - return fallback_cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } else if (mp2 <= 64) { - // M in (32, 64] - return fallback_cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } else if (mp2 <= 128) { - // M in (64, 128] - uint32_t const n = out.size(1); - bool const small_n = n < 8192; - if (small_n) { - return fallback_cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } else { - return fallback_cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } - } else { - // M in (128, inf) - return fallback_cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } -} template