diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml index f96c38bf57d..92a1bcada38 100644 --- a/.buildkite/release-pipeline.yaml +++ b/.buildkite/release-pipeline.yaml @@ -7,7 +7,7 @@ steps: commands: # #NOTE: torch_cuda_arch_list is derived from upstream PyTorch build files here: # https://github.com/pytorch/pytorch/blob/main/.ci/aarch64_linux/aarch64_ci_build.sh#L7 - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --build-arg torch_cuda_arch_list='8.7 9.0 10.0+PTX' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --build-arg torch_cuda_arch_list='8.7 9.0 10.0+PTX 12.0' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." - "mkdir artifacts" - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" - "bash .buildkite/scripts/upload-wheels.sh" @@ -62,23 +62,45 @@ steps: env: DOCKER_BUILDKIT: "1" - - block: "Build release image" + - label: "Build release image (x86)" depends_on: ~ - key: block-release-image-build - - - label: "Build release image" - depends_on: block-release-image-build - id: build-release-image + id: build-release-image-x86 agents: queue: cpu_queue_postmerge commands: - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --build-arg FLASHINFER_AOT_COMPILE=true --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain -f docker/Dockerfile ." + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --build-arg FLASHINFER_AOT_COMPILE=true --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m) --target vllm-openai --progress plain -f docker/Dockerfile ." + - "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m)" + # re-tag to default image tag and push, just in case arm64 build fails + - "docker tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m) public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT" - "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT" + - label: "Build release image (arm64)" + depends_on: ~ + id: build-release-image-arm64 + agents: + queue: arm64_cpu_queue_postmerge + commands: + - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --build-arg torch_cuda_arch_list='8.7 9.0 10.0+PTX 12.0' --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m) --target vllm-openai --progress plain -f docker/Dockerfile ." + - "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m)" + + # Add job to create multi-arch manifest + - label: "Create multi-arch manifest" + depends_on: + - build-release-image-x86 + - build-release-image-arm64 + id: create-multi-arch-manifest + agents: + queue: cpu_queue_postmerge + commands: + - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" + - "docker manifest create public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-x86_64 public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-aarch64 --amend" + - "docker manifest push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT" + - label: "Annotate release workflow" depends_on: - - build-release-image + - create-multi-arch-manifest - build-wheel-cuda-12-8 - build-wheel-cuda-12-6 - build-wheel-cuda-11-8 diff --git a/.buildkite/scripts/hardware_ci/run-amd-test.sh b/.buildkite/scripts/hardware_ci/run-amd-test.sh index df0bae0c9cb..c395011a244 100755 --- a/.buildkite/scripts/hardware_ci/run-amd-test.sh +++ b/.buildkite/scripts/hardware_ci/run-amd-test.sh @@ -164,7 +164,6 @@ if [[ $commands == *" entrypoints/llm "* ]]; then --ignore=entrypoints/llm/test_chat.py \ --ignore=entrypoints/llm/test_accuracy.py \ --ignore=entrypoints/llm/test_init.py \ - --ignore=entrypoints/llm/test_generate_multiple_loras.py \ --ignore=entrypoints/llm/test_prompt_validation.py "} fi diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test.sh b/.buildkite/scripts/hardware_ci/run-cpu-test.sh index 9dec9f8e9eb..0f734763f13 100644 --- a/.buildkite/scripts/hardware_ci/run-cpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-test.sh @@ -25,8 +25,8 @@ numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --tag cpu-test-"$NUMA_NODE numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" --tag cpu-test-"$NUMA_NODE"-avx2 --target vllm-test -f docker/Dockerfile.cpu . # Run the image, setting --shm-size=4g for tensor parallel. -docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE" -docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2 +docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=16 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE" +docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=16 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2 function cpu_tests() { set -e @@ -49,23 +49,23 @@ function cpu_tests() { # Run kernel tests docker exec cpu-test-"$NUMA_NODE" bash -c " set -e - pytest -v -s tests/kernels/test_onednn.py" + pytest -x -v -s tests/kernels/test_onednn.py" # Run basic model test docker exec cpu-test-"$NUMA_NODE" bash -c " set -e # Note: disable until supports V1 - # pytest -v -s tests/kernels/attention/test_cache.py -m cpu_model - # pytest -v -s tests/kernels/attention/test_mla_decode_cpu.py -m cpu_model + # pytest -x -v -s tests/kernels/attention/test_cache.py -m cpu_model + # pytest -x -v -s tests/kernels/attention/test_mla_decode_cpu.py -m cpu_model # Note: disable Bart until supports V1 - pytest -v -s tests/models/language/generation -m cpu_model \ + pytest -x -v -s tests/models/language/generation -m cpu_model \ --ignore=tests/models/language/generation/test_bart.py - VLLM_CPU_SGL_KERNEL=1 pytest -v -s tests/models/language/generation -m cpu_model \ + VLLM_CPU_SGL_KERNEL=1 pytest -x -v -s tests/models/language/generation -m cpu_model \ --ignore=tests/models/language/generation/test_bart.py - pytest -v -s tests/models/language/pooling -m cpu_model - pytest -v -s tests/models/multimodal/generation \ + pytest -x -v -s tests/models/language/pooling -m cpu_model + pytest -x -v -s tests/models/multimodal/generation \ --ignore=tests/models/multimodal/generation/test_mllama.py \ --ignore=tests/models/multimodal/generation/test_pixtral.py \ -m cpu_model" @@ -73,33 +73,49 @@ function cpu_tests() { # Run compressed-tensor test docker exec cpu-test-"$NUMA_NODE" bash -c " set -e - pytest -s -v \ + pytest -x -s -v \ tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_logprobs[False-10-32-neuralmagic/Llama-3.2-1B-quantized.w8a8]" # Note: disable it until supports V1 # Run AWQ test # docker exec cpu-test-"$NUMA_NODE" bash -c " # set -e - # VLLM_USE_V1=0 pytest -s -v \ + # VLLM_USE_V1=0 pytest -x -s -v \ # tests/quantization/test_ipex_quant.py" # Run multi-lora tests docker exec cpu-test-"$NUMA_NODE" bash -c " set -e - pytest -s -v \ + pytest -x -s -v \ tests/lora/test_qwen2vl.py" - # online serving + # online serving: tp+pp docker exec cpu-test-"$NUMA_NODE" bash -c ' set -e VLLM_CPU_OMP_THREADS_BIND=$E2E_OMP_THREADS VLLM_CPU_SGL_KERNEL=1 vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -pp=2 & + server_pid=$! timeout 600 bash -c "until curl localhost:8000/v1/models; do sleep 1; done" || exit 1 vllm bench serve \ --backend vllm \ --dataset-name random \ --model meta-llama/Llama-3.2-3B-Instruct \ --num-prompts 20 \ - --endpoint /v1/completions' + --endpoint /v1/completions + kill -s SIGTERM $server_pid &' + + # online serving: tp+dp + docker exec cpu-test-"$NUMA_NODE" bash -c ' + set -e + VLLM_CPU_OMP_THREADS_BIND=$E2E_OMP_THREADS VLLM_CPU_SGL_KERNEL=1 vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -dp=2 & + server_pid=$! + timeout 600 bash -c "until curl localhost:8000/v1/models; do sleep 1; done" || exit 1 + vllm bench serve \ + --backend vllm \ + --dataset-name random \ + --model meta-llama/Llama-3.2-3B-Instruct \ + --num-prompts 20 \ + --endpoint /v1/completions + kill -s SIGTERM $server_pid &' } # All of CPU tests are expected to be finished less than 40 mins. diff --git a/.buildkite/scripts/hardware_ci/run-xpu-test.sh b/.buildkite/scripts/hardware_ci/run-xpu-test.sh index 445cd2735c1..73f3e63fbf5 100644 --- a/.buildkite/scripts/hardware_ci/run-xpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-xpu-test.sh @@ -31,6 +31,7 @@ docker run \ set -e echo $ZE_AFFINITY_MASK VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager + VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 -O3 -O.cudagraph_mode=NONE VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp cd tests diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 0d3b7a294d9..482808cd07e 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -109,10 +109,9 @@ steps: - tests/entrypoints/offline_mode commands: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_collective_rpc.py + - pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_collective_rpc.py - pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process - pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process - - pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process - VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests - label: Entrypoints Test (API Server) # 40min @@ -234,16 +233,33 @@ steps: # OOM in the CI unless we run this separately - pytest -v -s tokenization -- label: V1 Test +- label: V1 Test e2e + engine mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ - tests/v1 commands: - # split the test to avoid interference - - pytest -v -s v1/core + # TODO: accuracy does not match, whether setting + # VLLM_USE_FLASHINFER_SAMPLER or not on H100. + - pytest -v -s v1/e2e - pytest -v -s v1/engine + +- label: V1 Test entrypoints + mirror_hardwares: [amdexperimental] + source_file_dependencies: + - vllm/ + - tests/v1 + commands: - pytest -v -s v1/entrypoints + +- label: V1 Test others + mirror_hardwares: [amdexperimental] + source_file_dependencies: + - vllm/ + - tests/v1 + commands: + # split the test to avoid interference + - pytest -v -s v1/core - pytest -v -s v1/executor - pytest -v -s v1/sample - pytest -v -s v1/logits_processors @@ -256,9 +272,6 @@ steps: - pytest -v -s v1/test_utils.py - pytest -v -s v1/test_oracle.py - pytest -v -s v1/test_metrics_reader.py - # TODO: accuracy does not match, whether setting - # VLLM_USE_FLASHINFER_SAMPLER or not on H100. - - pytest -v -s v1/e2e # Integration test for streaming correctness (requires special branch). - pip install -U git+https://github.com/robertgshaw2-redhat/lm-evaluation-harness.git@streaming-api - pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine @@ -312,7 +325,7 @@ steps: source_file_dependencies: - vllm/lora - tests/lora - command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py + command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py --ignore=lora/test_llm_with_multi_loras.py parallelism: 4 - label: PyTorch Compilation Unit Tests @@ -449,8 +462,8 @@ steps: - tests/quantization commands: # temporary install here since we need nightly, will move to requirements/test.in - # after torchao 0.12 release - - pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu126 + # after torchao 0.12 release, and pin a working version of torchao nightly here + - pip install --pre torchao==0.13.0.dev20250814 --index-url https://download.pytorch.org/whl/nightly/cu128 - VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization - label: LM Eval Small Models # 53min @@ -654,6 +667,7 @@ steps: # Quantization - pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8' - pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py + - pytest -v -s tests/kernels/quantization/test_silu_nvfp4_quant_fusion.py - pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py - pytest -v -s tests/kernels/quantization/test_flashinfer_scaled_mm.py - pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py @@ -663,6 +677,7 @@ steps: - pytest -v -s tests/compile/test_fusion_all_reduce.py - pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern - pytest -v -s tests/kernels/moe/test_flashinfer.py + - pytest -v -s tests/compile/test_silu_mul_quant_fusion.py ##### 1 GPU test ##### ##### multi gpus test ##### @@ -791,13 +806,14 @@ steps: # requires multi-GPU testing for validation. - pytest -v -s -x lora/test_chatglm3_tp.py - pytest -v -s -x lora/test_llama_tp.py - - pytest -v -s -x lora/test_multi_loras_with_tp.py + - pytest -v -s -x lora/test_llm_with_multi_loras.py - label: Weight Loading Multiple GPU Test # 33min mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" - num_gpus: 2 + num_gpus: 2 + optional: true source_file_dependencies: - vllm/ - tests/weight_loading diff --git a/.github/scale-config.yml b/.github/scale-config.yml new file mode 100644 index 00000000000..c41a3ee3eb1 --- /dev/null +++ b/.github/scale-config.yml @@ -0,0 +1,21 @@ +# scale-config.yml: +# Powers what instance types are available for GHA auto-scaled +# runners. Runners listed here will be available as self hosted +# runners, configuration is directly pulled from the main branch. +# runner_types: +# runner_label: +# instance_type: m4.large +# os: linux +# # min_available defaults to the global cfg in the ALI Terraform +# min_available: undefined +# # when max_available value is not defined, no max runners is enforced +# max_available: undefined +# disk_size: 50 +# is_ephemeral: true + +runner_types: + linux.2xlarge: + disk_size: 150 + instance_type: c5.2xlarge + is_ephemeral: true + os: linux diff --git a/.github/workflows/issue_autolabel.yml b/.github/workflows/issue_autolabel.yml index 6401d6586cc..e0ab3872d8f 100644 --- a/.github/workflows/issue_autolabel.yml +++ b/.github/workflows/issue_autolabel.yml @@ -49,6 +49,10 @@ jobs: term: "VLLM_ROCM_", searchIn: "both" }, + { + term: "aiter", + searchIn: "title" + }, { term: "rocm", searchIn: "title" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 612b290e88d..c16bdeeecd0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,7 +21,7 @@ repos: - id: ruff-format files: ^(.buildkite|benchmarks|examples)/.* - repo: https://github.com/crate-ci/typos - rev: v1.34.0 + rev: v1.35.5 hooks: - id: typos - repo: https://github.com/PyCQA/isort diff --git a/CMakeLists.txt b/CMakeLists.txt index b0eb0f32e03..3f1f9a781a0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -45,8 +45,8 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1 # requirements.txt files and should be kept consistent. The ROCm torch # versions are derived from docker/Dockerfile.rocm # -set(TORCH_SUPPORTED_VERSION_CUDA "2.7.1") -set(TORCH_SUPPORTED_VERSION_ROCM "2.7.0") +set(TORCH_SUPPORTED_VERSION_CUDA "2.8.0") +set(TORCH_SUPPORTED_VERSION_ROCM "2.8.0") # # Try to find python package with an executable that exactly matches @@ -541,6 +541,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS) set(SRCS "csrc/quantization/fp4/nvfp4_quant_kernels.cu" + "csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu" "csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" @@ -559,6 +560,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS) set(SRCS "csrc/quantization/fp4/nvfp4_quant_kernels.cu" + "csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu" "csrc/quantization/fp4/nvfp4_experts_quant.cu" "csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu" "csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu") diff --git a/README.md b/README.md index ef5b4358895..8812aac4ea2 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ Easy, fast, and cheap LLM serving for everyone *Latest News* 🔥 - [2025/08] We hosted [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/pDmAXHcN7Iqc8sUKgJgGtg) focusing on building, developing, and integrating with vLLM! Please find the meetup slides [here](https://drive.google.com/drive/folders/1OvLx39wnCGy_WKq8SiVKf7YcxxYI3WCH). +- [2025/08] We hosted [vLLM Korea Meetup](https://luma.com/cgcgprmh) with Red Hat and Rebellions! We shared the latest advancements in vLLM along with project spotlights from the vLLM Korea community. Please find the meetup slides [here](https://drive.google.com/file/d/1bcrrAE1rxUgx0mjIeOWT6hNe2RefC5Hm/view). - [2025/08] We hosted [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/dgkWg1WFpWGO2jCdTqQHxA) focusing on large-scale LLM deployment! Please find the meetup slides [here](https://drive.google.com/drive/folders/1Pid6NSFLU43DZRi0EaTcPgXsAzDvbBqF) and the recording [here](https://www.chaspark.com/#/live/1166916873711665152). - [2025/05] vLLM is now a hosted project under PyTorch Foundation! Please find the announcement [here](https://pytorch.org/blog/pytorch-foundation-welcomes-vllm/). - [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html). diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index c7f290e1eb8..6b24b8c8f3c 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -96,7 +96,6 @@ def run_vllm( end = time.perf_counter() else: assert lora_requests is None, "BeamSearch API does not support LoRA" - prompts = [request.prompt for request in requests] # output_len should be the same for all requests. output_len = requests[0].expected_output_len for request in requests: diff --git a/benchmarks/kernels/bench_block_fp8_gemm.py b/benchmarks/kernels/bench_block_fp8_gemm.py new file mode 100644 index 00000000000..9663503e9ba --- /dev/null +++ b/benchmarks/kernels/bench_block_fp8_gemm.py @@ -0,0 +1,114 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + w8a8_block_fp8_matmul, +) +from vllm.platforms import current_platform +from vllm.triton_utils import triton as vllm_triton + +assert current_platform.is_cuda(), ( + "Only support benchmarking w8a8 block fp8 kernel on CUDA device." +) + +# DeepSeek-V3 weight shapes +DEEPSEEK_V3_SHAPES = [ + (512 + 64, 7168), + (2112, 7168), + ((128 + 64) * 128, 7168), + (128 * (128 + 128), 512), + (7168, 16384), + (7168, 18432), + (18432 * 2, 7168), + (24576, 1536), + (12288, 7168), + (4096, 7168), + (7168, 2048), +] + + +def build_w8a8_block_fp8_runner(M, N, K, block_size, device): + """Build runner function for w8a8 block fp8 matmul.""" + factor_for_scale = 1e-2 + + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + # Create random FP8 tensors + A_fp32 = (torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max + A = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + B_fp32 = (torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max + B = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + # Create scales + block_n, block_k = block_size[0], block_size[1] + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + + As = torch.rand(M, k_tiles, dtype=torch.float32, device=device) * factor_for_scale + Bs = ( + torch.rand(n_tiles, k_tiles, dtype=torch.float32, device=device) + * factor_for_scale + ) + + def run(): + return w8a8_block_fp8_matmul(A, B, As, Bs, block_size, torch.bfloat16) + + return run + + +@vllm_triton.testing.perf_report( + vllm_triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384], + x_log=False, + line_arg="provider", + line_vals=["torch-bf16", "w8a8-block-fp8"], + line_names=["torch-bf16", "w8a8-block-fp8"], + ylabel="TFLOP/s (larger is better)", + plot_name="BF16 vs W8A8 Block FP8 GEMMs", + args={}, + ) +) +def benchmark_tflops(batch_size, provider, N, K, block_size=(128, 128)): + M = batch_size + device = "cuda" + + quantiles = [0.5, 0.2, 0.8] + + if provider == "torch-bf16": + a = torch.randn((M, K), device=device, dtype=torch.bfloat16) + b = torch.randn((N, K), device=device, dtype=torch.bfloat16) + ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph( + lambda: torch.nn.functional.linear(a, b), quantiles=quantiles + ) + else: # w8a8-block-fp8 + run_w8a8 = build_w8a8_block_fp8_runner(M, N, K, block_size, device) + ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph( + lambda: run_w8a8(), quantiles=quantiles + ) + + to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3) + return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms) + + +if __name__ == "__main__": + block_size = (128, 128) + + for N, K in DEEPSEEK_V3_SHAPES: + print(f"\nBenchmarking DeepSeek-V3, N={N} K={K}") + + print(f"TFLOP/s comparison (block_size={block_size}):") + benchmark_tflops.run( + print_data=True, + # show_plots=False, + # save_path=f"bench_w8a8_block_fp8_tflops_n{N}_k{K}", + N=N, + K=K, + block_size=block_size, + ) + + print("\nBenchmark finished!") diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 752c2d00821..710d30adfd8 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -419,8 +419,10 @@ def benchmark( ) # NOTE(woosuk): The current naming convention uses w2.shape[2], which # is the intermediate size after silu_and_mul. + block_n = block_quant_shape[0] if block_quant_shape else None + block_k = block_quant_shape[1] if block_quant_shape else None op_config = get_moe_configs( - num_experts, shard_intermediate_size // 2, dtype_str + num_experts, shard_intermediate_size // 2, dtype_str, block_n, block_k ) if op_config is None: config = get_default_config( @@ -430,6 +432,7 @@ def benchmark( hidden_size, topk, dtype_str, + block_quant_shape, ) else: config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))] diff --git a/benchmarks/kernels/benchmark_w8a8_block_fp8.py b/benchmarks/kernels/benchmark_w8a8_block_fp8.py index e648a91077f..98bde9d83c8 100644 --- a/benchmarks/kernels/benchmark_w8a8_block_fp8.py +++ b/benchmarks/kernels/benchmark_w8a8_block_fp8.py @@ -141,6 +141,7 @@ def get_weight_shapes(tp_size): # cannot TP total = [ (512 + 64, 7168), + (2112, 7168), ((128 + 64) * 128, 7168), (128 * (128 + 128), 512), (7168, 16384), diff --git a/csrc/cache.h b/csrc/cache.h index fb0c353b961..e8e069aefd9 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -36,6 +36,13 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe, const std::string& kv_cache_dtype, torch::Tensor& scale); +void cp_fused_concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe, + torch::Tensor& cp_local_token_select_indices, + torch::Tensor& kv_cache, + torch::Tensor& slot_mapping, + const std::string& kv_cache_dtype, + torch::Tensor& scale); + // Just for unittest void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, const double scale, const std::string& kv_cache_dtype); @@ -47,4 +54,12 @@ void gather_and_maybe_dequant_cache( torch::Tensor const& cu_seq_lens, // [BATCH+1] int64_t batch_size, const std::string& kv_cache_dtype, torch::Tensor const& scale, - std::optional seq_starts = std::nullopt); \ No newline at end of file + std::optional seq_starts = std::nullopt); + +// TODO(hc): cp_gather_cache need support scaled kvcahe in the future. +void cp_gather_cache( + torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] + torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...] + torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] + torch::Tensor const& cu_seq_lens, // [BATCH+1] + int64_t batch_size, std::optional seq_starts = std::nullopt); diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index b3a985c2d5b..fbb022464ef 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -1,6 +1,7 @@ #include #include #include +#include #include "cuda_utils.h" #include "cuda_compat.h" @@ -395,6 +396,51 @@ __global__ void concat_and_cache_mla_kernel( copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank); } +template +__global__ void cp_fused_concat_and_cache_mla_kernel( + const scalar_t* __restrict__ kv_c, // [num_full_tokens, kv_lora_rank] + const scalar_t* __restrict__ k_pe, // [num_full_tokens, pe_dim] + const int64_t* __restrict__ cp_local_token_select_indices, // [num_tokens] + cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank + // + pe_dim)] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int block_stride, // + const int entry_stride, // + const int kv_c_stride, // + const int k_pe_stride, // + const int kv_lora_rank, // + const int pe_dim, // + const int block_size, // + const float* scale // +) { + const int64_t token_idx = cp_local_token_select_indices[blockIdx.x]; + const int64_t slot_idx = slot_mapping[blockIdx.x]; + // NOTE: slot_idx can be -1 if the token is padded + if (slot_idx < 0) { + return; + } + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + + auto copy = [&](const scalar_t* __restrict__ src, cache_t* __restrict__ dst, + int src_stride, int dst_stride, int size, int offset) { + for (int i = threadIdx.x; i < size; i += blockDim.x) { + const int64_t src_idx = token_idx * src_stride + i; + const int64_t dst_idx = + block_idx * block_stride + block_offset * entry_stride + i + offset; + if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { + dst[dst_idx] = src[src_idx]; + } else { + dst[dst_idx] = + fp8::scaled_convert(src[src_idx], *scale); + } + } + }; + + copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0); + copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank); +} + } // namespace vllm // KV_T is the data type of key and value tensors. @@ -508,6 +554,20 @@ void reshape_and_cache_flash( kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \ reinterpret_cast(scale.data_ptr())); +// KV_T is the data type of key and value tensors. +// CACHE_T is the stored data type of kv-cache. +// KV_DTYPE is the real data type of kv-cache. +#define CALL_CP_FUSED_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \ + vllm::cp_fused_concat_and_cache_mla_kernel \ + <<>>( \ + reinterpret_cast(kv_c.data_ptr()), \ + reinterpret_cast(k_pe.data_ptr()), \ + cp_local_token_select_indices.data_ptr(), \ + reinterpret_cast(kv_cache.data_ptr()), \ + slot_mapping.data_ptr(), block_stride, entry_stride, \ + kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \ + reinterpret_cast(scale.data_ptr())); + void concat_and_cache_mla( torch::Tensor& kv_c, // [num_tokens, kv_lora_rank] torch::Tensor& k_pe, // [num_tokens, pe_dim] @@ -546,6 +606,50 @@ void concat_and_cache_mla( CALL_CONCAT_AND_CACHE_MLA); } +// Note(hc): cp_fused_concat_and_cache_mla fuses the following three kernel +// calls into one: +// k_c_normed.index_select(0, cp_local_token_select_indices) + \ +// k_pe.squeeze(1).index_select(0, cp_local_token_select_indices) + \ +// concat_and_cache_mla. +void cp_fused_concat_and_cache_mla( + torch::Tensor& kv_c, // [num_total_tokens, kv_lora_rank] + torch::Tensor& k_pe, // [num_total_tokens, pe_dim] + torch::Tensor& cp_local_token_select_indices, // [num_tokens] + torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank + + // pe_dim)] + torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens] + const std::string& kv_cache_dtype, torch::Tensor& scale) { + // NOTE(woosuk): In vLLM V1, key.size(0) can be different from + // slot_mapping.size(0) because of padding for CUDA graphs. + // In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because + // both include padding. + // In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0) + // since key includes padding for CUDA graphs, while slot_mapping does not. + // In this case, slot_mapping.size(0) represents the actual number of tokens + // before padding. + // For compatibility with both cases, we use slot_mapping.size(0) as the + // number of tokens. + int num_tokens = slot_mapping.size(0); + int kv_lora_rank = kv_c.size(1); + int pe_dim = k_pe.size(1); + int block_size = kv_cache.size(1); + + TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim); + + int kv_c_stride = kv_c.stride(0); + int k_pe_stride = k_pe.stride(0); + int block_stride = kv_cache.stride(0); + int entry_stride = kv_cache.stride(1); + + dim3 grid(num_tokens); + dim3 block(std::min(kv_lora_rank, 512)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_c)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, + CALL_CP_FUSED_CONCAT_AND_CACHE_MLA); +} + namespace vllm { template @@ -779,3 +883,145 @@ void gather_and_maybe_dequant_cache( DISPATCH_BY_KV_CACHE_DTYPE(dst.dtype(), kv_cache_dtype, CALL_GATHER_CACHE); } + +namespace vllm { +template +// Note(hc): The cp_gather_cache allows seq_starts to no longer be divisible by +// block_size. +__global__ void cp_gather_cache( + const scalar_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, + // ENTRY_SIZE] + scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRY_SIZE] + const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES] + const int32_t* __restrict__ cu_seq_lens, // [BATCH+1] + const int32_t block_size, const int32_t entry_size, + const int64_t block_table_stride, const int64_t cache_block_stride, + const int64_t cache_entry_stride, const int64_t dst_entry_stride, + const int32_t* __restrict__ seq_starts // Optional: starting offsets per + // batch +) { + const int64_t bid = blockIdx.x; // Batch ID + const int32_t num_splits = gridDim.y; + const int32_t split = blockIdx.y; + const int32_t seq_start = cu_seq_lens[bid]; + const int32_t seq_end = cu_seq_lens[bid + 1]; + const int32_t seq_len = seq_end - seq_start; + const int32_t tot_slots = seq_len; + const int32_t split_slots = cuda_utils::ceil_div(tot_slots, num_splits); + + const int32_t split_start = split * split_slots; + const int32_t split_end = min((split + 1) * split_slots, tot_slots); + + const bool is_active_split = (split_start < tot_slots); + + if (!is_active_split) return; + + // Adjust the pointer for the block_table for this batch. + // If seq_starts is provided, compute an offset based on it + const int32_t batch_offset = bid * block_table_stride; + int32_t offset = split_start; + if (seq_starts != nullptr) { + offset += seq_starts[bid]; + } + int32_t offset_div = offset / block_size; + offset = offset % block_size; + const int32_t* batch_block_table = block_table + batch_offset; + + // Adjust dst pointer based on the cumulative sequence lengths. + dst += seq_start * dst_entry_stride; + + auto copy_entry = [&](const scalar_t* __restrict__ _src, + scalar_t* __restrict__ _dst) { + for (int i = threadIdx.x; i < entry_size; i += blockDim.x) + _dst[i] = _src[i]; + }; + + for (int pid = split_start; pid < split_end; ++pid) { + auto block_id = batch_block_table[offset_div]; + auto block_start_ptr = src_cache + block_id * cache_block_stride; + auto block_dst_ptr = dst + pid * dst_entry_stride; + copy_entry(block_start_ptr + offset * cache_entry_stride, block_dst_ptr); + offset += 1; + // bump to next block + if (offset == block_size) { + offset_div += 1; + offset = 0; + } + } +} +} // namespace vllm + +// Macro to dispatch the kernel based on the data type. +#define CALL_CP_GATHER_CACHE(CPY_DTYPE) \ + vllm::cp_gather_cache<<>>( \ + reinterpret_cast(src_cache.data_ptr()), \ + reinterpret_cast(dst.data_ptr()), \ + block_table.data_ptr(), cu_seq_lens.data_ptr(), \ + block_size, entry_size, block_table_stride, cache_block_stride, \ + cache_entry_stride, dst_entry_stride, seq_starts_ptr); + +// Gather sequences from the cache into the destination tensor. +// - cu_seq_lens contains the cumulative sequence lengths for each batch +// - block_table contains the cache block indices for each sequence +// - Optionally, seq_starts (if provided) offsets the starting slot index by +// seq_starts[bid] +void cp_gather_cache( + torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] + torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...] + torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] + torch::Tensor const& cu_seq_lens, // [BATCH+1] + int64_t batch_size, + std::optional seq_starts = std::nullopt) { + at::cuda::OptionalCUDAGuard device_guard(src_cache.device()); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + int32_t block_size = src_cache.size(1); + int32_t entry_size = src_cache.flatten(2, -1).size(2); + + TORCH_CHECK(block_table.dtype() == torch::kInt32, + "block_table must be int32"); + TORCH_CHECK(cu_seq_lens.dtype() == torch::kInt32, + "cu_seq_lens must be int32"); + if (seq_starts.has_value()) { + TORCH_CHECK(seq_starts.value().dtype() == torch::kInt32, + "seq_starts must be int32"); + } + + TORCH_CHECK(src_cache.device() == dst.device(), + "src_cache and dst must be on the same device"); + TORCH_CHECK(src_cache.device() == block_table.device(), + "src_cache and block_table must be on the same device"); + TORCH_CHECK(src_cache.device() == cu_seq_lens.device(), + "src_cache and cu_seq_lens must be on the same device"); + if (seq_starts.has_value()) { + TORCH_CHECK(src_cache.device() == seq_starts.value().device(), + "src_cache and seq_starts must be on the same device"); + } + + int64_t block_table_stride = block_table.stride(0); + int64_t cache_block_stride = src_cache.stride(0); + int64_t cache_entry_stride = src_cache.stride(1); + int64_t dst_entry_stride = dst.stride(0); + + // Decide on the number of splits based on the batch size. + int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16; + dim3 grid(batch_size, num_splits); + dim3 block(1024); + + TORCH_CHECK(src_cache.dtype() == dst.dtype(), + "src_cache and dst must have the same dtype"); + + const int dtype_bits = src_cache.element_size() * 8; + const int32_t* seq_starts_ptr = + seq_starts.has_value() ? seq_starts.value().data_ptr() : nullptr; + + if (dtype_bits == 32) { + CALL_CP_GATHER_CACHE(uint32_t); + } else if (dtype_bits == 16) { + CALL_CP_GATHER_CACHE(uint16_t); + } else if (dtype_bits == 8) { + CALL_CP_GATHER_CACHE(uint8_t); + } else { + TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits); + } +} diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index f7b75c48373..2728aa81f0c 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -19,6 +19,13 @@ #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) +#define VLLM_DISPATCH_CASE_HALF_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + +#define VLLM_DISPATCH_HALF_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_HALF_TYPES(__VA_ARGS__)) + // ROCm devices might use either fn or fnuz, so set up dispatch table for both. // A host-based check at runtime will create a preferred FP8 type for ROCm // such that the correct kernel is dispatched. @@ -45,6 +52,15 @@ #define VLLM_DISPATCH_FP8_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__)) +#define AT_DISPATCH_BYTE_CASE(enum_type, ...) \ + AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, byte_t, __VA_ARGS__) + +#define VLLM_DISPATCH_CASE_BYTE_TYPES(...) \ + AT_DISPATCH_BYTE_CASE(at::ScalarType::Byte, __VA_ARGS__) + +#define VLLM_DISPATCH_BYTE_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_BYTE_TYPES(__VA_ARGS__)) + #define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__)) diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu index 99c52ef17d0..cd80bfda7df 100644 --- a/csrc/moe/topk_softmax_kernels.cu +++ b/csrc/moe/topk_softmax_kernels.cu @@ -573,7 +573,7 @@ void topk_softmax( stream); } else { - assert(topk_indices.scalar_type() == at::ScalarType::Int64); + TORCH_CHECK(topk_indices.scalar_type() == at::ScalarType::Long); vllm::moe::topkGatingSoftmaxKernelLauncher( gating_output.data_ptr(), topk_weights.data_ptr(), diff --git a/csrc/ops.h b/csrc/ops.h index 86fe848e2fd..7a176a5c003 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -130,6 +130,14 @@ void silu_and_mul(torch::Tensor& out, torch::Tensor& input); void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& scale); +#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ + (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) +void silu_and_mul_nvfp4_quant(torch::Tensor& out, + torch::Tensor& output_block_scale, + torch::Tensor& input, + torch::Tensor& input_global_scale); +#endif + void mul_and_silu(torch::Tensor& out, torch::Tensor& input); void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); diff --git a/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu b/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu new file mode 100644 index 00000000000..9bbeb0334fb --- /dev/null +++ b/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu @@ -0,0 +1,368 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include + +#include +#include + +#include +#include "dispatch_utils.h" + +#include "cuda_utils.h" + +namespace vllm { + +// Get type2 from type or vice versa (applied to half and bfloat16) +template +struct TypeConverter { + using Type = half2; +}; // keep for generality + +template <> +struct TypeConverter { + using Type = c10::Half; +}; + +template <> +struct TypeConverter { + using Type = half2; +}; + +template <> +struct TypeConverter<__nv_bfloat162> { + using Type = c10::BFloat16; +}; + +template <> +struct TypeConverter { + using Type = __nv_bfloat162; +}; + +#define ELTS_PER_THREAD 8 + +constexpr int CVT_FP4_ELTS_PER_THREAD = 8; +constexpr int CVT_FP4_SF_VEC_SIZE = 16; + +// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). +inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]), + "f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7])); + return val; +#else + return 0; +#endif +} + +// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). +inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), + "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y)); + return val; +#else + return 0; +#endif +} + +// Fast reciprocal. +inline __device__ float reciprocal_approximate_ftz(float a) { + float b; + asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); + return b; +} + +template +__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx, + int numCols, + SFType* SFout) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || + CVT_FP4_NUM_THREADS_PER_SF == 2); + + // One pair of threads write one SF to global memory. + // TODO: stage through smem for packed STG.32 + // is it better than STG.8 from 4 threads ? + if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) { + // SF vector index (16 elements share one SF in the K dimension). + int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; + int32_t mIdx = rowIdx; + + // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] + // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] + + int32_t mTileIdx = mIdx / (32 * 4); + // SF vector size 16. + int factor = CVT_FP4_SF_VEC_SIZE * 4; + int32_t numKTiles = (numCols + factor - 1) / factor; + int64_t mTileStride = numKTiles * 32 * 4 * 4; + + int32_t kTileIdx = (kIdx / 4); + int64_t kTileStride = 32 * 4 * 4; + + // M tile layout [32, 4] is column-major. + int32_t outerMIdx = (mIdx % 32); + int64_t outerMStride = 4 * 4; + + int32_t innerMIdx = (mIdx % (32 * 4)) / 32; + int64_t innerMStride = 4; + + int32_t innerKIdx = (kIdx % 4); + int64_t innerKStride = 1; + + // Compute the global offset. + int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + + outerMIdx * outerMStride + innerMIdx * innerMStride + + innerKIdx * innerKStride; + + return reinterpret_cast(SFout) + SFOffset; + } +#endif + return nullptr; +} + +// Define a 16 bytes packed data type. +template +struct PackedVec { + typename TypeConverter::Type elts[4]; +}; + +template <> +struct PackedVec<__nv_fp8_e4m3> { + __nv_fp8x2_e4m3 elts[8]; +}; + +template +__inline__ __device__ PackedVec compute_silu(PackedVec& vec, + PackedVec& vec2) { + PackedVec result; +#pragma unroll + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) { + if constexpr (std::is_same_v) { + half2 val(0.5f, 0.5f); + half2 t0 = __hmul2(vec.elts[i], val); + half2 t1 = __hfma2(h2tanh(t0), val, val); + half2 t2 = __hmul2(vec.elts[i], t1); + result.elts[i] = __hmul2(t2, vec2.elts[i]); + } else { + __nv_bfloat162 val(0.5f, 0.5f); + __nv_bfloat162 t0 = __hmul2(vec.elts[i], val); + __nv_bfloat162 t1 = __hfma2(h2tanh(t0), val, val); + __nv_bfloat162 t2 = __hmul2(vec.elts[i], t1); + result.elts[i] = __hmul2(t2, vec2.elts[i]); + } + } + return result; +} + +// Quantizes the provided PackedVec into the uint32_t output +template +__device__ uint32_t silu_and_cvt_warp_fp16_to_fp4(PackedVec& vec, + PackedVec& vec2, + float SFScaleVal, + uint8_t* SFout) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + PackedVec out_silu = compute_silu(vec, vec2); + // Get absolute maximum values among the local 8 values. + auto localMax = __habs2(out_silu.elts[0]); + + // Local maximum value. + #pragma unroll + for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + localMax = __hmax2(localMax, __habs2(out_silu.elts[i])); + } + + // Get the absolute maximum among all 16 values (two threads). + localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); + // Get the final absolute maximum values. + float vecMax = float(__hmax(localMax.x, localMax.y)); + + // Get the SF (max value of the vector / max value of e2m1). + // maximum value of e2m1 = 6.0. + // TODO: use half as compute data type. + float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); + // 8 bits representation of the SF. + uint8_t fp8SFVal; + // Write the SF to global memory (STG.8). + if constexpr (UE8M0_SF) { + // Extract the 8 exponent bits from float32. + // float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits. + uint32_t tmp = reinterpret_cast(SFValue) >> 23; + fp8SFVal = tmp & 0xff; + // Convert back to fp32. + reinterpret_cast(SFValue) = tmp << 23; + } else { + // Here SFValue is always positive, so E4M3 is the same as UE4M3. + __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); + reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp; + // Convert back to fp32. + SFValue = float(tmp); + } + // Get the output scale. + // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * + // reciprocal(SFScaleVal)) + float outputScale = + SFValue != 0 ? reciprocal_approximate_ftz( + SFValue * reciprocal_approximate_ftz(SFScaleVal)) + : 0.0f; + + if (SFout) { + // Write the SF to global memory (STG.8). + *SFout = fp8SFVal; + } + + // Convert the input to float. + float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; + + #pragma unroll + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + if constexpr (std::is_same_v) { + fp2Vals[i] = __half22float2(out_silu.elts[i]); + } else { + fp2Vals[i] = __bfloat1622float2(out_silu.elts[i]); + } + fp2Vals[i].x *= outputScale; + fp2Vals[i].y *= outputScale; + } + + // Convert to e2m1 values. + uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals); + + // Write the e2m1 values to global memory. + return e2m1Vec; +#else + return 0; +#endif +} + +// Use UE4M3 by default. +template +__global__ void +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__launch_bounds__(1024, 4) silu_and_cvt_fp16_to_fp4( +#else +silu_and_cvt_fp16_to_fp4( +#endif + int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, + uint32_t* out, uint32_t* SFout) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + using PackedVec = PackedVec; + static constexpr int CVT_FP4_NUM_THREADS_PER_SF = + (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); + static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, + "Vec size is not matched."); + + // Get the global scaling factor, which will be applied to the SF. + // Note SFScale is the same as next GEMM's alpha, which is + // (448.f / (Alpha_A / 6.f)). + float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0]; + + // Input tensor row/col loops. + for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { + for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD; + colIdx += blockDim.x) { + int64_t inOffset = + rowIdx * (numCols * 2 / CVT_FP4_ELTS_PER_THREAD) + colIdx; + int64_t inOffset2 = rowIdx * (numCols * 2 / CVT_FP4_ELTS_PER_THREAD) + + numCols / CVT_FP4_ELTS_PER_THREAD + colIdx; + PackedVec in_vec = reinterpret_cast(in)[inOffset]; + PackedVec in_vec2 = reinterpret_cast(in)[inOffset2]; + + // Get the output tensor offset. + // Same as inOffset because 8 elements are packed into one uint32_t. + int64_t outOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; + ; + auto& out_pos = out[outOffset]; + + auto sf_out = + cvt_quant_to_fp4_get_sf_out_offset( + rowIdx, colIdx, numCols, SFout); + + out_pos = silu_and_cvt_warp_fp16_to_fp4( + in_vec, in_vec2, SFScaleVal, sf_out); + } + } +#endif +} + +} // namespace vllm + +void silu_and_mul_nvfp4_quant(torch::Tensor& output, // [..., d] + torch::Tensor& output_sf, + torch::Tensor& input, // [..., 2 * d] + torch::Tensor& input_sf) { + TORCH_CHECK(input.dtype() == torch::kFloat16 || + input.dtype() == torch::kBFloat16); + int32_t m = input.size(0); + int32_t n = input.size(1) / 2; + TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16."); + int multiProcessorCount = + get_device_attribute(cudaDevAttrMultiProcessorCount, -1); + auto input_sf_ptr = static_cast(input_sf.data_ptr()); + auto sf_out = static_cast(output_sf.data_ptr()); + auto output_ptr = static_cast(output.data_ptr()); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); + dim3 block(std::min(int(n / ELTS_PER_THREAD), 1024)); + int const numBlocksPerSM = 2048 / block.x; + dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM)); + VLLM_DISPATCH_HALF_TYPES( + input.scalar_type(), "act_and_mul_quant_kernel", [&] { + auto input_ptr = reinterpret_cast(input.data_ptr()); + VLLM_DISPATCH_BYTE_TYPES( + output.scalar_type(), "fused_act_and_mul_quant_kernel_nvfp4_type", + [&] { + vllm::silu_and_cvt_fp16_to_fp4 + <<>>( + m, n, input_ptr, input_sf_ptr, + reinterpret_cast(output_ptr), + reinterpret_cast(sf_out)); + }); + }); +} diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 7ae054dc19f..56626a02c02 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -115,6 +115,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()"); ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant); +#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ + (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) + ops.def( + "silu_and_mul_nvfp4_quant(Tensor! result, Tensor! result_block_scale, " + "Tensor input, Tensor input_global_scale) -> ()"); + ops.impl("silu_and_mul_nvfp4_quant", torch::kCUDA, &silu_and_mul_nvfp4_quant); +#endif + ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()"); ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu); @@ -686,6 +694,16 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { " Tensor scale) -> ()"); cache_ops.impl("concat_and_cache_mla", torch::kCUDA, &concat_and_cache_mla); + cache_ops.def( + "cp_fused_concat_and_cache_mla(Tensor kv_c, Tensor k_pe," + " Tensor cp_local_token_select_indices," + " Tensor! kv_cache," + " Tensor slot_mapping," + " str kv_cache_dtype," + " Tensor scale) -> ()"); + cache_ops.impl("cp_fused_concat_and_cache_mla", torch::kCUDA, + &cp_fused_concat_and_cache_mla); + // Convert the key and value cache to fp8 data type. cache_ops.def( "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, " @@ -702,6 +720,11 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { " Tensor scale, Tensor? seq_starts) -> ()"); cache_ops.impl("gather_and_maybe_dequant_cache", torch::kCUDA, &gather_and_maybe_dequant_cache); + + cache_ops.def( + "cp_gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, " + "Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()"); + cache_ops.impl("cp_gather_cache", torch::kCUDA, &cp_gather_cache); } TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) { diff --git a/docs/community/meetups.md b/docs/community/meetups.md index 61ea44220ad..221a7bd9621 100644 --- a/docs/community/meetups.md +++ b/docs/community/meetups.md @@ -3,6 +3,7 @@ We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below: - [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/pDmAXHcN7Iqc8sUKgJgGtg), August 23rd 2025. [[Slides]](https://drive.google.com/drive/folders/1OvLx39wnCGy_WKq8SiVKf7YcxxYI3WCH) +- [vLLM Korea Meetup](https://luma.com/cgcgprmh), August 19th 2025. [[Slides]](https://drive.google.com/file/d/1bcrrAE1rxUgx0mjIeOWT6hNe2RefC5Hm/view). - [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/dgkWg1WFpWGO2jCdTqQHxA), August 2nd 2025. [[Slides]](https://drive.google.com/drive/folders/1Pid6NSFLU43DZRi0EaTcPgXsAzDvbBqF) [[Recording]](https://www.chaspark.com/#/live/1166916873711665152). - [NYC vLLM Meetup](https://lu.ma/c1rqyf1f), May 7th, 2025. [[Slides]](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing) - [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day), April 3rd 2025. [[Slides]](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing). diff --git a/docs/configuration/conserving_memory.md b/docs/configuration/conserving_memory.md index 058eba5fe0b..efda9c8e019 100644 --- a/docs/configuration/conserving_memory.md +++ b/docs/configuration/conserving_memory.md @@ -86,7 +86,7 @@ llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct", If you run out of CPU RAM, try the following options: -- (Multi-modal models only) you can set the size of multi-modal processor cache by setting `mm_processor_cache_gb` engine argument (default 4 GiB per API process + 4 GiB per engine core process) +- (Multi-modal models only) you can set the size of multi-modal cache by setting `mm_processor_cache_gb` engine argument (default 4 GiB). - (CPU backend only) you can set the size of KV cache using `VLLM_CPU_KVCACHE_SPACE` environment variable (default 4 GiB). ## Multi-modal input limits diff --git a/docs/configuration/optimization.md b/docs/configuration/optimization.md index bb47e1b90f0..2d8cdcc11fa 100644 --- a/docs/configuration/optimization.md +++ b/docs/configuration/optimization.md @@ -164,15 +164,18 @@ llm = LLM( ) ``` -!! important +!!! important Batch-level DP is not to be confused with API request-level DP (which is instead controlled by `data_parallel_size`). -The availability of batch-level DP is based on model implementation. -Currently, the following models support `mm_encoder_tp_mode="data"`: +Batch-level DP needs to be implemented on a per-model basis, +and enabled by setting `supports_encoder_tp_data = True` in the model class. +Regardless, you need to set `mm_encoder_tp_mode="data"` in engine arguments to use this feature. + +Known supported models: - Llama4 () -- MiniCPM-V-4 () +- MiniCPM-V-2.5 or above (, ) - Qwen2.5-VL () - Step3 () @@ -204,20 +207,33 @@ vllm serve Qwen/Qwen2.5-VL-3B-Instruct --api-server-count 4 -dp 2 to avoid CPU resource exhaustion. !!! note - [Multi-modal processor cache](#processor-cache) is disabled when API server scale-out is enabled - because it requires a one-to-one correspondence between API and engine core processes. + API server scale-out disables [multi-modal IPC caching](#ipc-caching) + because it requires a one-to-one correspondance between API and engine core processes. -## Multi-Modal Caching + This does not impact [multi-modal processor caching](#processor-caching). -### Processor Cache +## Multi-Modal Caching -By default, the multi-modal processor cache is enabled to avoid repeatedly processing -the same multi-modal inputs via Hugging Face `AutoProcessor`, +Multi-modal caching avoids repeated transfer or processing of the same multi-modal data, which commonly occurs in multi-turn conversations. -You can adjust the size of the cache by setting the value of `mm_processor_cache_gb` -(default 4 GiB per API process + 4 GiB per engine core process). -If you do not benefit much from the cache, you can disable it completely via `mm_processor_cache_gb=0`. +### Processor Caching + +Multi-modal processor caching is automatically enabled +to avoid repeatedly processing the same multi-modal inputs in `BaseMultiModalProcessor`. + +### IPC Caching + +Multi-modal IPC caching is automatically enabled when +there is a one-to-one correspondance between API (`P0`) and engine core (`P1`) processes, +to avoid repeatedly transferring the same multi-modal inputs between them. + +### Configuration + +You can adjust the size of the cache by setting the value of `mm_processor_cache_gb` (default 4 GiB). + +If you do not benefit much from the cache, you can disable both IPC +and processor caching completely via `mm_processor_cache_gb=0`. Examples: @@ -230,3 +246,16 @@ llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", mm_processor_cache_gb=0) ``` + +### Cache Placement + +Based on the configuration, the content of the multi-modal caches on `P0` and `P1` are as follows: + +| Processor Caching | IPC Caching | `P0` Cache | `P1` Cache | Max. Memory | +|-------------------|-------------|------------|------------|-------------| +| ✅ | ✅ | K | K + V | `mm_processor_cache_gb * data_parallel_size` | +| ✅ | ❌ | K + V | N/A | `mm_processor_cache_gb * api_server_count` | +| ❌ | ❌ | N/A | N/A | `0` | + +K: Stores the hashes of multi-modal items +V: Stores the processed tensor data of multi-modal items diff --git a/docs/configuration/tpu.md b/docs/configuration/tpu.md index ac2b6baffd1..e456077e049 100644 --- a/docs/configuration/tpu.md +++ b/docs/configuration/tpu.md @@ -45,30 +45,30 @@ This initial compilation time ranges significantly and is impacted by many of th ### Optimize based on your data -#### max model len vs. most model len +#### max-model-len vs. most-model-len ![most_model_len](../assets/design/tpu/most_model_len.png) -If most of your requests are shorter than the maximum model length but you still need to accommodate occasional longer requests, setting a high maximum model length can negatively impact performance. In these cases, you can try introducing most model len by specifying the `VLLM_TPU_MOST_MODEL_LEN` environment variable. +If most of your requests are shorter than the maximum model length but you still need to accommodate occasional longer requests, setting a high maximum model length can negatively impact performance. In these cases, you can try introducing most-model-len by specifying the `VLLM_TPU_MOST_MODEL_LEN` environment variable. For example, 1% requests are 32k length and 99% requests are 2k length. You can pass 32k into `--max-model-len 32768` and use `VLLM_TPU_MOST_MODEL_LEN=2048`. -The requests get subdivided into max-model-len and most-model-len categories, for the latter category, we can gain better performance since the server can process more requests at a time. +The requests get subdivided into max-model-len and most-model-len categories, for the latter category, you can gain better performance since the server can process more requests at a time. #### Padding -For online serving with latency requirements, consider switching to bucket padding by setting the `VLLM_TPU_BUCKET_PADDING_GAP` environment variable. Because of the layout of the TPU, try using increments of 128: 128, 256, etc. +For online serving with latency requirements, consider switching to bucket padding by setting the `VLLM_TPU_BUCKET_PADDING_GAP` environment variable. Because of the layout of the TPU, try using increments of 128 (e.g., 128, 256, etc.) -The server pads the requests into fixed lengths before sending them to the model to avoid recompilation. To read more about tpu padding, see [here](https://cloud.google.com/tpu/docs/performance-guide#xla-efficiencies). Currently, there are 2 ways to pad the requests: +The server pads the requests into fixed lengths before sending them to the model to avoid recompilation. To read more about TPU padding, see [here](https://cloud.google.com/tpu/docs/performance-guide#xla-efficiencies). Currently, there are 2 ways to pad the requests: -1) the default exponential padding (pad to the nearest power of 2) -2) bucket padding (pad to the nearest linearly increasing bucket). +1. the default exponential padding (pad to the nearest power of 2) +2. bucket padding (pad to the nearest linearly increasing bucket). When using bucket padding, the buckets start from 16, end at max_model_len, and increment by `VLLM_TPU_BUCKET_PADDING_GAP`. For example, max_model_len=512, padding_gap=64, the buckets will be [16, 32, 64, 128, 192, 256, 320, 384, 448, 512]. -The fewer tokens we pad, the less unnecessary computation TPU does, the better performance we can get. For example, if num_tokens=300, with exponential padding, we pad to 512, with the bucket_padding above, we pad to 320. +The fewer tokens you pad, the less unnecessary computation TPU does, the better performance you can get. For example, if num_tokens=300, with exponential padding, you pad to 512, with the bucket_padding above, you pad to 320. However, you need to be careful to choose the padding gap. If the gap is too small, it means the number of buckets is large, leading to increased warmup (precompile) time and higher memory to store the compiled graph. Too many compiled graphs may lead to HBM OOM. Conversely, an overly large gap yields no performance improvement compared to the default exponential padding. diff --git a/docs/contributing/ci/update_pytorch_version.md b/docs/contributing/ci/update_pytorch_version.md index 7ef22d6f8c3..3dae62dd5d9 100644 --- a/docs/contributing/ci/update_pytorch_version.md +++ b/docs/contributing/ci/update_pytorch_version.md @@ -90,7 +90,7 @@ address the long build time at its source, the current workaround is to set `VLL to a custom branch provided by @khluu (`VLLM_CI_BRANCH=khluu/use_postmerge_q`) when manually triggering a build on Buildkite. This branch accomplishes two things: -1. Increase the timeout limit to 10 hours so that the build doesn't timeout. +1. Increase the timeout limit to 10 hours so that the build doesn't time out. 2. Allow the compiled artifacts to be written to the vLLM sccache S3 bucket to warm it up so that future builds are faster. diff --git a/docs/contributing/model/basic.md b/docs/contributing/model/basic.md index 21b1f21d60a..aafdb1058e0 100644 --- a/docs/contributing/model/basic.md +++ b/docs/contributing/model/basic.md @@ -121,3 +121,31 @@ To support a model with interleaving sliding windows, we need to take care of th - In the modeling code, parse the correct sliding window value for every layer, and pass it to the attention layer's `per_layer_sliding_window` argument. For reference, check [this line](https://github.com/vllm-project/vllm/blob/996357e4808ca5eab97d4c97c7d25b3073f46aab/vllm/model_executor/models/llama.py#L171). With these two steps, interleave sliding windows should work with the model. + +### How to support models that use Mamba? + +We consider 3 different scenarios: + +1. Models that use Mamba layers (either Mamba-1 or Mamba-2) but do not use attention layers. +2. Models that combine Mamba layers (either Mamba-1 or Mamba-2) together with attention layers. +3. Models that combine Mamba-like mechanisms (e.g., Linear Attention, ShortConv) together with attention layers. + +For case (1), we recommend looking at the implementation of [`MambaForCausalLM`](gh-file:vllm/model_executor/models/mamba.py) (for Mamba-1) or [`Mamba2ForCausalLM`](gh-file:vllm/model_executor/models/mamba2.py) (for Mamba-2) as a reference. +The model should inherit protocol `IsAttentionFree` and also implement class methods `get_mamba_state_dtype_from_config` and `get_mamba_state_shape_from_config` to calculate the state shapes and data types from the config. +For the mamba layers themselves, please use the [`MambaMixer`](gh-file:vllm/model_executor/layers/mamba/mamba_mixer.py) (for Mamba-1) or [`MambaMixer2`](gh-file:vllm/model_executor/layers/mamba/mamba_mixer2.py) (for Mamba-2) classes. +Please *do not* use the `MambaCacheManager` (deprecated in V1) or replicate any of the V0-specific code paths in the existing model implementations. +V0-only classes and code will be removed in the very near future. +The model should also be added to the `MODELS_CONFIG_MAP` dictionary in to ensure that the runtime defaults are optimized. + +For case (2), we recommend using as a reference the implementation of [`JambaForCausalLM`](gh-file:vllm/model_executor/models/jamba.py) (for an example of a model that uses Mamba-1 and attention together) or [`BambaForCausalLM`](gh-file:vllm/model_executor/models/bamba.py) (for an example of a model that uses Mamba-2 and attention together). +These models should follow the same instructions as case (1), but they should inherit protocol `IsHybrid` (instead of `IsAttentionFree`) and it is *not* necessary to add them to the `MODELS_CONFIG_MAP` (their runtime defaults will be inferred from the protocol). + +For case (3), we recommend looking at the implementation of [`MiniMaxText01ForCausalLM`](gh-file:vllm/model_executor/models/minimax_text_01.py) or [`Lfm2ForCausalLM`](gh-file:vllm/model_executor/models/lfm2.py) as a reference, which use custom "mamba-like" layers `MiniMaxText01LinearAttention` and `ShortConv` respectively. +Please follow the same guidelines as case (2) for implementing these models. +We use "mamba-like" to refer to layers that posses a state that is updated in-place, rather than being appended-to (like KV cache for attention). +For implementing new custom mamba-like layers, one should inherit from `MambaBase` and implement the methods `get_state_dtype`, `get_state_shape` to calculate the data types and state shapes at runtime, as well as `mamba_type` and `get_attn_backend`. +It is also necessary to implement the "attention meta-data" class which handles the meta-data that is common across all layers. +Please see [`LinearAttentionMetadata`](gh-file:vllm/v1/attention/backends/linear_attn.py) or [`ShortConvAttentionMetadata`](gh-file:v1/attention/backends/short_conv_attn.py) for examples of this. +Finally, if one wants to support torch compile and CUDA graphs, it necessary to wrap the call to the mamba-like layer inside a custom op and register it. +Please see the calls to `direct_register_custom_op` in or for examples of this. +The new custom op should then be added to the list `_attention_ops` in to ensure that piecewise CUDA graphs works as intended. diff --git a/docs/contributing/model/multimodal.md b/docs/contributing/model/multimodal.md index 76d0f067fd4..dc742c8fcf2 100644 --- a/docs/contributing/model/multimodal.md +++ b/docs/contributing/model/multimodal.md @@ -855,7 +855,7 @@ Examples: ### Custom HF processor -Some models don't define a HF processor class on HF Hub. In that case, you can define a custom HF processor that has the same call signature as HF processors and pass it to [_call_hf_processor][vllm.multimodal.processing.BaseMultiModalProcessor._call_hf_processor]. +Some models don't define an HF processor class on HF Hub. In that case, you can define a custom HF processor that has the same call signature as HF processors and pass it to [_call_hf_processor][vllm.multimodal.processing.BaseMultiModalProcessor._call_hf_processor]. Examples: diff --git a/docs/deployment/frameworks/lobe-chat.md b/docs/deployment/frameworks/lobe-chat.md index e3e7dbe6e1e..8ecd1484eab 100644 --- a/docs/deployment/frameworks/lobe-chat.md +++ b/docs/deployment/frameworks/lobe-chat.md @@ -6,6 +6,6 @@ Supports speech-synthesis, multi-modal, and extensible (function call) plugin sy One-click FREE deployment of your private OpenAI ChatGPT/Claude/Gemini/Groq/Ollama chat application. -It supports vLLM as a AI model provider to efficiently serve large language models. +It supports vLLM as an AI model provider to efficiently serve large language models. For details, see the tutorial [Using vLLM in LobeChat](https://lobehub.com/docs/usage/providers/vllm). diff --git a/docs/deployment/k8s.md b/docs/deployment/k8s.md index cad801a4312..ca23e0b9fd8 100644 --- a/docs/deployment/k8s.md +++ b/docs/deployment/k8s.md @@ -380,7 +380,7 @@ INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit) ### Startup Probe or Readiness Probe Failure, container log contains "KeyboardInterrupt: terminated" -If the startup or readiness probe failureThreshold is too low for the time needed to startup the server, Kubernetes scheduler will kill the container. A couple of indications that this has happened: +If the startup or readiness probe failureThreshold is too low for the time needed to start up the server, Kubernetes scheduler will kill the container. A couple of indications that this has happened: 1. container log contains "KeyboardInterrupt: terminated" 2. `kubectl get events` shows message `Container $NAME failed startup probe, will be restarted` diff --git a/docs/design/fused_moe_modular_kernel.md b/docs/design/fused_moe_modular_kernel.md index 202e9c1caf1..b03483d1c9b 100644 --- a/docs/design/fused_moe_modular_kernel.md +++ b/docs/design/fused_moe_modular_kernel.md @@ -138,7 +138,7 @@ Typically a FusedMoEPrepareAndFinalize type is backed by an All2All Dispatch & C #### Step 1: Add an All2All manager -The purpose of the All2All Manager is to setup the All2All kernel implementations. The `FusedMoEPrepareAndFinalize` implementations typically fetch a kernel-implementation "handle" from the All2All Manager to invoke the Dispatch and Combine functions. Please look at the All2All Manager implementations [here](gh-file:vllm/distributed/device_communicators/all2all.py). +The purpose of the All2All Manager is to set up the All2All kernel implementations. The `FusedMoEPrepareAndFinalize` implementations typically fetch a kernel-implementation "handle" from the All2All Manager to invoke the Dispatch and Combine functions. Please look at the All2All Manager implementations [here](gh-file:vllm/distributed/device_communicators/all2all.py). #### Step 2: Add a FusedMoEPrepareAndFinalize Type diff --git a/docs/design/metrics.md b/docs/design/metrics.md index b24364247b3..90b2fd32f29 100644 --- a/docs/design/metrics.md +++ b/docs/design/metrics.md @@ -99,11 +99,11 @@ http_request_duration_seconds_count{handler="/v1/completions",method="POST"} 201 ### Multi-process Mode -In v0, metrics are collected in the engine core process and we use multi-process mode to make them available in the API server process. See . +In v0, metrics are collected in the engine core process and we use multiprocess mode to make them available in the API server process. See . ### Built in Python/Process Metrics -The following metrics are supported by default by `prometheus_client`, but they are not exposed when multi-process mode is used: +The following metrics are supported by default by `prometheus_client`, but they are not exposed when multiprocess mode is used: - `python_gc_objects_collected_total` - `python_gc_objects_uncollectable_total` diff --git a/docs/features/lora.md b/docs/features/lora.md index 668460a368a..db794b2ebd7 100644 --- a/docs/features/lora.md +++ b/docs/features/lora.md @@ -52,7 +52,7 @@ Check out for an exa ## Serving LoRA Adapters LoRA adapted models can also be served with the Open-AI compatible vLLM server. To do so, we use -`--lora-modules {name}={path} {name}={path}` to specify each LoRA module when we kickoff the server: +`--lora-modules {name}={path} {name}={path}` to specify each LoRA module when we kick off the server: ```bash vllm serve meta-llama/Llama-2-7b-hf \ diff --git a/docs/features/multimodal_inputs.md b/docs/features/multimodal_inputs.md index 9d51f9cf52f..206ab7a4687 100644 --- a/docs/features/multimodal_inputs.md +++ b/docs/features/multimodal_inputs.md @@ -13,6 +13,41 @@ To input multi-modal data, follow this schema in [vllm.inputs.PromptType][]: - `prompt`: The prompt should follow the format that is documented on HuggingFace. - `multi_modal_data`: This is a dictionary that follows the schema defined in [vllm.multimodal.inputs.MultiModalDataDict][]. +### Stable UUIDs for Caching (multi_modal_uuids) + +When using multi-modal inputs, vLLM normally hashes each media item by content to enable caching across requests. You can optionally pass `multi_modal_uuids` to provide your own stable IDs for each item so caching can reuse work across requests without rehashing the raw content. + +??? code + + ```python + from vllm import LLM + from PIL import Image + + # Qwen2.5-VL example with two images + llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct") + + prompt = "USER: \nDescribe the differences.\nASSISTANT:" + img_a = Image.open("/path/to/a.jpg") + img_b = Image.open("/path/to/b.jpg") + + outputs = llm.generate({ + "prompt": prompt, + "multi_modal_data": {"image": [img_a, img_b]}, + # Provide stable IDs for caching. + # Requirements (matched by this example): + # - Include every modality present in multi_modal_data. + # - For lists, provide the same number of entries. + # - Use None to fall back to content hashing for that item. + "multi_modal_uuids": {"image": ["sku-1234-a", None]}, + }) + + for o in outputs: + print(o.outputs[0].text) + ``` + +!!! warning + If both multimodal processor caching and prefix caching are disabled, user-provided `multi_modal_uuids` are ignored. + ### Image Inputs You can pass a single image to the `'image'` field of the multi-modal dictionary, as shown in the following examples: diff --git a/docs/features/reasoning_outputs.md b/docs/features/reasoning_outputs.md index 04b943efbbb..d9a785eb73f 100644 --- a/docs/features/reasoning_outputs.md +++ b/docs/features/reasoning_outputs.md @@ -143,7 +143,7 @@ OpenAI Python client library does not officially support `reasoning_content` att print(content, end="", flush=True) ``` -Remember to check whether the `reasoning_content` exists in the response before accessing it. You could checkout the [example](https://github.com/vllm-project/vllm/blob/main/examples/online_serving/openai_chat_completion_with_reasoning_streaming.py). +Remember to check whether the `reasoning_content` exists in the response before accessing it. You could check out the [example](https://github.com/vllm-project/vllm/blob/main/examples/online_serving/openai_chat_completion_with_reasoning_streaming.py). ## Tool Calling diff --git a/docs/features/structured_outputs.md b/docs/features/structured_outputs.md index 8a934d406f3..0d6294a5fdd 100644 --- a/docs/features/structured_outputs.md +++ b/docs/features/structured_outputs.md @@ -205,7 +205,7 @@ This section covers the OpenAI beta wrapper over the `client.chat.completions.cr At the time of writing (`openai==1.54.4`), this is a "beta" feature in the OpenAI client library. Code reference can be found [here](https://github.com/openai/openai-python/blob/52357cff50bee57ef442e94d78a0de38b4173fc2/src/openai/resources/beta/chat/completions.py#L100-L104). -For the following examples, vLLM was setup using `vllm serve meta-llama/Llama-3.1-8B-Instruct` +For the following examples, vLLM was set up using `vllm serve meta-llama/Llama-3.1-8B-Instruct` Here is a simple example demonstrating how to get structured output using Pydantic models: diff --git a/docs/getting_started/installation/README.md b/docs/getting_started/installation/README.md index 0ee680f5c68..8a658b7a910 100644 --- a/docs/getting_started/installation/README.md +++ b/docs/getting_started/installation/README.md @@ -12,7 +12,6 @@ vLLM supports the following hardware platforms: - [Apple silicon](cpu.md#apple-silicon) - [IBM Z (S390X)](cpu.md#ibm-z-s390x) - [Google TPU](google_tpu.md) -- [Intel Gaudi](intel_gaudi.md) - [AWS Neuron](aws_neuron.md) ## Hardware Plugins diff --git a/docs/getting_started/installation/aws_neuron.md b/docs/getting_started/installation/aws_neuron.md index b8bd76bd5bc..ff2500f0352 100644 --- a/docs/getting_started/installation/aws_neuron.md +++ b/docs/getting_started/installation/aws_neuron.md @@ -140,8 +140,8 @@ Alternatively, users can directly call the NxDI library to trace and compile you - `NEURON_COMPILED_ARTIFACTS`: set this environment variable to point to your pre-compiled model artifacts directory to avoid compilation time upon server initialization. If this variable is not set, the Neuron module will perform compilation and save the - artifacts under `neuron-compiled-artifacts/{unique_hash}/` sub-directory in the model path. If this environment variable is set, - but the directory does not exist, or the contents are invalid, Neuron will also fallback to a new compilation and store the artifacts + artifacts under `neuron-compiled-artifacts/{unique_hash}/` subdirectory in the model path. If this environment variable is set, + but the directory does not exist, or the contents are invalid, Neuron will also fall back to a new compilation and store the artifacts under this specified path. - `NEURON_CONTEXT_LENGTH_BUCKETS`: Bucket sizes for context encoding. (Only applicable to `transformers-neuronx` backend). - `NEURON_TOKEN_GEN_BUCKETS`: Bucket sizes for token generation. (Only applicable to `transformers-neuronx` backend). diff --git a/docs/getting_started/installation/cpu.md b/docs/getting_started/installation/cpu.md index e76ec35e1ed..7f0ecb2bc0b 100644 --- a/docs/getting_started/installation/cpu.md +++ b/docs/getting_started/installation/cpu.md @@ -96,6 +96,7 @@ Currently, there are no pre-built CPU wheels. - `VLLM_CPU_KVCACHE_SPACE`: specify the KV Cache size (e.g, `VLLM_CPU_KVCACHE_SPACE=40` means 40 GiB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users. Default value is `0`. - `VLLM_CPU_OMP_THREADS_BIND`: specify the CPU cores dedicated to the OpenMP threads, can be set as CPU id lists or `auto` (by default). For example, `VLLM_CPU_OMP_THREADS_BIND=0-31` means there will be 32 OpenMP threads bound on 0-31 CPU cores. `VLLM_CPU_OMP_THREADS_BIND=0-31|32-63` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores. By setting to `auto`, the OpenMP threads of each rank are bound to the CPU cores in each NUMA node respectively. - `VLLM_CPU_NUM_OF_RESERVED_CPU`: specify the number of CPU cores which are not dedicated to the OpenMP threads for each rank. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. Default value is `None`. If the value is not set and use `auto` thread binding, no CPU will be reserved for `world_size == 1`, 1 CPU per rank will be reserved for `world_size > 1`. +- `CPU_VISIBLE_MEMORY_NODES`: specify visible NUMA memory nodes for vLLM CPU workers, similar to ```CUDA_VISIBLE_DEVICES```. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. The variable provides more control for the auto thread-binding feature, such as masking nodes and changing nodes binding sequence. - `VLLM_CPU_MOE_PREPACK` (x86 only): whether to use prepack for MoE layer. This will be passed to `ipex.llm.modules.GatedMLPMOE`. Default is `1` (True). On unsupported CPUs, you might need to set this to `0` (False). - `VLLM_CPU_SGL_KERNEL` (x86 only, Experimental): whether to use small-batch optimized kernels for linear layer and MoE layer, especially for low-latency requirements like online serving. The kernels require AMX instruction set, BFloat16 weight type and weight shapes divisible by 32. Default is `0` (False). @@ -179,7 +180,7 @@ Inference batch size is an important parameter for the performance. Larger batch - Offline Inference: `256 * world_size` - Online Serving: `128 * world_size` -vLLM CPU supports tensor parallel (TP) and pipeline parallel (PP) to leverage multiple CPU sockets and memory nodes. For more details of tuning TP and PP, please refer to [Optimization and Tuning](../../configuration/optimization.md). For vLLM CPU, it is recommend to use TP and PP together if there are enough CPU sockets and memory nodes. +vLLM CPU supports data parallel (DP), tensor parallel (TP) and pipeline parallel (PP) to leverage multiple CPU sockets and memory nodes. For more details of tuning DP, TP and PP, please refer to [Optimization and Tuning](../../configuration/optimization.md). For vLLM CPU, it is recommend to use DP, TP and PP together if there are enough CPU sockets and memory nodes. ### Which quantization configs does vLLM CPU support? diff --git a/docs/getting_started/installation/cpu/apple.inc.md b/docs/getting_started/installation/cpu/apple.inc.md index 2828173a76a..124a41adf1a 100644 --- a/docs/getting_started/installation/cpu/apple.inc.md +++ b/docs/getting_started/installation/cpu/apple.inc.md @@ -1,6 +1,6 @@ # --8<-- [start:installation] -vLLM has experimental support for macOS with Apple silicon. For now, users must build from source to natively run on macOS. +vLLM has experimental support for macOS with Apple Silicon. For now, users must build from source to natively run on macOS. Currently the CPU implementation for macOS supports FP32 and FP16 datatypes. diff --git a/docs/getting_started/installation/cpu/x86.inc.md b/docs/getting_started/installation/cpu/x86.inc.md index 6dc6f94249c..f7af259ace6 100644 --- a/docs/getting_started/installation/cpu/x86.inc.md +++ b/docs/getting_started/installation/cpu/x86.inc.md @@ -43,7 +43,7 @@ docker build -f docker/Dockerfile.cpu \ # Launching OpenAI server docker run --rm \ - --privileged=true \ + --security-opt seccomp=unconfined \ --shm-size=4g \ -p 8000:8000 \ -e VLLM_CPU_KVCACHE_SPACE= \ diff --git a/docs/getting_started/installation/gpu/cuda.inc.md b/docs/getting_started/installation/gpu/cuda.inc.md index 69a9842e471..275232e12e0 100644 --- a/docs/getting_started/installation/gpu/cuda.inc.md +++ b/docs/getting_started/installation/gpu/cuda.inc.md @@ -48,7 +48,7 @@ uv pip install https://github.com/vllm-project/vllm/releases/download/v${VLLM_VE #### Install the latest code -LLM inference is a fast-evolving field, and the latest code may contain bug fixes, performance improvements, and new features that are not released yet. To allow users to try the latest code without waiting for the next release, vLLM provides wheels for Linux running on a x86 platform with CUDA 12 for every commit since `v0.5.3`. +LLM inference is a fast-evolving field, and the latest code may contain bug fixes, performance improvements, and new features that are not released yet. To allow users to try the latest code without waiting for the next release, vLLM provides wheels for Linux running on an x86 platform with CUDA 12 for every commit since `v0.5.3`. ```bash uv pip install -U vllm \ diff --git a/docs/getting_started/installation/gpu/rocm.inc.md b/docs/getting_started/installation/gpu/rocm.inc.md index 560883d3caf..80e99d3034d 100644 --- a/docs/getting_started/installation/gpu/rocm.inc.md +++ b/docs/getting_started/installation/gpu/rocm.inc.md @@ -149,7 +149,7 @@ Build a docker image from which setup ROCm **This step is optional as this rocm_base image is usually prebuilt and store at [Docker Hub](https://hub.docker.com/r/rocm/vllm-dev) under tag `rocm/vllm-dev:base` to speed up user experience.** If you choose to build this rocm_base image yourself, the steps are as follows. -It is important that the user kicks off the docker build using buildkit. Either the user put DOCKER_BUILDKIT=1 as environment variable when calling docker build command, or the user needs to setup buildkit in the docker daemon configuration /etc/docker/daemon.json as follows and restart the daemon: +It is important that the user kicks off the docker build using buildkit. Either the user put DOCKER_BUILDKIT=1 as environment variable when calling docker build command, or the user needs to set up buildkit in the docker daemon configuration /etc/docker/daemon.json as follows and restart the daemon: ```json { @@ -170,7 +170,7 @@ DOCKER_BUILDKIT=1 docker build \ #### Build an image with vLLM First, build a docker image from and launch a docker container from the image. -It is important that the user kicks off the docker build using buildkit. Either the user put `DOCKER_BUILDKIT=1` as environment variable when calling docker build command, or the user needs to setup buildkit in the docker daemon configuration /etc/docker/daemon.json as follows and restart the daemon: +It is important that the user kicks off the docker build using buildkit. Either the user put `DOCKER_BUILDKIT=1` as environment variable when calling docker build command, or the user needs to set up buildkit in the docker daemon configuration /etc/docker/daemon.json as follows and restart the daemon: ```bash { diff --git a/docs/getting_started/installation/intel_gaudi.md b/docs/getting_started/installation/intel_gaudi.md deleted file mode 100644 index ff912efec9c..00000000000 --- a/docs/getting_started/installation/intel_gaudi.md +++ /dev/null @@ -1,388 +0,0 @@ -# Intel Gaudi - -This page provides instructions on running vLLM with Intel Gaudi devices. - -!!! warning - There are no pre-built wheels or images for this device, so you must build vLLM from source. - -## Requirements - -- OS: Ubuntu 22.04 LTS -- Python: 3.10 -- Intel Gaudi accelerator -- Intel Gaudi software version 1.18.0 - -Please follow the instructions provided in the -[Gaudi Installation Guide](https://docs.habana.ai/en/latest/Installation_Guide/index.html) -to set up the execution environment. To achieve the best performance, -please follow the methods outlined in the -[Optimizing Training Platform Guide](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_Training_Platform.html). - -## Configure a new environment - -### Environment verification - -To verify that the Intel Gaudi software was correctly installed, run: - -```bash -hl-smi # verify that hl-smi is in your PATH and each Gaudi accelerator is visible -apt list --installed | grep habana # verify that habanalabs-firmware-tools, habanalabs-graph, habanalabs-rdma-core, habanalabs-thunk and habanalabs-container-runtime are installed -pip list | grep habana # verify that habana-torch-plugin, habana-torch-dataloader, habana-pyhlml and habana-media-loader are installed -pip list | grep neural # verify that neural_compressor_pt is installed -``` - -Refer to [Intel Gaudi Software Stack Verification](https://docs.habana.ai/en/latest/Installation_Guide/SW_Verification.html#platform-upgrade) -for more details. - -### Run Docker Image - -It is highly recommended to use the latest Docker image from Intel Gaudi -vault. Refer to the [Intel Gaudi documentation](https://docs.habana.ai/en/latest/Installation_Guide/Bare_Metal_Fresh_OS.html#pull-prebuilt-containers) -for more details. - -Use the following commands to run a Docker image: - -```bash -docker pull vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest -docker run \ - -it \ - --runtime=habana \ - -e HABANA_VISIBLE_DEVICES=all \ - -e OMPI_MCA_btl_vader_single_copy_mechanism=none \ - --cap-add=sys_nice \ - --net=host \ - --ipc=host \ - vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest -``` - -## Set up using Python - -### Pre-built wheels - -Currently, there are no pre-built Intel Gaudi wheels. - -### Build wheel from source - -To build and install vLLM from source, run: - -```bash -git clone https://github.com/vllm-project/vllm.git -cd vllm -pip install -r requirements/hpu.txt -python setup.py develop -``` - -Currently, the latest features and performance optimizations are developed in Gaudi's [vLLM-fork](https://github.com/HabanaAI/vllm-fork) and we periodically upstream them to vLLM main repo. To install latest [HabanaAI/vLLM-fork](https://github.com/HabanaAI/vllm-fork), run the following: - -```bash -git clone https://github.com/HabanaAI/vllm-fork.git -cd vllm-fork -git checkout habana_main -pip install -r requirements/hpu.txt -python setup.py develop -``` - -## Set up using Docker - -### Pre-built images - -Currently, there are no pre-built Intel Gaudi images. - -### Build image from source - -```bash -docker build -f docker/Dockerfile.hpu -t vllm-hpu-env . -docker run \ - -it \ - --runtime=habana \ - -e HABANA_VISIBLE_DEVICES=all \ - -e OMPI_MCA_btl_vader_single_copy_mechanism=none \ - --cap-add=sys_nice \ - --net=host \ - --rm vllm-hpu-env -``` - -!!! tip - If you're observing the following error: `docker: Error response from daemon: Unknown runtime specified habana.`, please refer to "Install Using Containers" section of [Intel Gaudi Software Stack and Driver Installation](https://docs.habana.ai/en/v1.18.0/Installation_Guide/Bare_Metal_Fresh_OS.html). Make sure you have `habana-container-runtime` package installed and that `habana` container runtime is registered. - -## Extra information - -### Supported features - -- [Offline inference](../../serving/offline_inference.md) -- Online serving via [OpenAI-Compatible Server](../../serving/openai_compatible_server.md) -- HPU autodetection - no need to manually select device within vLLM -- Paged KV cache with algorithms enabled for Intel Gaudi accelerators -- Custom Intel Gaudi implementations of Paged Attention, KV cache ops, - prefill attention, Root Mean Square Layer Normalization, Rotary - Positional Encoding -- Tensor parallelism support for multi-card inference -- Inference with [HPU Graphs](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html) - for accelerating low-batch latency and throughput -- Attention with Linear Biases (ALiBi) -- INC quantization - -### Unsupported features - -- Beam search -- LoRA adapters -- AWQ quantization -- Prefill chunking (mixed-batch inferencing) - -### Supported configurations - -The following configurations have been validated to function with -Gaudi2 devices. Configurations that are not listed may or may not work. - -| Model | TP Size| dtype | Sampling | -|-------|--------|--------|----------| -| [meta-llama/Llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b) | 1, 2, 8 | BF16 | Random / Greedy | -| [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) | 1, 2, 8 | BF16 | Random / Greedy | -| [meta-llama/Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B) | 1, 2, 8 | BF16 | Random / Greedy | -| [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) | 1, 2, 8 | BF16 | Random / Greedy | -| [meta-llama/Meta-Llama-3.1-8B](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B) | 1, 2, 8 | BF16 | Random / Greedy | -| [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct) | 1, 2, 8 | BF16 | Random / Greedy | -| [meta-llama/Llama-2-70b](https://huggingface.co/meta-llama/Llama-2-70b) | 8 | BF16 | Random / Greedy | -| [meta-llama/Llama-2-70b-chat-hf](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf) | 8 | BF16 | Random / Greedy | -| [meta-llama/Meta-Llama-3-70B](https://huggingface.co/meta-llama/Meta-Llama-3-70B) | 8 | BF16 | Random / Greedy | -| [meta-llama/Meta-Llama-3-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct) | 8 | BF16 | Random / Greedy | -| [meta-llama/Meta-Llama-3.1-70B](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B) | 8 | BF16 | Random / Greedy | -| [meta-llama/Meta-Llama-3.1-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct) | 8 | BF16 | Random / Greedy | - -## Performance tuning - -### Execution modes - -Currently in vLLM for HPU we support four execution modes, depending on selected HPU PyTorch Bridge backend (via `PT_HPU_LAZY_MODE` environment variable), and `--enforce-eager` flag. - -| `PT_HPU_LAZY_MODE` | `enforce_eager` | execution mode | -|----------------------|-------------------|--------------------| -| 0 | 0 | torch.compile | -| 0 | 1 | PyTorch eager mode | -| 1 | 0 | HPU Graphs | - -!!! warning - In 1.18.0, all modes utilizing `PT_HPU_LAZY_MODE=0` are highly experimental and should be only used for validating functional correctness. Their performance will be improved in the next releases. For obtaining the best performance in 1.18.0, please use HPU Graphs, or PyTorch lazy mode. - -[](){ #gaudi-bucketing-mechanism } - -### Bucketing mechanism - -Intel Gaudi accelerators work best when operating on models with fixed tensor shapes. [Intel Gaudi Graph Compiler](https://docs.habana.ai/en/latest/Gaudi_Overview/Intel_Gaudi_Software_Suite.html#graph-compiler-and-runtime) is responsible for generating optimized binary code that implements the given model topology on Gaudi. In its default configuration, the produced binary code may be heavily dependent on input and output tensor shapes, and can require graph recompilation when encountering differently shaped tensors within the same topology. While the resulting binaries utilize Gaudi efficiently, the compilation itself may introduce a noticeable overhead in end-to-end execution. -In a dynamic inference serving scenario, there is a need to minimize the number of graph compilations and reduce the risk of graph compilation occurring during server runtime. Currently it is achieved by "bucketing" model's forward pass across two dimensions - `batch_size` and `sequence_length`. - -!!! note - Bucketing allows us to reduce the number of required graphs significantly, but it does not handle any graph compilation and device code generation - this is done in warmup and HPUGraph capture phase. - -Bucketing ranges are determined with 3 parameters - `min`, `step` and `max`. They can be set separately for prompt and decode phase, and for batch size and sequence length dimension. These parameters can be observed in logs during vLLM startup: - -```text -INFO 08-01 21:37:59 hpu_model_runner.py:493] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 4], seq:[128, 128, 1024] -INFO 08-01 21:37:59 hpu_model_runner.py:499] Generated 24 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024)] -INFO 08-01 21:37:59 hpu_model_runner.py:504] Decode bucket config (min, step, max_warmup) bs:[1, 128, 4], seq:[128, 128, 2048] -INFO 08-01 21:37:59 hpu_model_runner.py:509] Generated 48 decode buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] -``` - -| Parameter | Description | -|----------------|-----------------------------------------------------------------------------| -| `min` | Determines the lowest value of the bucket. | -| `step` | Determines the interval between buckets. | -| `max` | Determines the upper bound of the bucket. | -| Ramp-up phase | A special handling phase applied between `min` and `step`:
- `min` is multiplied by consecutive powers of two until `step` is reached.
- Minimizes resource wastage for small batch sizes.
- Allows larger padding for larger batches. | - -Example (with ramp-up): - -```text -min = 2, step = 32, max = 64 -=> ramp_up = (2, 4, 8, 16) -=> stable = (32, 64) -=> buckets = ramp_up + stable => (2, 4, 8, 16, 32, 64) -``` - -Example (without ramp-up): - -```text -min = 128, step = 128, max = 512 -=> ramp_up = () -=> stable = (128, 256, 384, 512) -=> buckets = ramp_up + stable => (128, 256, 384, 512) -``` - -In the logged scenario, 24 buckets were generated for prompt (prefill) runs, and 48 buckets for decode runs. Each bucket corresponds to a separate optimized device binary for a given model with specified tensor shapes. Whenever a batch of requests is processed, it is padded across batch and sequence length dimension to the smallest possible bucket. - -!!! warning - If a request exceeds maximum bucket size in any dimension, it will be processed without padding, and its processing may require a graph compilation, potentially significantly increasing end-to-end latency. The boundaries of the buckets are user-configurable via environment variables, and upper bucket boundaries can be increased to avoid such scenario. - -As an example, if a request of 3 sequences, with max sequence length of 412 comes in to an idle vLLM server, it will be padded executed as `(4, 512)` prefill bucket, as `batch_size` (number of sequences) will be padded to 4 (closest batch_size dimension higher than 3), and max sequence length will be padded to 512 (closest sequence length dimension higher than 412). After prefill stage, it will be executed as `(4, 512)` decode bucket and will continue as that bucket until either batch dimension changes (due to request being finished) - in which case it will become a `(2, 512)` bucket, or context length increases above 512 tokens, in which case it will become `(4, 640)` bucket. - -!!! note - Bucketing is transparent to a client -- padding in sequence length dimension is never returned to the client, and padding in batch dimension does not create new requests. - -### Warmup - -Warmup is an optional, but highly recommended step occurring before vLLM server starts listening. It executes a forward pass for each bucket with dummy data. The goal is to pre-compile all graphs and not incur any graph compilation overheads within bucket boundaries during server runtime. Each warmup step is logged during vLLM startup: - -??? console "Logs" - - ```text - INFO 08-01 22:26:47 hpu_model_runner.py:1066] [Warmup][Prompt][1/24] batch_size:4 seq_len:1024 free_mem:79.16 GiB - INFO 08-01 22:26:47 hpu_model_runner.py:1066] [Warmup][Prompt][2/24] batch_size:4 seq_len:896 free_mem:55.43 GiB - INFO 08-01 22:26:48 hpu_model_runner.py:1066] [Warmup][Prompt][3/24] batch_size:4 seq_len:768 free_mem:55.43 GiB - ... - INFO 08-01 22:26:59 hpu_model_runner.py:1066] [Warmup][Prompt][24/24] batch_size:1 seq_len:128 free_mem:55.43 GiB - INFO 08-01 22:27:00 hpu_model_runner.py:1066] [Warmup][Decode][1/48] batch_size:4 seq_len:2048 free_mem:55.43 GiB - INFO 08-01 22:27:00 hpu_model_runner.py:1066] [Warmup][Decode][2/48] batch_size:4 seq_len:1920 free_mem:55.43 GiB - INFO 08-01 22:27:01 hpu_model_runner.py:1066] [Warmup][Decode][3/48] batch_size:4 seq_len:1792 free_mem:55.43 GiB - ... - INFO 08-01 22:27:16 hpu_model_runner.py:1066] [Warmup][Decode][47/48] batch_size:2 seq_len:128 free_mem:55.43 GiB - INFO 08-01 22:27:16 hpu_model_runner.py:1066] [Warmup][Decode][48/48] batch_size:1 seq_len:128 free_mem:55.43 GiB - ``` - -This example uses the same buckets as in the [Bucketing Mechanism][gaudi-bucketing-mechanism] section. Each output line corresponds to execution of a single bucket. When bucket is executed for the first time, its graph is compiled and can be reused later on, skipping further graph compilations. - -!!! tip - Compiling all the buckets might take some time and can be turned off with `VLLM_SKIP_WARMUP=true` environment variable. Keep in mind that if you do that, you may face graph compilations once executing a given bucket for the first time. It is fine to disable warmup for development, but it's highly recommended to enable it in deployment. - -### HPU Graph capture - -[HPU Graphs](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html) are currently the most performant execution method of vLLM on Intel Gaudi. When HPU Graphs are enabled, execution graphs will be traced (recorded) ahead of time (after performing warmup), to be later replayed during inference, significantly reducing host overheads. Recording can take large amounts of memory, which needs to be taken into account when allocating KV cache. Enabling HPU Graphs will impact the number of available KV cache blocks, but vLLM provides user-configurable variables to control memory management. - -When HPU Graphs are being used, they share the common memory pool ("usable memory") as KV cache, determined by `gpu_memory_utilization` flag (`0.9` by default). -Before KV cache gets allocated, model weights are loaded onto the device, and a forward pass of the model is executed on dummy data, to estimate memory usage. -Only after that, `gpu_memory_utilization` flag is utilized - at its default value, will mark 90% of free device memory at that point as usable. -Next, KV cache gets allocated, model is warmed up, and HPU Graphs are captured. -Environment variable `VLLM_GRAPH_RESERVED_MEM` defines the ratio of memory reserved for HPU Graphs capture. -With its default value (`VLLM_GRAPH_RESERVED_MEM=0.1`), 10% of usable memory will be reserved for graph capture (later referred to as "usable graph memory"), and the remaining 90% will be utilized for KV cache. -Environment variable `VLLM_GRAPH_PROMPT_RATIO` determines the ratio of usable graph memory reserved for prefill and decode graphs. By default (`VLLM_GRAPH_PROMPT_RATIO=0.3`), both stages have equal memory constraints. -Lower value corresponds to less usable graph memory reserved for prefill stage, e.g. `VLLM_GRAPH_PROMPT_RATIO=0.2` will reserve 20% of usable graph memory for prefill graphs, and 80% of usable graph memory for decode graphs. - -!!! note - `gpu_memory_utilization` does not correspond to the absolute memory usage across HPU. It specifies the memory margin after loading the model and performing a profile run. If device has 100 GiB of total memory, and 50 GiB of free memory after loading model weights and executing profiling run, `gpu_memory_utilization` at its default value will mark 90% of 50 GiB as usable, leaving 5 GiB of margin, regardless of total device memory. - -User can also configure the strategy for capturing HPU Graphs for prompt and decode stages separately. Strategy affects the order of capturing graphs. There are two strategies implemented: - -- `max_bs` - graph capture queue will be sorted in descending order by their batch sizes. Buckets with equal batch sizes are sorted by sequence length in ascending order (e.g. `(64, 128)`, `(64, 256)`, `(32, 128)`, `(32, 256)`, `(1, 128)`, `(1,256)`), default strategy for decode -- `min_tokens` - graph capture queue will be sorted in ascending order by the number of tokens each graph processes (`batch_size*sequence_length`), default strategy for prompt - -When there's large amount of requests pending, vLLM scheduler will attempt to fill the maximum batch size for decode as soon as possible. When a request is finished, decode batch size decreases. When that happens, vLLM will attempt to schedule a prefill iteration for requests in the waiting queue, to fill the decode batch size to its previous state. This means that in a full load scenario, decode batch size is often at its maximum, which makes large batch size HPU Graphs crucial to capture, as reflected by `max_bs` strategy. On the other hand, prefills will be executed most frequently with very low batch sizes (1-4), which is reflected in `min_tokens` strategy. - -!!! note - `VLLM_GRAPH_PROMPT_RATIO` does not set a hard limit on memory taken by graphs for each stage (prefill and decode). vLLM will first attempt to use up entirety of usable prefill graph memory (usable graph memory * `VLLM_GRAPH_PROMPT_RATIO`) for capturing prefill HPU Graphs, next it will attempt to do the same for decode graphs and usable decode graph memory pool. If one stage is fully captured, and there is unused memory left within usable graph memory pool, vLLM will attempt further graph capture for the other stage, until no more HPU Graphs can be captured without exceeding reserved memory pool. The behavior on that mechanism can be observed in the example below. - -Each described step is logged by vLLM server, as follows (negative values correspond to memory being released): - -??? console "Logs" - - ```text - INFO 08-02 17:37:44 hpu_model_runner.py:493] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 4], seq:[128, 128, 1024] - INFO 08-02 17:37:44 hpu_model_runner.py:499] Generated 24 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024)] - INFO 08-02 17:37:44 hpu_model_runner.py:504] Decode bucket config (min, step, max_warmup) bs:[1, 128, 4], seq:[128, 128, 2048] - INFO 08-02 17:37:44 hpu_model_runner.py:509] Generated 48 decode buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] - INFO 08-02 17:37:52 hpu_model_runner.py:430] Pre-loading model weights on hpu:0 took 14.97 GiB of device memory (14.97 GiB/94.62 GiB used) and 2.95 GiB of host memory (475.2 GiB/1007 GiB used) - INFO 08-02 17:37:52 hpu_model_runner.py:438] Wrapping in HPU Graph took 0 B of device memory (14.97 GiB/94.62 GiB used) and -252 KiB of host memory (475.2 GiB/1007 GiB used) - INFO 08-02 17:37:52 hpu_model_runner.py:442] Loading model weights took in total 14.97 GiB of device memory (14.97 GiB/94.62 GiB used) and 2.95 GiB of host memory (475.2 GiB/1007 GiB used) - INFO 08-02 17:37:54 hpu_worker.py:134] Model profiling run took 504 MiB of device memory (15.46 GiB/94.62 GiB used) and 180.9 MiB of host memory (475.4 GiB/1007 GiB used) - INFO 08-02 17:37:54 hpu_worker.py:158] Free device memory: 79.16 GiB, 39.58 GiB usable (gpu_memory_utilization=0.5), 15.83 GiB reserved for HPUGraphs (VLLM_GRAPH_RESERVED_MEM=0.4), 23.75 GiB reserved for KV cache - INFO 08-02 17:37:54 hpu_executor.py:85] # HPU blocks: 1519, # CPU blocks: 0 - INFO 08-02 17:37:54 hpu_worker.py:190] Initializing cache engine took 23.73 GiB of device memory (39.2 GiB/94.62 GiB used) and -1.238 MiB of host memory (475.4 GiB/1007 GiB used) - INFO 08-02 17:37:54 hpu_model_runner.py:1066] [Warmup][Prompt][1/24] batch_size:4 seq_len:1024 free_mem:55.43 GiB - ... - INFO 08-02 17:38:22 hpu_model_runner.py:1066] [Warmup][Decode][48/48] batch_size:1 seq_len:128 free_mem:55.43 GiB - INFO 08-02 17:38:22 hpu_model_runner.py:1159] Using 15.85 GiB/55.43 GiB of free device memory for HPUGraphs, 7.923 GiB for prompt and 7.923 GiB for decode (VLLM_GRAPH_PROMPT_RATIO=0.3) - INFO 08-02 17:38:22 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][1/24] batch_size:1 seq_len:128 free_mem:55.43 GiB - ... - INFO 08-02 17:38:26 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][11/24] batch_size:1 seq_len:896 free_mem:48.77 GiB - INFO 08-02 17:38:27 hpu_model_runner.py:1066] [Warmup][Graph/Decode][1/48] batch_size:4 seq_len:128 free_mem:47.51 GiB - ... - INFO 08-02 17:38:41 hpu_model_runner.py:1066] [Warmup][Graph/Decode][48/48] batch_size:1 seq_len:2048 free_mem:47.35 GiB - INFO 08-02 17:38:41 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][12/24] batch_size:4 seq_len:256 free_mem:47.35 GiB - INFO 08-02 17:38:42 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][13/24] batch_size:2 seq_len:512 free_mem:45.91 GiB - INFO 08-02 17:38:42 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][14/24] batch_size:1 seq_len:1024 free_mem:44.48 GiB - INFO 08-02 17:38:43 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][15/24] batch_size:2 seq_len:640 free_mem:43.03 GiB - INFO 08-02 17:38:43 hpu_model_runner.py:1128] Graph/Prompt captured:15 (62.5%) used_mem:14.03 GiB buckets:[(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (4, 128), (4, 256)] - INFO 08-02 17:38:43 hpu_model_runner.py:1128] Graph/Decode captured:48 (100.0%) used_mem:161.9 MiB buckets:[(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] - INFO 08-02 17:38:43 hpu_model_runner.py:1206] Warmup finished in 49 secs, allocated 14.19 GiB of device memory - INFO 08-02 17:38:43 hpu_executor.py:91] init_cache_engine took 37.92 GiB of device memory (53.39 GiB/94.62 GiB used) and 57.86 MiB of host memory (475.4 GiB/1007 GiB used) - ``` - -### Recommended vLLM Parameters - -- We recommend running inference on Gaudi 2 with `block_size` of 128 - for BF16 data type. Using default values (16, 32) might lead to - sub-optimal performance due to Matrix Multiplication Engine - under-utilization (see [Gaudi Architecture](https://docs.habana.ai/en/latest/Gaudi_Overview/Gaudi_Architecture.html)). -- For max throughput on Llama 7B, we recommend running with batch size - of 128 or 256 and max context length of 2048 with HPU Graphs enabled. - If you encounter out-of-memory issues, see troubleshooting section. - -### Environment variables - -**Diagnostic and profiling knobs:** - -- `VLLM_PROFILER_ENABLED`: If `true`, enable the high level profiler. Resulting JSON traces can be viewed in [perfetto.habana.ai](https://perfetto.habana.ai/#!/viewer). `false` by default. -- `VLLM_HPU_LOG_STEP_GRAPH_COMPILATION`: If `true`, log graph compilations for each vLLM engine step when any occurs. Highly recommended to use with `PT_HPU_METRICS_GC_DETAILS=1`. `false` by default. -- `VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL`: If `true`, always log graph compilations for each vLLM engine step even if none occurred. `false` by default. -- `VLLM_HPU_LOG_STEP_CPU_FALLBACKS`: If `true`, log CPU fallbacks for each vLLM engine step when any occurs. `false` by default. -- `VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL`: if `true`, always log CPU fallbacks for each vLLM engine step even if none occurred. `false` by default. - -**Performance tuning knobs:** - -- `VLLM_SKIP_WARMUP`: if `true`, warmup will be skipped, `false` by default - -- `VLLM_GRAPH_RESERVED_MEM`: percentage of memory dedicated for HPUGraph capture, `0.1` by default - -- `VLLM_GRAPH_PROMPT_RATIO`: percentage of reserved graph memory dedicated for prompt graphs, `0.3` by default - -- `VLLM_GRAPH_PROMPT_STRATEGY`: strategy determining order of prompt graph capture, `min_tokens` or `max_bs`, `min_tokens` by default - -- `VLLM_GRAPH_DECODE_STRATEGY`: strategy determining order of decode graph capture, `min_tokens` or `max_bs`, `max_bs` by default - -- `VLLM_{phase}_{dim}_BUCKET_{param}` - collection of 12 environment variables configuring ranges of bucketing mechanism - - - `{phase}` is either `PROMPT` or `DECODE` - - - `{dim}` is either `BS`, `SEQ` or `BLOCK` - - - `{param}` is either `MIN`, `STEP` or `MAX` - - - Default values: - -| `{phase}` | Parameter | Env Variable | Value Expression | -|-----------|-----------|--------------|------------------| -| Prompt | Batch size min | `VLLM_PROMPT_BS_BUCKET_MIN` | `1` | -| Prompt | Batch size step | `VLLM_PROMPT_BS_BUCKET_STEP` | `min(max_num_seqs, 32)` | -| Prompt | Batch size max | `VLLM_PROMPT_BS_BUCKET_MAX` | `min(max_num_seqs, 64)` | -| Prompt | Sequence length min | `VLLM_PROMPT_SEQ_BUCKET_MIN` | `block_size` | -| Prompt | Sequence length step | `VLLM_PROMPT_SEQ_BUCKET_STEP` | `block_size` | -| Prompt | Sequence length max | `VLLM_PROMPT_SEQ_BUCKET_MAX` | `max_model_len` | -| Decode | Batch size min | `VLLM_DECODE_BS_BUCKET_MIN` | `1` | -| Decode | Batch size step | `VLLM_DECODE_BS_BUCKET_STEP` | `min(max_num_seqs, 32)` | -| Decode | Batch size max | `VLLM_DECODE_BS_BUCKET_MAX` | `max_num_seqs` | -| Decode | Sequence length min | `VLLM_DECODE_BLOCK_BUCKET_MIN` | `block_size` | -| Decode | Sequence length step | `VLLM_DECODE_BLOCK_BUCKET_STEP` | `block_size` | -| Decode | Sequence length max | `VLLM_DECODE_BLOCK_BUCKET_MAX` | `max(128, (max_num_seqs*max_model_len)/block_size)` | - -Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM execution: - -- `PT_HPU_LAZY_MODE`: if `0`, PyTorch Eager backend for Gaudi will be used; if `1`, PyTorch Lazy backend for Gaudi will be used. `1` is default. -- `PT_HPU_ENABLE_LAZY_COLLECTIVES`: required to be `true` for tensor parallel inference with HPU Graphs - -## Troubleshooting: tweaking HPU graphs - -If you experience device out-of-memory issues or want to attempt -inference at higher batch sizes, try tweaking HPU Graphs by following -the below: - -- Tweak `gpu_memory_utilization` knob. It will decrease the - allocation of KV cache, leaving some headroom for capturing graphs - with larger batch size. By default `gpu_memory_utilization` is set - to 0.9. It attempts to allocate ~90% of HBM left for KV cache after - short profiling run. Note that decreasing reduces the number of KV - cache blocks you have available, and therefore reduces the effective - maximum number of tokens you can handle at a given time. -- If this method is not efficient, you can disable `HPUGraph` - completely. With HPU Graphs disabled, you are trading latency and - throughput at lower batches for potentially higher throughput on - higher batches. You can do that by adding `--enforce-eager` flag to - server (for online serving), or by passing `enforce_eager=True` - argument to LLM constructor (for offline inference). diff --git a/docs/models/pooling_models.md b/docs/models/pooling_models.md index fbb5f6f6dd1..d2fbb1870dd 100644 --- a/docs/models/pooling_models.md +++ b/docs/models/pooling_models.md @@ -258,4 +258,4 @@ Expected output: {"id":"embd-5c21fc9a5c9d4384a1b021daccaf9f64","object":"list","created":1745476417,"model":"jinaai/jina-embeddings-v3","data":[{"index":0,"object":"embedding","embedding":[-0.3828125,-0.1357421875,0.03759765625,0.125,0.21875,0.09521484375,-0.003662109375,0.1591796875,-0.130859375,-0.0869140625,-0.1982421875,0.1689453125,-0.220703125,0.1728515625,-0.2275390625,-0.0712890625,-0.162109375,-0.283203125,-0.055419921875,-0.0693359375,0.031982421875,-0.04052734375,-0.2734375,0.1826171875,-0.091796875,0.220703125,0.37890625,-0.0888671875,-0.12890625,-0.021484375,-0.0091552734375,0.23046875]}],"usage":{"prompt_tokens":8,"total_tokens":8,"completion_tokens":0,"prompt_tokens_details":null}} ``` -A openai client example can be found here: +An OpenAI client example can be found here: diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 74f3a9d1cdb..e8fe77e8d6c 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -40,7 +40,7 @@ If it is `TransformersForCausalLM` or `TransformersForMultimodalLM` then it mean #### Custom models -If a model is neither supported natively by vLLM or Transformers, it can still be used in vLLM! +If a model is neither supported natively by vLLM nor Transformers, it can still be used in vLLM! For a model to be compatible with the Transformers backend for vLLM it must: @@ -335,9 +335,9 @@ th { | `CohereForCausalLM`, `Cohere2ForCausalLM` | Command-R, Command-A | `CohereLabs/c4ai-command-r-v01`, `CohereLabs/c4ai-command-r7b-12-2024`, `CohereLabs/c4ai-command-a-03-2025`, `CohereLabs/command-a-reasoning-08-2025`, etc. | ✅︎ | ✅︎ | ✅︎ | | `DbrxForCausalLM` | DBRX | `databricks/dbrx-base`, `databricks/dbrx-instruct`, etc. | | ✅︎ | ✅︎ | | `DeciLMForCausalLM` | DeciLM | `nvidia/Llama-3_3-Nemotron-Super-49B-v1`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `DeepseekForCausalLM` | DeepSeek | `deepseek-ai/deepseek-llm-67b-base`, `deepseek-ai/deepseek-llm-7b-chat`, etc. | | ✅︎ | ✅︎ | -| `DeepseekV2ForCausalLM` | DeepSeek-V2 | `deepseek-ai/DeepSeek-V2`, `deepseek-ai/DeepSeek-V2-Chat`, etc. | | ✅︎ | ✅︎ | -| `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3-Base`, `deepseek-ai/DeepSeek-V3`, etc. | | ✅︎ | ✅︎ | +| `DeepseekForCausalLM` | DeepSeek | `deepseek-ai/deepseek-llm-67b-base`, `deepseek-ai/deepseek-llm-7b-chat`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `DeepseekV2ForCausalLM` | DeepSeek-V2 | `deepseek-ai/DeepSeek-V2`, `deepseek-ai/DeepSeek-V2-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3`, `deepseek-ai/DeepSeek-R1`, `deepseek-ai/DeepSeek-V3.1`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Dots1ForCausalLM` | dots.llm1 | `rednote-hilab/dots.llm1.base`, `rednote-hilab/dots.llm1.inst`, etc. | | ✅︎ | ✅︎ | | `Ernie4_5ForCausalLM` | Ernie4.5 | `baidu/ERNIE-4.5-0.3B-PT`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Ernie4_5_MoeForCausalLM` | Ernie4.5MoE | `baidu/ERNIE-4.5-21B-A3B-PT`, `baidu/ERNIE-4.5-300B-A47B-PT`, etc. |✅︎| ✅︎ | ✅︎ | @@ -358,7 +358,7 @@ th { | `GPTBigCodeForCausalLM` | StarCoder, SantaCoder, WizardCoder | `bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, `WizardLM/WizardCoder-15B-V1.0`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GPTJForCausalLM` | GPT-J | `EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc. | | ✅︎ | ✅︎ | | `GPTNeoXForCausalLM` | GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM | `EleutherAI/gpt-neox-20b`, `EleutherAI/pythia-12b`, `OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc. | | ✅︎ | ✅︎ | -| `GptOssForCausalLM` | GPT-OSS | `openai/gpt-oss-120b`, `openai/gpt-oss-20b` | | | ✅︎ | +| `GptOssForCausalLM` | GPT-OSS | `openai/gpt-oss-120b`, `openai/gpt-oss-20b` | | ✅︎ | ✅︎ | | `GraniteForCausalLM` | Granite 3.0, Granite 3.1, PowerLM | `ibm-granite/granite-3.0-2b-base`, `ibm-granite/granite-3.1-8b-instruct`, `ibm/PowerLM-3b`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GraniteMoeForCausalLM` | Granite 3.0 MoE, PowerMoE | `ibm-granite/granite-3.0-1b-a400m-base`, `ibm-granite/granite-3.0-3b-a800m-instruct`, `ibm/PowerMoE-3b`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GraniteMoeHybridForCausalLM` | Granite 4.0 MoE Hybrid | `ibm-granite/granite-4.0-tiny-preview`, etc. | ✅︎ | ✅︎ | ✅︎ | @@ -497,6 +497,7 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A |--------------|--------|-------------------|----------------------|---------------------------|---------------------| | `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | | ✅︎ | | `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | +| `GteNewForSequenceClassification` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-reranker-base`, etc. | | | ✅︎ | | `Qwen2ForSequenceClassification` | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | | `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | | ✅︎ | @@ -513,6 +514,9 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A vllm serve BAAI/bge-reranker-v2-gemma --hf_overrides '{"architectures": ["GemmaForSequenceClassification"],"classifier_from_token": ["Yes"],"method": "no_post_processing"}' ``` +!!! note + The second-generation GTE model (mGTE-TRM) is named `NewForSequenceClassification`. The name `NewForSequenceClassification` is too generic, you should set `--hf-overrides '{"architectures": ["GteNewForSequenceClassification"]}'` to specify the use of the `GteNewForSequenceClassification` architecture. + !!! note Load the official original `mxbai-rerank-v2` by using the following command. @@ -616,6 +620,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `Cohere2VisionForConditionalGeneration` | Command A Vision | T + I+ | `CohereLabs/command-a-vision-07-2025`, etc. | | ✅︎ | ✅︎ | | `DeepseekVLV2ForCausalLM`^ | DeepSeek-VL2 | T + I+ | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2`, etc. | | ✅︎ | ✅︎ | | `DonutForConditionalGeneration`^ | Donut | T + I | `ByteDance/Dolphin`, `naver-clova-ix/donut-base-finetuned-docvqa`, etc. | | | | +| `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + I+/ V+ | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ | ✅︎ | | `Florence2ForConditionalGeneration` | Florence-2 | T + I | `microsoft/Florence-2-base`, `microsoft/Florence-2-large`, etc. | | | | | `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | ✅︎ | | `Gemma3ForConditionalGeneration` | Gemma 3 | T + I+ | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ | @@ -628,6 +633,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | | ✅︎ | | `InternS1ForConditionalGeneration` | Intern-S1 | T + IE+ + VE+ | `internlm/Intern-S1`, etc. | ✅︎ | ✅︎ | ✅︎ | | `InternVLChatModel` | InternVL 3.5, InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + IE+ + (VE+) | `OpenGVLab/InternVL3_5-14B`, `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `InternVLForConditionalGeneration` | InternVL 3.0 (HF format) | T + IE+ + VE+ | `OpenGVLab/InternVL3-1B-hf`, etc. | ✅︎ | ✅︎ | ✅︎ | | `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + IE+ + VE+ | `Kwai-Keye/Keye-VL-8B-Preview` | | | ✅︎ | | `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I+ | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | ✅︎ | ✅︎ | | `Llama4ForConditionalGeneration` | Llama 4 | T + I+ | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | | ✅︎ | ✅︎ | @@ -637,7 +643,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `LlavaNextVideoForConditionalGeneration` | LLaVA-NeXT-Video | T + V | `llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. | | ✅︎ | ✅︎ | | `LlavaOnevisionForConditionalGeneration` | LLaVA-Onevision | T + I+ + V+ | `llava-hf/llava-onevision-qwen2-7b-ov-hf`, `llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. | | ✅︎ | ✅︎ | | `MiniCPMO` | MiniCPM-O | T + IE+ + VE+ + AE+ | `openbmb/MiniCPM-o-2_6`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `MiniCPMV` | MiniCPM-V | T + IE+ + VE+ | `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, `openbmb/MiniCPM-V-4`, etc. | ✅︎ | | ✅︎ | +| `MiniCPMV` | MiniCPM-V | T + IE+ + VE+ | `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, `openbmb/MiniCPM-V-4`, `openbmb/MiniCPM-V-4_5`, etc. | ✅︎ | | ✅︎ | | `MiniMaxVL01ForConditionalGeneration` | MiniMax-VL | T + IE+ | `MiniMaxAI/MiniMax-VL-01`, etc. | | ✅︎ | ✅︎ | | `Mistral3ForConditionalGeneration` | Mistral3 (HF Transformers) | T + I+ | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MllamaForConditionalGeneration` | Llama 3.2 | T + I+ | `meta-llama/Llama-3.2-90B-Vision-Instruct`, `meta-llama/Llama-3.2-11B-Vision`, etc. | | | | diff --git a/docs/usage/usage_stats.md b/docs/usage/usage_stats.md index e78c67522f6..4c7a7ff019e 100644 --- a/docs/usage/usage_stats.md +++ b/docs/usage/usage_stats.md @@ -51,7 +51,7 @@ tail ~/.config/vllm/usage_stats.json ## Opting out -You can opt-out of usage stats collection by setting the `VLLM_NO_USAGE_STATS` or `DO_NOT_TRACK` environment variable, or by creating a `~/.config/vllm/do_not_track` file: +You can opt out of usage stats collection by setting the `VLLM_NO_USAGE_STATS` or `DO_NOT_TRACK` environment variable, or by creating a `~/.config/vllm/do_not_track` file: ```bash # Any of the following methods can disable usage stats collection diff --git a/docs/usage/v1_guide.md b/docs/usage/v1_guide.md index 64bd0d9bf50..f71805436a6 100644 --- a/docs/usage/v1_guide.md +++ b/docs/usage/v1_guide.md @@ -107,14 +107,14 @@ to enable simultaneous generation and embedding using the same engine instance i #### Mamba Models Models using selective state-space mechanisms instead of standard transformer attention are supported. -Models that use Mamba-2 and Mamba-1 layers (e.g., `Mamba2ForCausalLM`, `MambaForCausalLM`) are supported. Please note that these models currently require disabling prefix caching in V1. +Models that use Mamba-2 and Mamba-1 layers (e.g., `Mamba2ForCausalLM`, `MambaForCausalLM`,`FalconMambaForCausalLM`) are supported. -Models that combine Mamba-2 and Mamba-1 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`, -`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`, `JambaForCausalLM`). Please note that -these models currently require disabling prefix caching in V1. +Hybrid models that combine Mamba-2 and Mamba-1 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`, +`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`, `JambaForCausalLM`). -Hybrid models with mechanisms different to Mamba are also supported (e.g, `MiniMaxText01ForCausalLM`, `MiniMaxM1ForCausalLM`). -Please note that these models currently require disabling prefix caching and enforcing eager mode in V1. +Hybrid models with mechanisms different to Mamba are also supported (e.g, `MiniMaxText01ForCausalLM`, `MiniMaxM1ForCausalLM`, `Lfm2ForCausalLM`). + +Please note that prefix caching is not yet supported for any of the above models. #### Encoder-Decoder Models diff --git a/examples/offline_inference/logits_processor.py b/examples/offline_inference/logits_processor.py index 7ef20efa7d2..3e122319169 100644 --- a/examples/offline_inference/logits_processor.py +++ b/examples/offline_inference/logits_processor.py @@ -42,8 +42,8 @@ class object. from vllm.v1.sample.logits_processor import ( BatchUpdate, LogitsProcessor, - MoveDirectionality, ) +from vllm.v1.sample.logits_processor.builtin import process_dict_updates # Hypothetical custom logits processor @@ -53,38 +53,22 @@ class DummyLogitsProcessor(LogitsProcessor): def __init__( self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool ): - self.req_info: dict[int, SamplingParams] = {} + self.req_info: dict[int, int] = {} def is_argmax_invariant(self) -> bool: """Never impacts greedy sampling""" return False def update_state(self, batch_update: Optional[BatchUpdate]): - if not batch_update: - return - - # Process added requests. - for index, params, _, _ in batch_update.added: - assert params is not None - if params.extra_args and ( - target_token := params.extra_args.get("target_token") - ): - self.req_info[index] = target_token - - if self.req_info: - # Process removed requests. - for index in batch_update.removed: - self.req_info.pop(index, None) - - # Process moved requests, unidirectional move (a->b) and swap - # (a<->b) - for adx, bdx, direct in batch_update.moved: - a_val = self.req_info.pop(adx, None) - b_val = self.req_info.pop(bdx, None) - if a_val is not None: - self.req_info[bdx] = a_val - if direct == MoveDirectionality.SWAP and b_val is not None: - self.req_info[adx] = b_val + process_dict_updates( + self.req_info, + batch_update, + # This function returns the LP's per-request state based on the + # request details, or None if this LP does not apply to the + # request. + lambda params, _, __: params.extra_args + and (params.extra_args.get("target_token")), + ) def apply(self, logits: torch.Tensor) -> torch.Tensor: if not self.req_info: diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index c4972f02d0f..5af232cb6af 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -138,7 +138,7 @@ def main(): sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len) if not args.custom_mm_prompts: outputs = llm.generate( - TokensPrompt(prompt_token_ids=prompt_ids), + [TokensPrompt(prompt_token_ids=x) for x in prompt_ids], sampling_params=sampling_params, ) else: diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 8d97ba26682..4e879666f61 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -173,6 +173,37 @@ def run_deepseek_vl2(questions: list[str], modality: str) -> ModelRequestData: ) +# Ernie4.5-VL +def run_ernie45_vl(questions: list[str], modality: str) -> ModelRequestData: + model_name = "baidu/ERNIE-4.5-VL-28B-A3B-PT" + + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=5, + limit_mm_per_prompt={modality: 1}, + trust_remote_code=True, + ) + + if modality == "image": + placeholder = "Picture 1:<|IMAGE_START|><|image@placeholder|><|IMAGE_END|>" + elif modality == "video": + placeholder = "Video 1:<|VIDEO_START|><|video@placeholder|><|VIDEO_END|>" + + prompts = [ + ( + f"<|begin_of_sentence|>User: {question}{placeholder}\n" + "Assistant: " + ) + for question in questions + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # Florence2 def run_florence2(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -1602,6 +1633,7 @@ def run_tarsier2(questions: list[str], modality: str) -> ModelRequestData: "chameleon": run_chameleon, "command_a_vision": run_command_a_vision, "deepseek_vl_v2": run_deepseek_vl2, + "ernie45_vl": run_ernie45_vl, "florence2": run_florence2, "fuyu": run_fuyu, "gemma3": run_gemma3, diff --git a/examples/online_serving/kv_events_subscriber.py b/examples/online_serving/kv_events_subscriber.py index 584db53db4e..f238c66234d 100644 --- a/examples/online_serving/kv_events_subscriber.py +++ b/examples/online_serving/kv_events_subscriber.py @@ -27,10 +27,12 @@ class BlockStored(KVCacheEvent): token_ids: list[int] block_size: int lora_id: Optional[int] + medium: Optional[str] class BlockRemoved(KVCacheEvent): block_hashes: list[int] + medium: Optional[str] class AllBlocksCleared(KVCacheEvent): diff --git a/pyproject.toml b/pyproject.toml index 013f2a6cd59..e63f8aeae27 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ requires = [ "packaging>=24.2", "setuptools>=77.0.3,<80.0.0", "setuptools-scm>=8.0", - "torch == 2.7.1", + "torch == 2.8.0", "wheel", "jinja2", ] diff --git a/requirements/build.txt b/requirements/build.txt index dd644d621ef..5f826a1afa1 100644 --- a/requirements/build.txt +++ b/requirements/build.txt @@ -4,7 +4,8 @@ ninja packaging>=24.2 setuptools>=77.0.3,<80.0.0 setuptools-scm>=8 -torch==2.7.1 +torch==2.8.0 wheel jinja2>=3.1.6 regex +build diff --git a/requirements/cpu.txt b/requirements/cpu.txt index f4b95b72898..a48cb9fde00 100644 --- a/requirements/cpu.txt +++ b/requirements/cpu.txt @@ -9,17 +9,16 @@ packaging>=24.2 setuptools>=77.0.3,<80.0.0 --extra-index-url https://download.pytorch.org/whl/cpu torch==2.6.0+cpu; platform_machine == "x86_64" # torch>2.6.0+cpu has performance regression on x86 platform, see https://github.com/pytorch/pytorch/pull/151218 -torch==2.7.0; platform_system == "Darwin" -torch==2.7.0; platform_machine == "ppc64le" -torch==2.6.0; platform_machine == "aarch64" # for arm64 CPUs, torch 2.7.0 has a issue: https://github.com/vllm-project/vllm/issues/17960 +torch==2.8.0; platform_system == "Darwin" +torch==2.8.0; platform_machine == "ppc64le" or platform_machine == "aarch64" # required for the image processor of minicpm-o-2_6, this must be updated alongside torch torchaudio; platform_machine != "ppc64le" and platform_machine != "s390x" -torchaudio==2.7.0; platform_machine == "ppc64le" +torchaudio==2.8.0; platform_machine == "ppc64le" # required for the image processor of phi3v, this must be updated alongside torch torchvision; platform_machine != "ppc64le" and platform_machine != "s390x" -torchvision==0.22.0; platform_machine == "ppc64le" +torchvision==0.23.0; platform_machine == "ppc64le" datasets # for benchmark scripts # Intel Extension for PyTorch, only for x86_64 CPUs diff --git a/requirements/cuda.txt b/requirements/cuda.txt index fb30e493f80..3f8b8fca320 100644 --- a/requirements/cuda.txt +++ b/requirements/cuda.txt @@ -6,9 +6,9 @@ numba == 0.61.2; python_version > '3.9' # Dependencies for NVIDIA GPUs ray[cgraph]>=2.48.0 # Ray Compiled Graph, required for pipeline parallelism in V1. -torch==2.7.1 -torchaudio==2.7.1 +torch==2.8.0 +torchaudio==2.8.0 # These must be updated alongside torch -torchvision==0.22.1 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version -# https://github.com/facebookresearch/xformers/releases/tag/v0.0.31 -xformers==0.0.31; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.7 \ No newline at end of file +torchvision==0.23.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version +# https://github.com/facebookresearch/xformers/releases/tag/v0.0.32.post1 +xformers==0.0.32.post1; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.8 diff --git a/requirements/rocm-build.txt b/requirements/rocm-build.txt index cbae9bbb8a9..affe562c24f 100644 --- a/requirements/rocm-build.txt +++ b/requirements/rocm-build.txt @@ -1,10 +1,10 @@ # Common dependencies -r common.txt ---extra-index-url https://download.pytorch.org/whl/rocm6.2.4 -torch==2.7.0 -torchvision==0.22.0 -torchaudio==2.7.0 +--extra-index-url https://download.pytorch.org/whl/rocm6.3 +torch==2.8.0 +torchvision==0.23.0 +torchaudio==2.8.0 triton==3.3.0 cmake>=3.26.1,<4 diff --git a/requirements/test.in b/requirements/test.in index 098a9242bc3..5b1688c76c9 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -22,9 +22,9 @@ sentence-transformers # required for embedding tests soundfile # required for audio tests jiwer # required for audio tests timm >=1.0.17 # required for internvl and gemma3n-mm test -torch==2.7.1 -torchaudio==2.7.1 -torchvision==0.22.1 +torch==2.8.0 +torchaudio==2.8.0 +torchvision==0.23.0 transformers_stream_generator # required for qwen-vl test matplotlib # required for qwen-vl test mistral_common[image,audio] >= 1.8.2 # required for voxtral test @@ -54,3 +54,4 @@ runai-model-streamer-s3==0.11.0 fastsafetensors>=0.1.10 pydantic>=2.10 # 2.9 leads to error on python 3.10 terratorch==1.1rc2 # required for PrithviMAE test +decord==0.6.0 diff --git a/requirements/test.txt b/requirements/test.txt index 8b872752d87..0b728ebfb00 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -156,6 +156,8 @@ datasets==3.0.2 # mteb decorator==5.1.1 # via librosa +decord==0.6.0 + # via -r requirements/test.in dill==0.3.8 # via # datasets @@ -493,6 +495,7 @@ numpy==1.26.4 # contourpy # cupy-cuda12x # datasets + # decord # einx # encodec # evaluate @@ -538,42 +541,42 @@ numpy==1.26.4 # tritonclient # vocos # xarray -nvidia-cublas-cu12==12.8.3.14 +nvidia-cublas-cu12==12.8.4.1 # via # nvidia-cudnn-cu12 # nvidia-cusolver-cu12 # torch -nvidia-cuda-cupti-cu12==12.8.57 +nvidia-cuda-cupti-cu12==12.8.90 # via torch -nvidia-cuda-nvrtc-cu12==12.8.61 +nvidia-cuda-nvrtc-cu12==12.8.93 # via torch -nvidia-cuda-runtime-cu12==12.8.57 +nvidia-cuda-runtime-cu12==12.8.90 # via torch -nvidia-cudnn-cu12==9.7.1.26 +nvidia-cudnn-cu12==9.10.2.21 # via torch -nvidia-cufft-cu12==11.3.3.41 +nvidia-cufft-cu12==11.3.3.83 # via torch -nvidia-cufile-cu12==1.13.0.11 +nvidia-cufile-cu12==1.13.1.3 # via torch -nvidia-curand-cu12==10.3.9.55 +nvidia-curand-cu12==10.3.9.90 # via torch -nvidia-cusolver-cu12==11.7.2.55 +nvidia-cusolver-cu12==11.7.3.90 # via torch -nvidia-cusparse-cu12==12.5.7.53 +nvidia-cusparse-cu12==12.5.8.93 # via # nvidia-cusolver-cu12 # torch -nvidia-cusparselt-cu12==0.6.3 +nvidia-cusparselt-cu12==0.7.1 # via torch -nvidia-nccl-cu12==2.26.2 +nvidia-nccl-cu12==2.27.3 # via torch -nvidia-nvjitlink-cu12==12.8.61 +nvidia-nvjitlink-cu12==12.8.93 # via # nvidia-cufft-cu12 # nvidia-cusolver-cu12 # nvidia-cusparse-cu12 # torch -nvidia-nvtx-cu12==12.8.55 +nvidia-nvtx-cu12==12.8.90 # via torch omegaconf==2.3.0 # via @@ -1066,7 +1069,7 @@ tomli==2.2.1 # via schemathesis tomli-w==1.2.0 # via schemathesis -torch==2.7.1+cu128 +torch==2.8.0+cu128 # via # -r requirements/test.in # accelerate @@ -1095,7 +1098,7 @@ torch==2.7.1+cu128 # torchvision # vector-quantize-pytorch # vocos -torchaudio==2.7.1+cu128 +torchaudio==2.8.0+cu128 # via # -r requirements/test.in # encodec @@ -1108,7 +1111,7 @@ torchmetrics==1.7.4 # pytorch-lightning # terratorch # torchgeo -torchvision==0.22.1+cu128 +torchvision==0.23.0+cu128 # via # -r requirements/test.in # lightly @@ -1149,7 +1152,7 @@ transformers==4.55.2 # transformers-stream-generator transformers-stream-generator==0.0.5 # via -r requirements/test.in -triton==3.3.1 +triton==3.4.0 # via torch tritonclient==2.51.0 # via diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index 0e1059e6544..fcc2589e421 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -4,32 +4,41 @@ import torch import vllm.envs as envs -from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass -from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe +from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant +# yapf conflicts with isort for this block +# yapf: disable +from vllm.compilation.activation_quant_fusion import ( + FUSED_OPS, SILU_MUL_OP, ActivationQuantFusionPass) +# yapf: enable +from vllm.compilation.fusion import QUANT_OPS from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.config import CompilationConfig, PassConfig, VllmConfig from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape) + GroupShape, kFp8StaticTensorSym, kNvfp4Quant) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp) from vllm.platforms import current_platform from .backend import TestBackend +FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 -class TestModel(torch.nn.Module): - def __init__(self, hidden_size: int, force_fp8_e4m3fnuz: bool, *args, - **kwargs): - super().__init__(*args, **kwargs) +def is_nvfp4_supported(): + return current_platform.has_device_capability(100) + + +class TestSiluMulFp8QuantModel(torch.nn.Module): + + def __init__(self, hidden_size: int, force_fp8_e4m3fnuz: bool, **kwargs): + super().__init__() self.silu_and_mul = SiluAndMul() self.wscale = torch.rand(1, dtype=torch.float32) self.scale = torch.rand(1, dtype=torch.float32) - self.w = (torch.rand( - hidden_size, - hidden_size).to(dtype=current_platform.fp8_dtype()).t()) + self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() self.fp8_linear = Fp8LinearOp( force_fp8_e4m3fnuz=force_fp8_e4m3fnuz, @@ -45,14 +54,56 @@ def forward(self, x): input_scale=self.wscale) return x2 + def ops_in_model_before(self): + return [SILU_MUL_OP, QUANT_OPS[kFp8StaticTensorSym]] + + def ops_in_model_after(self): + return [FUSED_OPS[kFp8StaticTensorSym]] + + +class TestSiluMulNvfp4QuantModel(torch.nn.Module): + + def __init__(self, hidden_size: int, **kwargs): + super().__init__() + self.silu_and_mul = SiluAndMul() + self.w = torch.randint(256, (hidden_size, hidden_size // 2), + dtype=FP4_DTYPE) + self.wscale = torch.randn(hidden_size, + hidden_size // 16).to(dtype=FP8_DTYPE) + self.wscale2 = torch.rand(1, dtype=torch.float32) + self.scale = torch.rand(1, dtype=torch.float32) -@pytest.mark.parametrize("num_tokens", [256]) -@pytest.mark.parametrize("hidden_size", [64]) + def forward(self, x): + y = self.silu_and_mul(x) + y_quant, y_block_scale = scaled_fp4_quant(y, 1 / self.scale) + out = cutlass_scaled_fp4_mm(a=y_quant, + b=self.w, + block_scale_a=y_block_scale, + block_scale_b=self.wscale, + alpha=self.scale * self.wscale2, + out_dtype=y.dtype) + return out + + def ops_in_model_before(self): + return [SILU_MUL_OP, QUANT_OPS[kNvfp4Quant]] + + def ops_in_model_after(self): + return [FUSED_OPS[kNvfp4Quant]] + + +@pytest.mark.parametrize("num_tokens", [64]) +@pytest.mark.parametrize("hidden_size", [128]) +@pytest.mark.parametrize( + "model_class", [TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel] + if is_nvfp4_supported() else [TestSiluMulFp8QuantModel]) @pytest.mark.parametrize("force_fp8_e4m3fnuz", [True, False]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm") -def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, +def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, model_class, force_fp8_e4m3fnuz): + if model_class == TestSiluMulNvfp4QuantModel and force_fp8_e4m3fnuz: + pytest.skip("Duplicate tests for NVFP4") + torch.set_default_device("cuda") torch.set_default_dtype(torch.float16) @@ -63,7 +114,8 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, fusion_pass = ActivationQuantFusionPass(config) backend = TestBackend(NoOpEliminationPass(config), fusion_pass) - model = TestModel(hidden_size, force_fp8_e4m3fnuz) + model = model_class(hidden_size=hidden_size, + force_fp8_e4m3fnuz=force_fp8_e4m3fnuz) # First dimension dynamic x = torch.rand(num_tokens, hidden_size * 2) @@ -80,17 +132,8 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, atol=1e-3, rtol=1e-3) - # Check substitution worked - pre_nodes = backend.graph_pre_pass.nodes - post_nodes = backend.graph_post_pass.nodes - - silu_and_mul_quant = torch.ops._C.silu_and_mul_quant.default - fp8_quant = torch.ops._C.static_scaled_fp8_quant.default - - # In pre-nodes, fp8 quant should be present and fused kernels should not - assert find_auto_fn_maybe(pre_nodes, silu_and_mul_quant) is None - find_auto_fn(pre_nodes, fp8_quant) + # In pre-nodes, quant op should be present and fused kernels should not + backend.check_before_ops(model.ops_in_model_before()) - # In post-nodes, fused kernels should be present and fp8 quant should not - find_auto_fn(post_nodes, silu_and_mul_quant) - assert find_auto_fn_maybe(post_nodes, fp8_quant) is None + # In post-nodes, fused kernels should be present and quant op should not + backend.check_after_ops(model.ops_in_model_after()) diff --git a/tests/conftest.py b/tests/conftest.py index 2bf88abb0f6..9fed43cba54 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import json +import math import os import tempfile from enum import Enum -from typing import Any, Callable, Optional, TypedDict, TypeVar, Union +from typing import Any, Callable, Optional, TypedDict, TypeVar, Union, cast import numpy as np import pytest @@ -33,6 +34,7 @@ from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams +from vllm.sequence import Logprob from vllm.transformers_utils.utils import maybe_model_redirect logger = init_logger(__name__) @@ -454,11 +456,10 @@ def classify(self, prompts: list[str]) -> list[str]: # output is final logits all_inputs = self.get_inputs(prompts) outputs = [] + problem_type = getattr(self.config, "problem_type", "") + for inputs in all_inputs: output = self.model(**self.wrap_device(inputs)) - - problem_type = getattr(self.config, "problem_type", "") - if problem_type == "regression": logits = output.logits[0].tolist() elif problem_type == "multi_label_classification": @@ -602,7 +603,7 @@ def _hidden_states_to_seq_logprobs( def _hidden_states_to_logprobs( self, hidden_states: tuple[tuple[torch.Tensor, ...], ...], - num_logprobs: int, + num_logprobs: Optional[int], ) -> tuple[list[dict[int, float]], int]: seq_logprobs = self._hidden_states_to_seq_logprobs(hidden_states) output_len = len(hidden_states) @@ -630,7 +631,7 @@ def generate_greedy_logprobs_limit( self, prompts: list[str], max_tokens: int, - num_logprobs: int, + num_logprobs: Optional[int], images: Optional[PromptImageInput] = None, audios: Optional[PromptAudioInput] = None, videos: Optional[PromptVideoInput] = None, @@ -677,7 +678,7 @@ def generate_encoder_decoder_greedy_logprobs_limit( self, encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]], max_tokens: int, - num_logprobs: int, + num_logprobs: Optional[int], images: Optional[PromptImageInput] = None, **kwargs: Any, ) -> list[TokensTextLogprobs]: @@ -966,7 +967,7 @@ def generate_greedy_logprobs( self, prompts: list[str], max_tokens: int, - num_logprobs: int, + num_logprobs: Optional[int], num_prompt_logprobs: Optional[int] = None, images: Optional[PromptImageInput] = None, audios: Optional[PromptAudioInput] = None, @@ -991,11 +992,40 @@ def generate_greedy_logprobs( videos=videos, **kwargs) + def generate_prompt_perplexity(self, prompts: list[str]) -> list[float]: + """ + Return the perplexity score associated with generating the prompts + + :param prompts: list of prompts to score + :return: perplexity score of each prompt + """ + outputs = self.generate_greedy_logprobs(prompts, + max_tokens=1, + num_logprobs=None, + num_prompt_logprobs=0) + + perplexities = [] + for output in outputs: + output = cast(TokensTextLogprobsPromptLogprobs, output) + token_datas = cast(list[Optional[dict[int, Logprob]]], output[3]) + assert token_datas[0] is None + token_log_probs = [] + for token_data in token_datas[1:]: + assert token_data is not None + assert len(token_data) == 1 + token_log_prob = list(token_data.values())[0].logprob + token_log_probs.append(token_log_prob) + + perplexity = math.exp(-sum(token_log_probs) / len(token_log_probs)) + perplexities.append(perplexity) + + return perplexities + def generate_encoder_decoder_greedy_logprobs( self, encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]], max_tokens: int, - num_logprobs: int, + num_logprobs: Optional[int], num_prompt_logprobs: Optional[int] = None, skip_special_tokens: bool = True, ) -> Union[list[TokensTextLogprobs], @@ -1022,15 +1052,17 @@ def generate_beam_search( images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, + concurrency_limit: Optional[int] = None, ) -> list[tuple[list[list[int]], list[str]]]: inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) - outputs = self.llm.beam_search( - inputs, - BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens)) + outputs = self.llm.beam_search(inputs, + BeamSearchParams(beam_width=beam_width, + max_tokens=max_tokens), + concurrency_limit=concurrency_limit) returned_outputs = [] for output in outputs: token_ids = [x.tokens for x in output.sequences] diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 28150d76823..1afe9ea970c 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -118,6 +118,8 @@ def fast( multi_node_only: bool = False, load_format: Optional[str] = None, ): + vllm_major_versions = ["1"] if runner == "pooling" else ["0"] + return PPTestSettings( parallel_setups=[ ParallelSetup(tp_size=tp_base, @@ -126,7 +128,7 @@ def fast( chunked_prefill=False), ], distributed_backends=["mp"], - vllm_major_versions=["0"], + vllm_major_versions=vllm_major_versions, runner=runner, test_options=PPTestOptions(multi_node_only=multi_node_only, load_format=load_format), @@ -213,7 +215,9 @@ def iter_params(self, model_id: str): EMBEDDING_MODELS = { # type: ignore[var-annotated] # [Text-only] "intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(runner="pooling"), - "BAAI/bge-multilingual-gemma2": PPTestSettings.fast(runner="pooling"), + # TODO: re-enable when https://github.com/vllm-project/vllm/issues/23883 + # is fixed + #"BAAI/bge-multilingual-gemma2": PPTestSettings.fast(runner="pooling"), "Qwen/Qwen2.5-Math-RM-72B": PPTestSettings.fast( load_format="dummy", runner="pooling" ), diff --git a/tests/distributed/test_sequence_parallel.py b/tests/distributed/test_sequence_parallel.py index 49b8eddecb4..c93b436f384 100644 --- a/tests/distributed/test_sequence_parallel.py +++ b/tests/distributed/test_sequence_parallel.py @@ -292,7 +292,7 @@ def _compare_sp( # TODO support other models # [LANGUAGE GENERATION] "meta-llama/Llama-3.2-1B-Instruct", - "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8" + "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8", ] diff --git a/tests/entrypoints/llm/test_classify.py b/tests/entrypoints/llm/test_classify.py index 57705ff6690..6c0c9cd0158 100644 --- a/tests/entrypoints/llm/test_classify.py +++ b/tests/entrypoints/llm/test_classify.py @@ -16,14 +16,6 @@ prompts = ["The chef prepared a delicious meal."] -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - @pytest.fixture(scope="module") def llm(): # pytest caches the fixture so we use weakref.proxy to @@ -70,3 +62,9 @@ def test_encode_api(llm: LLM): err_msg = "pooling_task must be one of.+" with pytest.raises(ValueError, match=err_msg): llm.encode(prompts, use_tqdm=False) + + +def test_score_api(llm: LLM): + err_msg = "Score API is only enabled for num_labels == 1." + with pytest.raises(ValueError, match=err_msg): + llm.score("ping", "pong", use_tqdm=False) diff --git a/tests/entrypoints/llm/test_encode.py b/tests/entrypoints/llm/test_encode.py index cb54b16b0b0..eae3e234378 100644 --- a/tests/entrypoints/llm/test_encode.py +++ b/tests/entrypoints/llm/test_encode.py @@ -27,14 +27,6 @@ ] -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - @pytest.fixture(scope="module") def llm(): # pytest caches the fixture so we use weakref.proxy to diff --git a/tests/entrypoints/llm/test_generate_multiple_loras.py b/tests/entrypoints/llm/test_generate_multiple_loras.py deleted file mode 100644 index a04f195692e..00000000000 --- a/tests/entrypoints/llm/test_generate_multiple_loras.py +++ /dev/null @@ -1,80 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import weakref - -import pytest -# downloading lora to test lora requests -from huggingface_hub import snapshot_download - -from vllm import LLM -from vllm.distributed import cleanup_dist_env_and_memory -from vllm.lora.request import LoRARequest - -MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" - -PROMPTS = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", -] - -LORA_NAME = "typeof/zephyr-7b-beta-lora" - - -@pytest.fixture(scope="module") -def monkeypatch_module(): - from _pytest.monkeypatch import MonkeyPatch - mpatch = MonkeyPatch() - yield mpatch - mpatch.undo() - - -@pytest.fixture(scope="module", params=[False, True]) -def llm(request, monkeypatch_module): - - use_v1 = request.param - monkeypatch_module.setenv('VLLM_USE_V1', '1' if use_v1 else '0') - - # pytest caches the fixture so we use weakref.proxy to - # enable garbage collection - llm = LLM(model=MODEL_NAME, - tensor_parallel_size=1, - max_model_len=8192, - enable_lora=True, - max_loras=4, - max_lora_rank=64, - max_num_seqs=128, - enforce_eager=True) - - yield weakref.proxy(llm) - - del llm - - cleanup_dist_env_and_memory() - - -@pytest.fixture(scope="module") -def zephyr_lora_files(): - return snapshot_download(repo_id=LORA_NAME) - - -@pytest.mark.skip_global_cleanup -def test_multiple_lora_requests(llm: LLM, zephyr_lora_files): - lora_request = [ - LoRARequest(LORA_NAME + str(idx), idx + 1, zephyr_lora_files) - for idx in range(len(PROMPTS)) - ] - # Multiple SamplingParams should be matched with each prompt - outputs = llm.generate(PROMPTS, lora_request=lora_request) - assert len(PROMPTS) == len(outputs) - - # Exception raised, if the size of params does not match the size of prompts - with pytest.raises(ValueError): - outputs = llm.generate(PROMPTS, lora_request=lora_request[:1]) - - # Single LoRARequest should be applied to every prompt - single_lora_request = lora_request[0] - outputs = llm.generate(PROMPTS, lora_request=single_lora_request) - assert len(PROMPTS) == len(outputs) diff --git a/tests/entrypoints/llm/test_reward.py b/tests/entrypoints/llm/test_reward.py index de82cf8d403..2cee3c8d94e 100644 --- a/tests/entrypoints/llm/test_reward.py +++ b/tests/entrypoints/llm/test_reward.py @@ -16,14 +16,6 @@ prompts = ["The chef prepared a delicious meal."] -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - @pytest.fixture(scope="module") def llm(): # pytest caches the fixture so we use weakref.proxy to diff --git a/tests/entrypoints/llm/test_score.py b/tests/entrypoints/llm/test_score.py index 5a1339b2add..f715dacacb8 100644 --- a/tests/entrypoints/llm/test_score.py +++ b/tests/entrypoints/llm/test_score.py @@ -14,14 +14,6 @@ MODEL_NAME = "tomaarsen/Qwen3-Reranker-0.6B-seq-cls" -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - @pytest.fixture(scope="module") def llm(): # pytest caches the fixture so we use weakref.proxy to diff --git a/tests/entrypoints/offline_mode/test_offline_mode.py b/tests/entrypoints/offline_mode/test_offline_mode.py index dd8d63ad319..a154bb1059a 100644 --- a/tests/entrypoints/offline_mode/test_offline_mode.py +++ b/tests/entrypoints/offline_mode/test_offline_mode.py @@ -32,15 +32,16 @@ "tensor_parallel_size": 1, "tokenizer_mode": "mistral", }, - { - "model": "sentence-transformers/all-MiniLM-L12-v2", - "enforce_eager": True, - "gpu_memory_utilization": 0.20, - "max_model_len": 64, - "max_num_batched_tokens": 64, - "max_num_seqs": 64, - "tensor_parallel_size": 1, - }, + # TODO: re-enable once these tests are run with V1 + # { + # "model": "sentence-transformers/all-MiniLM-L12-v2", + # "enforce_eager": True, + # "gpu_memory_utilization": 0.20, + # "max_model_len": 64, + # "max_num_batched_tokens": 64, + # "max_num_seqs": 64, + # "tensor_parallel_size": 1, + # }, ] diff --git a/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py b/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py index 58195f98bd3..0d0ce0be8c5 100644 --- a/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py +++ b/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py @@ -49,8 +49,7 @@ async def transcribe_audio(client, tokenizer, y, sr): return latency, num_output_tokens, transcription.text -async def bound_transcribe(model_name, sem, client, audio, reference): - tokenizer = AutoTokenizer.from_pretrained(model_name) +async def bound_transcribe(sem, client, tokenizer, audio, reference): # Use semaphore to limit concurrent requests. async with sem: result = await transcribe_audio(client, tokenizer, *audio) @@ -63,15 +62,19 @@ async def bound_transcribe(model_name, sem, client, audio, reference): async def process_dataset(model, client, data, concurrent_request): sem = asyncio.Semaphore(concurrent_request) + # Load tokenizer once outside the loop + tokenizer = AutoTokenizer.from_pretrained(model) + # Warmup call as the first `librosa.load` server-side is quite slow. audio, sr = data[0]["audio"]["array"], data[0]["audio"]["sampling_rate"] - _ = await bound_transcribe(model, sem, client, (audio, sr), "") + _ = await bound_transcribe(sem, client, tokenizer, (audio, sr), "") tasks: list[asyncio.Task] = [] for sample in data: audio, sr = sample["audio"]["array"], sample["audio"]["sampling_rate"] task = asyncio.create_task( - bound_transcribe(model, sem, client, (audio, sr), sample["text"])) + bound_transcribe(sem, client, tokenizer, (audio, sr), + sample["text"])) tasks.append(task) return await asyncio.gather(*tasks) diff --git a/tests/entrypoints/openai/test_classification.py b/tests/entrypoints/openai/test_classification.py index 30078fe9025..36c96d76c2e 100644 --- a/tests/entrypoints/openai/test_classification.py +++ b/tests/entrypoints/openai/test_classification.py @@ -226,3 +226,33 @@ def test_pooling(server: RemoteOpenAIServer, model_name: str): }, ) assert response.json()["error"]["type"] == "BadRequestError" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_score(server: RemoteOpenAIServer, model_name: str): + # score api is only enabled for num_labels == 1. + response = requests.post( + server.url_for("score"), + json={ + "model": model_name, + "text_1": "ping", + "text_2": "pong", + }, + ) + assert response.json()["error"]["type"] == "BadRequestError" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_rerank(server: RemoteOpenAIServer, model_name: str): + # rerank api is only enabled for num_labels == 1. + response = requests.post( + server.url_for("rerank"), + json={ + "model": model_name, + "query": "ping", + "documents": ["pong"], + }, + ) + assert response.json()["error"]["type"] == "BadRequestError" diff --git a/tests/entrypoints/openai/test_cli_args.py b/tests/entrypoints/openai/test_cli_args.py index b20838956d7..9a1c0ea13b5 100644 --- a/tests/entrypoints/openai/test_cli_args.py +++ b/tests/entrypoints/openai/test_cli_args.py @@ -27,6 +27,28 @@ def serve_parser(): return make_arg_parser(parser) +### Test config parsing +def test_config_arg_parsing(serve_parser, cli_config_file): + args = serve_parser.parse_args([]) + assert args.port == 8000 + args = serve_parser.parse_args(['--config', cli_config_file]) + assert args.port == 12312 + args = serve_parser.parse_args([ + '--config', + cli_config_file, + '--port', + '9000', + ]) + assert args.port == 9000 + args = serve_parser.parse_args([ + '--port', + '9000', + '--config', + cli_config_file, + ]) + assert args.port == 9000 + + ### Tests for LoRA module parsing def test_valid_key_value_format(serve_parser): # Test old format: name=path diff --git a/tests/entrypoints/openai/test_embedding.py b/tests/entrypoints/openai/test_embedding.py index cf2442a5693..d46ab304ba6 100644 --- a/tests/entrypoints/openai/test_embedding.py +++ b/tests/entrypoints/openai/test_embedding.py @@ -24,14 +24,6 @@ DTYPE = "bfloat16" -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - @pytest.fixture(scope="module") def server(): args = [ diff --git a/tests/entrypoints/openai/test_lora_resolvers.py b/tests/entrypoints/openai/test_lora_resolvers.py index f4801172580..818efd82564 100644 --- a/tests/entrypoints/openai/test_lora_resolvers.py +++ b/tests/entrypoints/openai/test_lora_resolvers.py @@ -47,6 +47,7 @@ class MockModelConfig: allowed_local_media_path: str = "" encoder_config = None generation_config: str = "auto" + skip_tokenizer_init: bool = False def get_diff_sampling_param(self): return self.diff_sampling_param or {} diff --git a/tests/entrypoints/openai/test_rerank.py b/tests/entrypoints/openai/test_rerank.py index 73364294cbc..ce4d6c5f5d3 100644 --- a/tests/entrypoints/openai/test_rerank.py +++ b/tests/entrypoints/openai/test_rerank.py @@ -14,14 +14,6 @@ DTYPE = "bfloat16" -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - @pytest.fixture(scope="module") def server(): args = ["--enforce-eager", "--max-model-len", "100", "--dtype", DTYPE] diff --git a/tests/entrypoints/openai/test_score.py b/tests/entrypoints/openai/test_score.py index cb6ec795ae9..4fafcfb45fa 100644 --- a/tests/entrypoints/openai/test_score.py +++ b/tests/entrypoints/openai/test_score.py @@ -12,15 +12,6 @@ from ...utils import RemoteOpenAIServer - -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - MODELS = [ { "name": "BAAI/bge-reranker-v2-m3", diff --git a/tests/entrypoints/openai/test_token_in_token_out.py b/tests/entrypoints/openai/test_token_in_token_out.py new file mode 100644 index 00000000000..ed003939c44 --- /dev/null +++ b/tests/entrypoints/openai/test_token_in_token_out.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +import tempfile + +import pytest + +from vllm.model_executor.model_loader.weight_utils import ( + download_weights_from_hf) +from vllm.transformers_utils.tokenizer import get_tokenizer + +from ...utils import RemoteOpenAIServer + +MODEL_NAME = "Qwen/Qwen3-0.6B" +MODEL_PATH = os.path.join(tempfile.gettempdir(), "qwen3_06b") + + +@pytest.fixture(scope="module") +def server(): + global MODEL_PATH + MODEL_PATH = download_weights_from_hf( + MODEL_NAME, + allow_patterns=["*"], + cache_dir=MODEL_PATH, + ignore_patterns=["tokenizer*", "vocab*", "*.safetensors"]) + args = [ + "--max-model-len", + "2048", + "--max-num-seqs", + "128", + "--enforce-eager", + "--skip-tokenizer-init", + "--load-format", + "dummy", + ] + with RemoteOpenAIServer(MODEL_PATH, args) as remote_server: + yield remote_server + + +@pytest.mark.asyncio +async def test_token_in_token_out_and_logprobs(server): + """ + Test token-in-token-out and token_ids align with prompt_logprobs + & logprobs when return_tokens_as_token_ids is enabled. + """ + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + text = "Hello, world! How are you today?" + token_ids = tokenizer.encode(text) + async with server.get_async_client() as client: + # Test with both return_token_ids and return_tokens_as_token_ids enabled + completion = await client.completions.create( + model=MODEL_PATH, + prompt=token_ids, + max_tokens=20, + temperature=0, + echo=True, + extra_body={ + "return_token_ids": True, + }, + ) + + # Verify all fields are present + assert (completion.choices[0].token_ids is not None + and 0 < len(completion.choices[0].token_ids) <= 20) + assert completion.choices[0].prompt_token_ids is not None + + # Decode prompt tokens + if completion.choices[0].prompt_token_ids: + prompt_text = tokenizer.decode( + completion.choices[0].prompt_token_ids) + # The decoded prompt should match or close to original prompt + assert prompt_text == text diff --git a/tests/kernels/attention/test_cache.py b/tests/kernels/attention/test_cache.py index cbf11da63ca..69e96dfd2cb 100644 --- a/tests/kernels/attention/test_cache.py +++ b/tests/kernels/attention/test_cache.py @@ -790,6 +790,78 @@ def test_gather_and_maybe_dequant_cache_mla(kv_lora_rank, qk_rope_head_dim, torch.testing.assert_close(dst, expected) +@pytest.mark.parametrize("kv_lora_rank", [512]) +@pytest.mark.parametrize("qk_rope_head_dim", [64]) +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("num_blocks", [1024]) +@pytest.mark.parametrize("max_seq_len", [512]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("dtype", [torch.float32]) +@pytest.mark.parametrize("kv_cache_dtype", + ["auto"]) # You can also test "fp8" if needed. +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_cp_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size, + num_blocks, max_seq_len, batch_size, dtype, + kv_cache_dtype, device): + entry_size = kv_lora_rank + qk_rope_head_dim + src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, + kv_cache_dtype, device) + _fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype) + + seq_len_tensor = torch.randint(0, + max_seq_len + 1, (batch_size, ), + device=device) + + total_tokens = seq_len_tensor.sum() + cu_seq_lens = torch.empty((batch_size + 1), + dtype=torch.int32, + device=device) + cu_seq_lens[0] = 0 + cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32) + print("seq_len_tensor", seq_len_tensor) + + tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size + block_table = torch.empty((batch_size, num_blocks), + dtype=torch.int32, + device=device) + + for b in range(batch_size): + perm = torch.randperm(num_blocks, device=device) + block_table[b, :] = perm + + dst = torch.zeros((total_tokens, entry_size), + dtype=src_cache.dtype, + device=device) + + expected_batches = [] + for b in range(batch_size): + s = seq_len_tensor[b] + if s == 0: + continue + tot = tot_blocks_tensor[b] + blocks = block_table[b, :tot].tolist() + + gathered_rows = [] + for i in range(tot - 1): + gathered_rows.append(src_cache[blocks[i]]) + remaining = s - (tot - 1) * block_size + gathered_rows.append(src_cache[blocks[-1], :remaining, :]) + + batch_expected = torch.cat(gathered_rows, dim=0) + expected_batches.append(batch_expected) + expected = torch.cat(expected_batches, dim=0) + + opcheck( + torch.ops._C_cache_ops.cp_gather_cache, + (src_cache, dst, block_table, cu_seq_lens, batch_size, None), + test_utils=DEFAULT_OPCHECK_TEST_UTILS, + ) + + ops.cp_gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size) + torch.testing.assert_close(dst, expected) + + @pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS) @pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS) @pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA) diff --git a/tests/kernels/moe/test_block_fp8.py b/tests/kernels/moe/test_block_fp8.py index 9e4eaf221f2..ecc57acc679 100644 --- a/tests/kernels/moe/test_block_fp8.py +++ b/tests/kernels/moe/test_block_fp8.py @@ -16,7 +16,7 @@ fused_topk, modular_triton_fused_moe) from vllm.platforms import current_platform from vllm.utils import has_deep_gemm -from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used dg_available = has_deep_gemm() @@ -226,8 +226,7 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed, @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") -@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(), - reason="Not E8M0 scale MOE") +@pytest.mark.skipif(is_deep_gemm_e8m0_used(), reason="Not E8M0 scale MOE") @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch): diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py index 1e922be47f2..36a98522a65 100644 --- a/tests/kernels/moe/test_deepep_deepgemm_moe.py +++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py @@ -20,8 +20,7 @@ FusedMoEModularKernel) from vllm.platforms import current_platform from vllm.utils import has_deep_ep, has_deep_gemm -from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used, - is_deep_gemm_supported) +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used, is_deep_gemm_supported from ...utils import multi_gpu_test from .parallel_utils import ProcessGroupInfo, parallel_launch @@ -374,7 +373,7 @@ def _test_deepep_deepgemm_moe( @multi_gpu_test(num_gpus=2) @requires_deep_ep @requires_deep_gemm -@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(), +@pytest.mark.skipif(is_deep_gemm_e8m0_used(), reason="Skipping test for Blackwell DeepGEMM") def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int, topk: int, world_dp_size: tuple[int, int]): @@ -432,7 +431,7 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int, @multi_gpu_test(num_gpus=2) @requires_deep_ep @requires_deep_gemm -@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(), +@pytest.mark.skipif(is_deep_gemm_e8m0_used(), reason="Skipping test for Blackwell DeepGEMM") def test_ll_deepep_deepgemm_moe( mnk: tuple[int, int, int], diff --git a/tests/kernels/quantization/test_silu_nvfp4_quant_fusion.py b/tests/kernels/quantization/test_silu_nvfp4_quant_fusion.py new file mode 100644 index 00000000000..969f14cc3fe --- /dev/null +++ b/tests/kernels/quantization/test_silu_nvfp4_quant_fusion.py @@ -0,0 +1,126 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch + +from tests.kernels.utils import opcheck +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types + +if not current_platform.has_device_capability(100): + pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.", + allow_module_level=True) + +DTYPES = [torch.float16, torch.bfloat16] +SHAPES = [(128, 64), (128, 128), (256, 64), (256, 128)] +SEEDS = [42] +CUDA_DEVICES = ['cuda:0'] + +FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max + +BLOCK_SIZE = 16 + + +def ref_impl(silu_and_mul: SiluAndMul, x: torch.Tensor, + global_scale: torch.Tensor, + ref_output_scale: torch.Tensor) -> torch.Tensor: + silu_and_mul_out = silu_and_mul.forward_native(x) + assert not current_platform.is_rocm() + assert silu_and_mul_out.ndim >= 1, ( + f'input.ndim needs to be >= 1, but got {silu_and_mul_out.ndim}.') + other_dims = 1 if silu_and_mul_out.ndim == 1 else -1 + silu_and_mul_out = silu_and_mul_out.reshape(other_dims, + silu_and_mul_out.shape[-1]) + m, n = silu_and_mul_out.shape + device = silu_and_mul_out.device + + # Two fp4 values will be packed into an uint8. + out = torch.empty((m, n // 2), device=device, dtype=torch.uint8) + + output_scale = ref_output_scale + + torch.ops._C.scaled_fp4_quant(out, silu_and_mul_out, output_scale, + global_scale) + + return out, output_scale + + +def ops_impl(x: torch.Tensor, global_scale: torch.Tensor, + ref_output_scale: torch.Tensor) -> torch.Tensor: + out_shape = (x.shape[0], x.shape[1] // 4) + output_scale = ref_output_scale + out = torch.empty(out_shape, dtype=torch.uint8, device=x.device) + torch.ops._C.silu_and_mul_nvfp4_quant(out, output_scale, x, global_scale) + return out, output_scale + + +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("shape", SHAPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_quantize_to_fp4( + dtype: torch.dtype, + shape: tuple[int, int], + seed: int, + device: str, +) -> None: + current_platform.seed_everything(seed) + torch.set_default_device(device) + + m, n = shape + + x = torch.randn((m, n), dtype=dtype) + tensor_amax = torch.abs(x).max().to(torch.float32) + global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax + + block_size = 16 + + assert n % block_size == 0, ( + f'last dim has to be multiple of 16, but got {n}.') + assert x.dtype in (torch.float16, torch.bfloat16), ( + f'input.dtype needs to be fp16 or bf16 but got {x.dtype}.') + + round_up = lambda x, y: (x + y - 1) // y * y + rounded_m = round_up(x.shape[0], 128) + scale_n = x.shape[1] // (2 * block_size) + rounded_n = round_up(scale_n, 4) + output_scale = torch.empty((rounded_m, rounded_n // 4), + device=x.device, + dtype=torch.int32) + + layer = SiluAndMul() + + ref_out, ref_out_scale = ref_impl(layer, x, global_scale, output_scale) + + fusion_out, fusion_out_scale = ops_impl(x, global_scale, output_scale) + + assert ref_out.dtype == torch.uint8 + assert fusion_out.dtype == torch.uint8 + assert ref_out.shape == fusion_out.shape + + assert ref_out_scale.dtype == torch.int32 + assert fusion_out_scale.dtype == torch.int32 + assert ref_out_scale.shape == fusion_out_scale.shape + + # Allow up to 2% of mismatched values since BF16 has accuracy issues. + mis_threshold = 0.02 + atol = 0.4 + rtol = 0.4 + ref_logits = ref_out[-1] + fusion_logits = fusion_out[-1] + + mis_count = torch.sum( + torch.abs(fusion_logits - ref_logits) > (atol + + rtol * torch.abs(ref_logits))) + mis_ratio = mis_count / fusion_logits.numel() + + assert mis_ratio < mis_threshold, \ + f"Mismatch ratio {mis_ratio} exceeds threshold {mis_threshold}" + + torch.testing.assert_close(ref_out_scale, fusion_out_scale) + + opcheck(torch.ops._C.silu_and_mul_nvfp4_quant, + (fusion_out, fusion_out_scale, x, global_scale)) diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index cba573b63c0..3475993ff8f 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -216,11 +216,6 @@ def tinyllama_lora_files(): return snapshot_download(repo_id="jashing/tinyllama-colorist-lora") -@pytest.fixture(scope="session") -def phi2_lora_files(): - return snapshot_download(repo_id="isotr0py/phi-2-test-sql-lora") - - @pytest.fixture def reset_default_device(): """ diff --git a/tests/lora/test_baichuan.py b/tests/lora/test_baichuan.py deleted file mode 100644 index 774ebb9db21..00000000000 --- a/tests/lora/test_baichuan.py +++ /dev/null @@ -1,112 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest - -import vllm -from vllm.distributed import cleanup_dist_env_and_memory -from vllm.lora.request import LoRARequest - -MODEL_PATH = "baichuan-inc/Baichuan-7B" - -PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501 - - -def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: - prompts = [ - PROMPT_TEMPLATE.format(query="How many singers do we have?"), - PROMPT_TEMPLATE.format( - query= - "What is the average, minimum, and maximum age of all singers from France?" # noqa: E501 - ), - PROMPT_TEMPLATE.format( - query= - "Show name, country, age for all singers ordered by age from the oldest to the youngest." # noqa: E501 - ), - ] - print(prompts) - sampling_params = vllm.SamplingParams(temperature=0, max_tokens=256) - outputs = llm.generate( - prompts, - sampling_params, - lora_request=LoRARequest(str(lora_id), lora_id, lora_path) - if lora_id else None) - # Print the outputs. - generated_texts: list[str] = [] - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text.strip() - generated_texts.append(generated_text) - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - return generated_texts - - -def test_baichuan_lora(baichuan_lora_files): - llm = vllm.LLM(MODEL_PATH, - max_model_len=1024, - enable_lora=True, - max_loras=4, - max_lora_rank=64, - trust_remote_code=True) - - expected_lora_output = [ - "SELECT count(*) FROM singer", - "SELECT avg(age) , min(age) , max(age) FROM singer WHERE Country = 'France'", # noqa: E501 - "SELECT name , country , age FROM singer ORDER BY age ASC", - ] - - output1 = do_sample(llm, baichuan_lora_files, lora_id=1) - for i in range(len(expected_lora_output)): - assert output1[i] == expected_lora_output[i] - output2 = do_sample(llm, baichuan_lora_files, lora_id=2) - for i in range(len(expected_lora_output)): - assert output2[i] == expected_lora_output[i] - - -@pytest.mark.parametrize("fully_sharded", [True, False]) -def test_baichuan_tensor_parallel_equality(baichuan_lora_files, - num_gpus_available, fully_sharded): - if num_gpus_available < 4: - pytest.skip(f"Not enough GPUs for tensor parallelism {4}") - - llm_tp1 = vllm.LLM(MODEL_PATH, - enable_lora=True, - max_num_seqs=16, - max_loras=4, - max_lora_rank=64, - trust_remote_code=True, - fully_sharded_loras=fully_sharded) - output_tp1 = do_sample(llm_tp1, baichuan_lora_files, lora_id=1) - - del llm_tp1 - cleanup_dist_env_and_memory() - - llm_tp2 = vllm.LLM(MODEL_PATH, - enable_lora=True, - max_num_seqs=16, - max_loras=4, - max_lora_rank=64, - tensor_parallel_size=2, - trust_remote_code=True, - fully_sharded_loras=fully_sharded) - output_tp2 = do_sample(llm_tp2, baichuan_lora_files, lora_id=2) - - del llm_tp2 - cleanup_dist_env_and_memory() - - assert output_tp1 == output_tp2 - - llm_tp4 = vllm.LLM(MODEL_PATH, - enable_lora=True, - max_num_seqs=16, - max_loras=4, - max_lora_rank=64, - tensor_parallel_size=4, - trust_remote_code=True, - fully_sharded_loras=fully_sharded) - output_tp4 = do_sample(llm_tp4, baichuan_lora_files, lora_id=2) - - del llm_tp4 - cleanup_dist_env_and_memory() - - assert output_tp1 == output_tp4 diff --git a/tests/lora/test_chatglm3_tp.py b/tests/lora/test_chatglm3_tp.py index fb00e7b65b0..5cffb8cfcc2 100644 --- a/tests/lora/test_chatglm3_tp.py +++ b/tests/lora/test_chatglm3_tp.py @@ -87,6 +87,9 @@ def test_chatglm3_lora_tp4(chatglm3_lora_files): @multi_gpu_test(num_gpus=4) @create_new_process_for_each_test() def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files): + # https://github.com/NVIDIA/nccl/issues/1790, set a lower value for + # gpu_memory_utilization here because NCCL >= 2.26.3 seems to use + # more GPU memory causing vLLM to OOM llm = vllm.LLM(MODEL_PATH, max_model_len=1024, enable_lora=True, @@ -95,7 +98,8 @@ def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files): tensor_parallel_size=4, trust_remote_code=True, fully_sharded_loras=True, - enable_chunked_prefill=True) + enable_chunked_prefill=True, + gpu_memory_utilization=0.85) output1 = do_sample(llm, chatglm3_lora_files, lora_id=1) for i in range(len(EXPECTED_LORA_OUTPUT)): assert output1[i] == EXPECTED_LORA_OUTPUT[i] diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 92db023babc..6e2dda464d8 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -243,7 +243,7 @@ def check_punica_wrapper(punica_wrapper) -> bool: @torch.inference_mode() -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("num_loras", [1, 2, 4]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) @pytest.mark.parametrize("stage", STAGES) @@ -347,7 +347,7 @@ def create_random_embedding_layer(): @torch.inference_mode() # @pytest.mark.skip( # reason="Fails when loras are in any slot other than the first.") -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("num_loras", [1, 2, 4]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) @pytest.mark.parametrize("stage", STAGES) @@ -486,7 +486,7 @@ def create_random_embedding_layer(): @torch.inference_mode() -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("num_loras", [1, 2, 4]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 256512]) @pytest.mark.parametrize("stage", STAGES) @@ -620,12 +620,15 @@ def _pretest(): @torch.inference_mode() -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("num_loras", [1, 2, 4]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("stage", STAGES) -@pytest.mark.parametrize("bias_enabled", [True, False]) -def test_linear_replicated(dist_init, num_loras, device, stage, - bias_enabled) -> None: +def test_linear_replicated( + dist_init, + num_loras, + device, + stage, +) -> None: if current_platform.is_cuda_alike(): torch.cuda.set_device(device) @@ -634,10 +637,11 @@ def test_linear_replicated(dist_init, num_loras, device, stage, torch.set_default_device(device) punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) assert check_punica_wrapper(punica_wrapper) - lora_config = LoRAConfig(max_loras=max_loras, - max_lora_rank=8, - lora_dtype=torch.float16, - bias_enabled=bias_enabled) + lora_config = LoRAConfig( + max_loras=max_loras, + max_lora_rank=8, + lora_dtype=torch.float16, + ) def create_random_linear_replicated_layer(): @@ -651,10 +655,6 @@ def create_random_linear_replicated_layer(): lora_linear.create_lora_weights(max_loras, lora_config) assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( lora_linear.lora_b_stacked) == 1) - if bias_enabled: - assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices - else: - assert lora_linear.lora_bias_stacked is None return linear, lora_linear for i in range(NUM_RANDOM_SEEDS): @@ -734,14 +734,13 @@ def create_random_linear_replicated_layer(): @torch.inference_mode() -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("num_loras", [1, 2, 4]) @pytest.mark.parametrize("orientation", ["row", "column"]) @pytest.mark.parametrize("fully_shard", [True, False]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("stage", STAGES) -@pytest.mark.parametrize("bias_enabled", [True, False]) def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, - device, stage, bias_enabled) -> None: + device, stage) -> None: if current_platform.is_cuda_alike(): torch.cuda.set_device(device) @@ -750,11 +749,12 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, torch.set_default_device(device) punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) assert check_punica_wrapper(punica_wrapper) - lora_config = LoRAConfig(max_loras=max_loras, - max_lora_rank=8, - fully_sharded_loras=fully_shard, - lora_dtype=torch.float16, - bias_enabled=bias_enabled) + lora_config = LoRAConfig( + max_loras=max_loras, + max_lora_rank=8, + fully_sharded_loras=fully_shard, + lora_dtype=torch.float16, + ) def create_random_linear_parallel_layer(): if orientation == "row": @@ -777,10 +777,7 @@ def create_random_linear_parallel_layer(): lora_linear.create_lora_weights(max_loras, lora_config) assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( lora_linear.lora_b_stacked) == 1) - if bias_enabled: - assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices - else: - assert lora_linear.lora_bias_stacked is None + return linear, lora_linear for i in range(NUM_RANDOM_SEEDS): @@ -860,14 +857,13 @@ def create_random_linear_parallel_layer(): @torch.inference_mode() -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("num_loras", [1, 2, 4]) @pytest.mark.parametrize("repeats", [1, 2, 3]) @pytest.mark.parametrize("fully_shard", [True, False]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("stage", STAGES) -@pytest.mark.parametrize("bias_enabled", [True, False]) def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, - device, stage, bias_enabled) -> None: + device, stage) -> None: if current_platform.is_cuda_alike(): torch.cuda.set_device(device) @@ -876,11 +872,12 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, torch.set_default_device(device) punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) assert check_punica_wrapper(punica_wrapper) - lora_config = LoRAConfig(max_loras=max_loras, - max_lora_rank=8, - fully_sharded_loras=fully_shard, - lora_dtype=torch.float16, - bias_enabled=bias_enabled) + lora_config = LoRAConfig( + max_loras=max_loras, + max_lora_rank=8, + fully_sharded_loras=fully_shard, + lora_dtype=torch.float16, + ) def create_column_parallel_packed_layer(): if repeats == 2: @@ -924,10 +921,7 @@ class FakeConfig: model_config=FakeConfig()) assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( lora_linear.lora_b_stacked) == n_slices) - if bias_enabled: - assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices - else: - assert lora_linear.lora_bias_stacked is None + return linear, lora_linear for i in range(NUM_RANDOM_SEEDS): diff --git a/tests/lora/test_multi_loras_with_tp.py b/tests/lora/test_llm_with_multi_loras.py similarity index 80% rename from tests/lora/test_multi_loras_with_tp.py rename to tests/lora/test_llm_with_multi_loras.py index fe9bd3f2695..3d8dd512a20 100644 --- a/tests/lora/test_multi_loras_with_tp.py +++ b/tests/lora/test_llm_with_multi_loras.py @@ -1,8 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ -Script to test multi loras service with tp >= 2 +This script contains: +1. test multi loras service with tp >= 2 +2. test multi loras request """ +import pytest + from tests.utils import multi_gpu_test from vllm import LLM, SamplingParams from vllm.lora.request import LoRARequest @@ -156,3 +160,34 @@ def check_outputs(outputs: str, expected: str): output_text = call_llm_get_outputs(prompt, "Alice") check_outputs(output_text, expected_output) + + +def test_multiple_lora_requests(): + llm = LLM( + model=MODEL_PATH, + enable_lora=True, + max_loras=4, + max_lora_rank=LORA_RANK, + max_model_len=512, + gpu_memory_utilization=0.5, + enforce_eager=True, + ) + PROMPTS = ["Hello, my name is"] * 2 + LORA_NAME = "Alice" + lora_request = [ + LoRARequest(LORA_NAME + str(idx), idx + 1, + LORA_NAME_PATH_MAP[LORA_NAME]) + for idx in range(len(PROMPTS)) + ] + # Multiple SamplingParams should be matched with each prompt + outputs = llm.generate(PROMPTS, lora_request=lora_request) + assert len(PROMPTS) == len(outputs) + + # Exception raised, if the size of params does not match the size of prompts + with pytest.raises(ValueError): + outputs = llm.generate(PROMPTS, lora_request=lora_request[:1]) + + # Single LoRARequest should be applied to every prompt + single_lora_request = lora_request[0] + outputs = llm.generate(PROMPTS, lora_request=single_lora_request) + assert len(PROMPTS) == len(outputs) diff --git a/tests/lora/test_phi.py b/tests/lora/test_phi.py deleted file mode 100644 index 3090941e636..00000000000 --- a/tests/lora/test_phi.py +++ /dev/null @@ -1,71 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import vllm -from vllm.lora.request import LoRARequest - -MODEL_PATH = "microsoft/phi-2" - -PROMPT_TEMPLATE = "### Instruct: {sql_prompt}\n\n### Context: {context}\n\n### Output:" # noqa: E501 - - -def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: - prompts = [ - PROMPT_TEMPLATE.format( - sql_prompt= - "Which catalog publisher has published the most catalogs?", - context="CREATE TABLE catalogs (catalog_publisher VARCHAR);"), - PROMPT_TEMPLATE.format( - sql_prompt= - "Which trip started from the station with the largest dock count? Give me the trip id.", # noqa: E501 - context= - "CREATE TABLE trip (id VARCHAR, start_station_id VARCHAR); CREATE TABLE station (id VARCHAR, dock_count VARCHAR);" # noqa: E501 - ), - PROMPT_TEMPLATE.format( - sql_prompt= - "How many marine species are found in the Southern Ocean?", # noqa: E501 - context= - "CREATE TABLE marine_species (name VARCHAR(50), common_name VARCHAR(50), location VARCHAR(50));" # noqa: E501 - ), - ] - sampling_params = vllm.SamplingParams(temperature=0, - max_tokens=64, - stop="### End") - outputs = llm.generate( - prompts, - sampling_params, - lora_request=LoRARequest(str(lora_id), lora_id, lora_path) - if lora_id else None, - ) - # Print the outputs. - generated_texts: list[str] = [] - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text.strip() - generated_texts.append(generated_text) - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - return generated_texts - - -def test_phi2_lora(phi2_lora_files): - # We enable enforce_eager=True here to reduce VRAM usage for lora-test CI, - # Otherwise, the lora-test will fail due to CUDA OOM. - llm = vllm.LLM(MODEL_PATH, - max_model_len=1024, - enable_lora=True, - max_loras=2, - enforce_eager=True, - enable_chunked_prefill=True) - - expected_lora_output = [ - "SELECT catalog_publisher, COUNT(*) as num_catalogs FROM catalogs GROUP BY catalog_publisher ORDER BY num_catalogs DESC LIMIT 1;", # noqa: E501 - "SELECT trip.id FROM trip JOIN station ON trip.start_station_id = station.id WHERE station.dock_count = (SELECT MAX(dock_count) FROM station);", # noqa: E501 - "SELECT COUNT(*) FROM marine_species WHERE location = 'Southern Ocean';", # noqa: E501 - ] - - output1 = do_sample(llm, phi2_lora_files, lora_id=1) - for i in range(len(expected_lora_output)): - assert output1[i].startswith(expected_lora_output[i]) - output2 = do_sample(llm, phi2_lora_files, lora_id=2) - for i in range(len(expected_lora_output)): - assert output2[i].startswith(expected_lora_output[i]) diff --git a/tests/models/language/generation/test_common.py b/tests/models/language/generation/test_common.py index 57382914bfe..4c4434c9414 100644 --- a/tests/models/language/generation/test_common.py +++ b/tests/models/language/generation/test_common.py @@ -92,7 +92,8 @@ pytest.param( "allenai/OLMoE-1B-7B-0924-Instruct", marks=[pytest.mark.cpu_model], - ) + ), + pytest.param("swiss-ai/Apertus-8B"), # apertus ]) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [5]) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 7e7cc893ec8..31ca3a6f0f9 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -100,21 +100,19 @@ def test_models( else: hf_outputs = None - if model not in V0_UNSUPPORTED_MODELS: - with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: - vllm_v0_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) - else: - vllm_v0_outputs = None + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "0") + if model not in V0_UNSUPPORTED_MODELS: + with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: + vllm_v0_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + else: + vllm_v0_outputs = None if model in V1_SUPPORTED_MODELS: - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - with vllm_runner(model, - max_num_seqs=MAX_NUM_SEQS, - enable_prefix_caching=False) as vllm_model: - vllm_v1_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: + vllm_v1_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) else: vllm_v1_outputs = None @@ -137,7 +135,7 @@ def test_models( ) -@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS) +@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) def test_batching( @@ -147,10 +145,6 @@ def test_batching( max_tokens: int, num_logprobs: int, ) -> None: - if model in V0_UNSUPPORTED_MODELS: - pytest.skip( - f"Unsupported V0 Engine. Skipping `test_batching` on {model}.") - try: model_info = HF_EXAMPLE_MODELS.find_hf_info(model) model_info.check_available_online(on_fail="skip") @@ -188,29 +182,32 @@ def test_chunked_prefill( max_tokens: int, num_logprobs: int, chunked_prefill_token_size: int, + monkeypatch, ) -> None: max_num_seqs = chunked_prefill_token_size max_num_batched_tokens = chunked_prefill_token_size - with vllm_runner(model, - enable_chunked_prefill=True, - max_num_batched_tokens=max_num_batched_tokens, - max_num_seqs=max_num_seqs) as vllm_model: - chunked = vllm_model.generate_greedy_logprobs(example_prompts, - max_tokens, num_logprobs) + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "0") + with vllm_runner(model, + enable_chunked_prefill=True, + max_num_batched_tokens=max_num_batched_tokens, + max_num_seqs=max_num_seqs) as vllm_model: + chunked = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) - with vllm_runner(model, - enable_chunked_prefill=False, - max_num_seqs=max_num_seqs) as vllm_model: - non_chunked = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + with vllm_runner(model, + enable_chunked_prefill=False, + max_num_seqs=max_num_seqs) as vllm_model: + non_chunked = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) - check_logprobs_close( - outputs_0_lst=chunked, - outputs_1_lst=non_chunked, - name_0="chunked", - name_1="non_chunked", - ) + check_logprobs_close( + outputs_0_lst=chunked, + outputs_1_lst=non_chunked, + name_0="chunked", + name_1="non_chunked", + ) @pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) @@ -281,25 +278,29 @@ def test_models_preemption_recompute( example_prompts, model: str, max_tokens: int, + monkeypatch, ) -> None: """ Tests that outputs are identical with and w/o preemptions (recompute). """ - with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: - scheduler = vllm_model.llm.llm_engine.scheduler[0] - scheduler.ENABLE_ARTIFICIAL_PREEMPT = True - preempt_vllm_outputs = vllm_model.generate_greedy( - example_prompts, max_tokens) - - scheduler.ENABLE_ARTIFICIAL_PREEMPT = False - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - - check_outputs_equal( - outputs_0_lst=preempt_vllm_outputs, - outputs_1_lst=vllm_outputs, - name_0="vllm_preepmtions", - name_1="vllm", - ) + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "0") + with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: + scheduler = vllm_model.llm.llm_engine.scheduler[0] + scheduler.ENABLE_ARTIFICIAL_PREEMPT = True + preempt_vllm_outputs = vllm_model.generate_greedy( + example_prompts, max_tokens) + + scheduler.ENABLE_ARTIFICIAL_PREEMPT = False + vllm_outputs = vllm_model.generate_greedy(example_prompts, + max_tokens) + + check_outputs_equal( + outputs_0_lst=preempt_vllm_outputs, + outputs_1_lst=vllm_outputs, + name_0="vllm_preepmtions", + name_1="vllm", + ) @pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) @@ -402,24 +403,18 @@ def test_full_cuda_graph( else: hf_outputs = None - if model not in V0_UNSUPPORTED_MODELS: - with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: - vllm_v0_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) - else: - vllm_v0_outputs = None - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - if model in HYBRID_MODELS: - # required due to reorder_batch behaviour - m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") - with vllm_runner(model, - max_num_seqs=MAX_NUM_SEQS, - compilation_config={'full_cuda_graph': True}, - enable_prefix_caching=False) as vllm_model: - vllm_v1_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + m.setenv("VLLM_USE_V1", "0") + if model not in V0_UNSUPPORTED_MODELS: + with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: + vllm_v0_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + else: + vllm_v0_outputs = None + + with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: + vllm_v1_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) if hf_outputs is not None and vllm_v0_outputs is not None: check_logprobs_close( @@ -466,24 +461,20 @@ def test_fp32_state( else: hf_outputs = None - with vllm_runner(model, - max_num_seqs=MAX_NUM_SEQS, - mamba_ssm_cache_dtype="float32") as vllm_model: - vllm_v0_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - if model in HYBRID_MODELS: - # required due to reorder_batch behaviour - m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") + m.setenv("VLLM_USE_V1", "0") with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS, - mamba_ssm_cache_dtype="float32", - enable_prefix_caching=False) as vllm_model: - vllm_v1_outputs = vllm_model.generate_greedy_logprobs( + mamba_ssm_cache_dtype="float32") as vllm_model: + vllm_v0_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) + with vllm_runner(model, + max_num_seqs=MAX_NUM_SEQS, + mamba_ssm_cache_dtype="float32") as vllm_model: + vllm_v1_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + if hf_outputs is not None: check_logprobs_close( outputs_0_lst=hf_outputs, diff --git a/tests/models/language/pooling/embed_utils.py b/tests/models/language/pooling/embed_utils.py index 61c5fcab4f8..a74ad2aa259 100644 --- a/tests/models/language/pooling/embed_utils.py +++ b/tests/models/language/pooling/embed_utils.py @@ -51,6 +51,9 @@ def correctness_test_embed_models(hf_runner, vllm_extra_kwargs = vllm_extra_kwargs or {} vllm_extra_kwargs["dtype"] = model_info.dtype + if model_info.hf_overrides is not None: + vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides + with vllm_runner(model_info.name, runner="pooling", max_model_len=None, diff --git a/tests/models/language/pooling/mteb_utils.py b/tests/models/language/pooling/mteb_utils.py index 4a1f8a53d02..640858125bf 100644 --- a/tests/models/language/pooling/mteb_utils.py +++ b/tests/models/language/pooling/mteb_utils.py @@ -172,6 +172,9 @@ def mteb_test_embed_models(hf_runner, vllm_extra_kwargs = vllm_extra_kwargs or {} vllm_extra_kwargs["dtype"] = model_info.dtype + if model_info.hf_overrides is not None: + vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides + with vllm_runner(model_info.name, runner="pooling", max_model_len=None, @@ -284,6 +287,9 @@ def mteb_test_rerank_models(hf_runner, vllm_extra_kwargs = vllm_extra_kwargs or {} vllm_extra_kwargs["dtype"] = model_info.dtype + if model_info.hf_overrides is not None: + vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides + with vllm_runner(model_info.name, runner="pooling", max_model_len=None, diff --git a/tests/models/language/pooling/test_bge_reranker_v2_gemma.py b/tests/models/language/pooling/test_bge_reranker_v2_gemma.py index 206524d7caa..f473e0ba01f 100644 --- a/tests/models/language/pooling/test_bge_reranker_v2_gemma.py +++ b/tests/models/language/pooling/test_bge_reranker_v2_gemma.py @@ -13,7 +13,14 @@ RERANK_MODELS = [ LASTPoolingRerankModelInfo("BAAI/bge-reranker-v2-gemma", - architecture="GemmaForSequenceClassification"), + architecture="GemmaForSequenceClassification", + hf_overrides={ + "architectures": + ["GemmaForSequenceClassification"], + "classifier_from_token": ["Yes"], + "method": + "no_post_processing", + }), ] PROMPT = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'." # noqa: E501 @@ -119,22 +126,9 @@ def predict( @pytest.mark.parametrize("model_info", RERANK_MODELS) -def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo, - monkeypatch) -> None: - monkeypatch.setenv("VLLM_USE_V1", "0") - - assert model_info.architecture == "GemmaForSequenceClassification" - - vllm_extra_kwargs: dict[str, Any] = { - "hf_overrides": { - "architectures": ["GemmaForSequenceClassification"], - "classifier_from_token": ["Yes"], - "method": "no_post_processing", - } - } +def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None: mteb_test_rerank_models(GemmaRerankerHfRunner, vllm_runner, model_info, - vllm_extra_kwargs, vllm_mteb_encoder=GemmaMtebEncoder) diff --git a/tests/models/language/pooling/test_embedding.py b/tests/models/language/pooling/test_embedding.py index 2dd35c41515..f918b2b91bc 100644 --- a/tests/models/language/pooling/test_embedding.py +++ b/tests/models/language/pooling/test_embedding.py @@ -10,14 +10,6 @@ from ...utils import check_embeddings_close, check_transformers_version -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - @pytest.mark.parametrize( "model", [ @@ -32,21 +24,15 @@ def v1(run_with_both_engines): "intfloat/e5-mistral-7b-instruct", # CPU v1 doesn't support sliding window marks=[pytest.mark.core_model]), - # the qwen models interfere with each other (see PR - # https://github.com/vllm-project/vllm/pull/18720). - # To avoid this problem, for now we skip v0 since it will be - # deprecated anyway. pytest.param("ssmits/Qwen2-7B-Instruct-embed-base", - marks=[pytest.mark.skip_v0, pytest.mark.cpu_model]), + marks=[pytest.mark.cpu_model]), # [Encoder-only] pytest.param("BAAI/bge-base-en-v1.5", marks=[pytest.mark.core_model]), pytest.param("sentence-transformers/all-MiniLM-L12-v2"), pytest.param("intfloat/multilingual-e5-small"), - pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct", - marks=[pytest.mark.skip_v1]), + pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"), # [Cross-Encoder] - pytest.param("sentence-transformers/stsb-roberta-base-v2", - marks=[pytest.mark.skip_v1]), + pytest.param("sentence-transformers/stsb-roberta-base-v2"), ], ) def test_models( diff --git a/tests/models/language/pooling/test_gte.py b/tests/models/language/pooling/test_gte.py index f805a64103c..9911620c018 100644 --- a/tests/models/language/pooling/test_gte.py +++ b/tests/models/language/pooling/test_gte.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any import pytest @@ -33,12 +32,15 @@ ########### NewModel CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-multilingual-base", architecture="GteNewModel", + hf_overrides={"architectures": ["GteNewModel"]}, enable_test=True), CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5", architecture="GteNewModel", + hf_overrides={"architectures": ["GteNewModel"]}, enable_test=True), CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5", architecture="GteNewModel", + hf_overrides={"architectures": ["GteNewModel"]}, enable_test=True), ########### Qwen2ForCausalLM LASTPoolingEmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct", @@ -60,11 +62,16 @@ ] RERANK_MODELS = [ - # classifier_pooling: mean CLSPoolingRerankModelInfo( + # classifier_pooling: mean "Alibaba-NLP/gte-reranker-modernbert-base", architecture="ModernBertForSequenceClassification", enable_test=True), + CLSPoolingRerankModelInfo( + "Alibaba-NLP/gte-multilingual-reranker-base", + architecture="GteNewForSequenceClassification", + hf_overrides={"architectures": ["GteNewForSequenceClassification"]}, + enable_test=True), ] @@ -75,12 +82,7 @@ def test_embed_models_mteb(hf_runner, vllm_runner, check_transformers_version(model_info.name, max_transformers_version="4.53.2") - vllm_extra_kwargs: dict[str, Any] = {} - if model_info.architecture == "GteNewModel": - vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]} - - mteb_test_embed_models(hf_runner, vllm_runner, model_info, - vllm_extra_kwargs) + mteb_test_embed_models(hf_runner, vllm_runner, model_info) @pytest.mark.parametrize("model_info", MODELS) @@ -91,12 +93,8 @@ def test_embed_models_correctness(hf_runner, vllm_runner, check_transformers_version(model_info.name, max_transformers_version="4.53.2") - vllm_extra_kwargs: dict[str, Any] = {} - if model_info.architecture == "GteNewModel": - vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]} - correctness_test_embed_models(hf_runner, vllm_runner, model_info, - example_prompts, vllm_extra_kwargs) + example_prompts) @pytest.mark.parametrize("model_info", RERANK_MODELS) diff --git a/tests/models/language/pooling/test_mxbai_rerank.py b/tests/models/language/pooling/test_mxbai_rerank.py index 480bd5e4567..73823deeff4 100644 --- a/tests/models/language/pooling/test_mxbai_rerank.py +++ b/tests/models/language/pooling/test_mxbai_rerank.py @@ -10,12 +10,20 @@ from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo from .mteb_utils import mteb_test_rerank_models +mxbai_rerank_hf_overrides = { + "architectures": ["Qwen2ForSequenceClassification"], + "classifier_from_token": ["0", "1"], + "method": "from_2_way_softmax", +} + RERANK_MODELS = [ LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-base-v2", architecture="Qwen2ForSequenceClassification", + hf_overrides=mxbai_rerank_hf_overrides, enable_test=True), LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-large-v2", architecture="Qwen2ForSequenceClassification", + hf_overrides=mxbai_rerank_hf_overrides, enable_test=False) ] @@ -71,13 +79,4 @@ def compute_logits(inputs): @pytest.mark.parametrize("model_info", RERANK_MODELS) def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None: - vllm_extra_kwargs: dict[str, Any] = {} - if model_info.architecture == "Qwen2ForSequenceClassification": - vllm_extra_kwargs["hf_overrides"] = { - "architectures": ["Qwen2ForSequenceClassification"], - "classifier_from_token": ["0", "1"], - "method": "from_2_way_softmax", - } - - mteb_test_rerank_models(MxbaiRerankerHfRunner, vllm_runner, model_info, - vllm_extra_kwargs) + mteb_test_rerank_models(MxbaiRerankerHfRunner, vllm_runner, model_info) diff --git a/tests/models/language/pooling/test_qwen3_reranker.py b/tests/models/language/pooling/test_qwen3_reranker.py index 37f5566a330..5dd2d9eae91 100644 --- a/tests/models/language/pooling/test_qwen3_reranker.py +++ b/tests/models/language/pooling/test_qwen3_reranker.py @@ -11,12 +11,20 @@ from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo from .mteb_utils import mteb_test_rerank_models +qwen3_reranker_hf_overrides = { + "architectures": ["Qwen3ForSequenceClassification"], + "classifier_from_token": ["no", "yes"], + "is_original_qwen3_reranker": True, +} + RERANK_MODELS = [ LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-0.6B", architecture="Qwen3ForSequenceClassification", + hf_overrides=qwen3_reranker_hf_overrides, enable_test=True), LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-4B", architecture="Qwen3ForSequenceClassification", + hf_overrides=qwen3_reranker_hf_overrides, enable_test=False) ] @@ -74,18 +82,7 @@ def compute_logits(inputs): @pytest.mark.parametrize("model_info", RERANK_MODELS) def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None: - assert model_info.architecture == "Qwen3ForSequenceClassification" - - vllm_extra_kwargs: dict[str, Any] = { - "hf_overrides": { - "architectures": ["Qwen3ForSequenceClassification"], - "classifier_from_token": ["no", "yes"], - "is_original_qwen3_reranker": True, - } - } - - mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info, - vllm_extra_kwargs) + mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info) @pytest.mark.parametrize("model_info", RERANK_MODELS) @@ -96,16 +93,8 @@ def test_rerank_models_mteb_tp(vllm_runner, assert model_info.architecture == "Qwen3ForSequenceClassification" vllm_extra_kwargs: dict[str, Any] = { - "hf_overrides": { - "architectures": ["Qwen3ForSequenceClassification"], - "classifier_from_token": ["no", "yes"], - "is_original_qwen3_reranker": True, - }, "tensor_parallel_size": 2, } - mteb_test_rerank_models(Qwen3RerankerHfRunner, - vllm_runner, - model_info, - vllm_extra_kwargs, - atol=1.2e-2) + mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info, + vllm_extra_kwargs) diff --git a/tests/models/language/pooling/test_reward.py b/tests/models/language/pooling/test_reward.py index beafa0aed98..08722ac98b7 100644 --- a/tests/models/language/pooling/test_reward.py +++ b/tests/models/language/pooling/test_reward.py @@ -13,14 +13,6 @@ from ...utils import check_transformers_version -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - @pytest.fixture def math_step_prompts(): # ruff: noqa: E501 diff --git a/tests/models/language/pooling/test_scoring.py b/tests/models/language/pooling/test_scoring.py index 6b5ff706814..ef9d5530cde 100644 --- a/tests/models/language/pooling/test_scoring.py +++ b/tests/models/language/pooling/test_scoring.py @@ -23,15 +23,6 @@ "The capital of Germany is Berlin.", ] - -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - DTYPE = "half" diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index 96208f8eda6..d61b182761e 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -189,23 +189,21 @@ }, marks=[pytest.mark.core_model], ), - # FIXME(Isotr0py): Enable this test after - # https://github.com/huggingface/transformers/pull/39470 released - # "idefics3-transformers": VLMTestInfo( - # models=["HuggingFaceTB/SmolVLM-256M-Instruct"], - # test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), - # prompt_formatter=lambda img_prompt:f"<|begin_of_text|>User:{img_prompt}\nAssistant:", # noqa: E501 - # img_idx_to_prompt=lambda idx: "", - # max_model_len=8192, - # max_num_seqs=2, - # auto_cls=AutoModelForImageTextToText, - # hf_output_post_proc=model_utils.idefics3_trunc_hf_output, - # image_size_factors=[(0.25, 0.5, 1.0)], - # vllm_runner_kwargs={ - # "model_impl": "transformers", - # }, - # marks=[pytest.mark.core_model], - # ), + "idefics3-transformers": VLMTestInfo( + models=["HuggingFaceTB/SmolVLM-256M-Instruct"], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt:f"<|begin_of_text|>User:{img_prompt}\nAssistant:", # noqa: E501 + img_idx_to_prompt=lambda idx: "", + max_model_len=8192, + max_num_seqs=2, + auto_cls=AutoModelForImageTextToText, + hf_output_post_proc=model_utils.idefics3_trunc_hf_output, + image_size_factors=[(0.25, 0.5, 1.0)], + vllm_runner_kwargs={ + "model_impl": "transformers", + }, + marks=[pytest.mark.core_model], + ), # Pixel values from processor are not 4D or 5D arrays "qwen2_5_vl-transformers": VLMTestInfo( models=["Qwen/Qwen2.5-VL-3B-Instruct"], @@ -222,21 +220,6 @@ }, marks=[large_gpu_mark(min_gb=32)], ), - # Check "auto" with fallback to transformers - "internvl-transformers": VLMTestInfo( - models=["OpenGVLab/InternVL3-1B-hf"], - test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), - prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 - img_idx_to_prompt=lambda idx: "", - max_model_len=4096, - use_tokenizer_eos=True, - image_size_factors=[(0.25, 0.5, 1.0)], - vllm_runner_kwargs={ - "model_impl": "auto", - }, - auto_cls=AutoModelForImageTextToText, - marks=[pytest.mark.core_model], - ), #### Extended model tests "aria": VLMTestInfo( models=["rhymes-ai/Aria"], @@ -337,10 +320,6 @@ vllm_output_post_proc=model_utils.fuyu_vllm_to_hf_output, num_logprobs=10, image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], - # FIXME(Isotr0py): This model is broken in Transformers v4.54.1, we - # should enable this again after the fix is released: - # https://github.com/huggingface/transformers/pull/39915 - marks=[pytest.mark.skip("HF model is broken")], ), "gemma3": VLMTestInfo( models=["google/gemma-3-4b-it"], @@ -461,6 +440,20 @@ use_tokenizer_eos=True, patch_hf_runner=model_utils.internvl_patch_hf_runner, ), + "intern_vl-hf": VLMTestInfo( + models=["OpenGVLab/InternVL3-1B-hf"], + test_type=( + VLMTestType.IMAGE, + VLMTestType.MULTI_IMAGE, + VLMTestType.VIDEO, + ), + prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "", + video_idx_to_prompt=lambda idx: "