Skip to content

Commit

Permalink
[LLM APIs] Fast follow ups for 2.44 (1/N) (#51042)
Browse files Browse the repository at this point in the history
Some of these changes came from bug bash and dogfooding:

- [x] Rename to VLLMService to VLLMServer
- [x] Remove the extra name space hierarchy for most common import path
to make things more consistent with `data.llm`
- [x] Some other inconsitency stuff
- [x] Use vLLM everywhere (instead of VLLM). Most of these changes
should happen on serve side.
- [x] in ray.data.llm use model_source in the vllm_config (model_id is a
serve only concept as it refers to the model name available to the model
discovery layer)
from ray.serve.llm
- [ ] support vllm v1 [in followup]
- [ ] allow a single deployment to llm router (not force people to pass
a list with one item) [in follow up]
- [x] Update the ray serve docs structure to be more flat based on the
dogfooding feedback.
serve llm docs now look more flat and consistent with serve docs.
Serving LLMs is a single page on the side bar and the apis are a
sub-header under the ray serve apis page.

Overview page
<img width="1446" alt="image"
src="https://github.com/user-attachments/assets/db850f35-3a3d-46e1-9892-b3cd17681b98"
/>

API page
<img width="1433" alt="image"
src="https://github.com/user-attachments/assets/03a669d1-12c2-46c1-9616-484ea95f7082"
/>

---------

Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
  • Loading branch information
kouroshHakha authored Mar 4, 2025
1 parent c6b06e1 commit e151dc2
Show file tree
Hide file tree
Showing 39 changed files with 482 additions and 478 deletions.
12 changes: 6 additions & 6 deletions doc/source/data/working-with-llms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ Upon execution, the Processor object instantiates replicas of the vLLM engine (u
import numpy as np

config = vLLMEngineProcessorConfig(
model="unsloth/Llama-3.1-8B-Instruct",
model_source="unsloth/Llama-3.1-8B-Instruct",
engine_kwargs={
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4096,
Expand Down Expand Up @@ -82,7 +82,7 @@ Some models may require a Hugging Face token to be specified. You can specify th
.. testcode::

config = vLLMEngineProcessorConfig(
model="unsloth/Llama-3.1-8B-Instruct",
model_source="unsloth/Llama-3.1-8B-Instruct",
runtime_env={"env_vars": {"HF_TOKEN": "your_huggingface_token"}},
concurrency=1,
batch_size=64,
Expand All @@ -100,7 +100,7 @@ Use the `vLLMEngineProcessorConfig` to configure the vLLM engine.
from ray.data.llm import vLLMEngineProcessorConfig

config = vLLMEngineProcessorConfig(
model="unsloth/Llama-3.1-8B-Instruct",
model_source="unsloth/Llama-3.1-8B-Instruct",
engine_kwargs={"max_model_len": 20000},
concurrency=1,
batch_size=64,
Expand All @@ -111,7 +111,7 @@ For handling larger models, specify model parallelism.
.. testcode::

config = vLLMEngineProcessorConfig(
model="unsloth/Llama-3.1-8B-Instruct",
model_source="unsloth/Llama-3.1-8B-Instruct",
engine_kwargs={
"max_model_len": 16384,
"tensor_parallel_size": 2,
Expand All @@ -132,7 +132,7 @@ To optimize model loading, you can configure the `load_format` to `runai_streame
.. testcode::

config = vLLMEngineProcessorConfig(
model="unsloth/Llama-3.1-8B-Instruct",
model_source="unsloth/Llama-3.1-8B-Instruct",
engine_kwargs={"load_format": "runai_streamer"},
concurrency=1,
batch_size=64,
Expand All @@ -143,7 +143,7 @@ To do multi-LoRA batch inference, you need to set LoRA related parameters in `en
.. testcode::

config = vLLMEngineProcessorConfig(
model="unsloth/Llama-3.1-8B-Instruct",
model_source="unsloth/Llama-3.1-8B-Instruct",
engine_kwargs={
enable_lora=True,
max_lora_rank=32,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
"# 2. construct a vLLM processor config.\n",
"processor_config = vLLMEngineProcessorConfig(\n",
" # The base model.\n",
" model=\"unsloth/Llama-3.2-1B-Instruct\",\n",
" model_source=\"unsloth/Llama-3.2-1B-Instruct\",\n",
" # vLLM engine config.\n",
" engine_kwargs=dict(\n",
" # Specify the guided decoding library to use. The default is \"xgrammar\".\n",
Expand Down
68 changes: 68 additions & 0 deletions doc/source/serve/api/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -396,3 +396,71 @@ Content-Type: application/json
metrics.Gauge
schema.LoggingConfig
```

(serve-llm-api)=

## LLM API

```{eval-rst}
.. currentmodule:: ray
```


### Builders

```{eval-rst}
.. autosummary::
:nosignatures:
:toctree: doc/
serve.llm.build_vllm_deployment
serve.llm.build_openai_app
```

### Configs

```{eval-rst}
.. autosummary::
:nosignatures:
:toctree: doc/
:template: autosummary/autopydantic.rst
serve.llm.LLMConfig
serve.llm.LLMServingArgs
serve.llm.ModelLoadingConfig
serve.llm.CloudMirrorConfig
serve.llm.LoraConfig
```


### Deployments

```{eval-rst}
.. autosummary::
:nosignatures:
:toctree: doc/
serve.llm.VLLMServer
serve.llm.LLMRouter
```

### OpenAI API Models

```{eval-rst}
.. autosummary::
:nosignatures:
:toctree: doc/
:template: autosummary/autopydantic_show_json.rst
serve.llm.openai_api_models.ChatCompletionRequest
serve.llm.openai_api_models.CompletionRequest
serve.llm.openai_api_models.ChatCompletionStreamResponse
serve.llm.openai_api_models.ChatCompletionResponse
serve.llm.openai_api_models.CompletionStreamResponse
serve.llm.openai_api_models.CompletionResponse
serve.llm.openai_api_models.ErrorResponse
```
2 changes: 1 addition & 1 deletion doc/source/serve/doc_code/vllm_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
class VLLMPredictDeployment:
def __init__(self, **kwargs):
"""
Construct a VLLM deployment.
Construct a vLLM deployment.
Refer to https://github.com/vllm-project/vllm/blob/main/vllm/engine/arg_utils.py
for the full list of arguments.
Expand Down
2 changes: 1 addition & 1 deletion doc/source/serve/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ multi-app
model-multiplexing
configure-serve-deployment
http-guide
Serving LLMs <llm/index>
Serving LLMs <llm/serving-llms>
Production Guide <production-guide/index>
monitoring
resource-allocation
Expand Down
65 changes: 0 additions & 65 deletions doc/source/serve/llm/api.rst

This file was deleted.

11 changes: 0 additions & 11 deletions doc/source/serve/llm/index.rst

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
Overview
========
.. _serving_llms:

Serving LLMs
============

Ray Serve LLM APIs allow users to deploy multiple LLM models together with a familiar Ray Serve API, while providing compatibility with the OpenAI API.

Expand Down Expand Up @@ -27,10 +29,10 @@ Key Components

The ``ray.serve.llm`` module provides two key deployment types for serving LLMs:

VLLMService
VLLMServer
~~~~~~~~~~~~~~~~~~

The VLLMService sets up and manages the vLLM engine for model serving. It can be used standalone or combined with your own custom Ray Serve deployments.
The VLLMServer sets up and manages the vLLM engine for model serving. It can be used standalone or combined with your own custom Ray Serve deployments.

LLMRouter
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -65,8 +67,7 @@ Deployment through ``LLMRouter``
.. code-block:: python
from ray import serve
from ray.serve.llm.configs import LLMConfig
from ray.serve.llm.deployments import VLLMService, LLMRouter
from ray.serve.llm import LLMConfig, VLLMServer, LLMRouter
llm_config = LLMConfig(
model_loading_config=dict(
Expand All @@ -87,7 +88,7 @@ Deployment through ``LLMRouter``
)
# Deploy the application
deployment = VLLMService.as_deployment(llm_config.get_serve_options(name_prefix="VLLM:")).bind(llm_config)
deployment = VLLMServer.as_deployment(llm_config.get_serve_options(name_prefix="vLLM:")).bind(llm_config)
llm_app = LLMRouter.as_deployment().bind([deployment])
serve.run(llm_app)
Expand Down Expand Up @@ -135,8 +136,7 @@ For deploying multiple models, you can pass a list of ``LLMConfig`` objects to t
.. code-block:: python
from ray import serve
from ray.serve.llm.configs import LLMConfig
from ray.serve.llm.deployments import VLLMService, LLMRouter
from ray.serve.llm import LLMConfig, VLLMServer, LLMRouter
llm_config1 = LLMConfig(
model_loading_config=dict(
Expand Down Expand Up @@ -165,8 +165,8 @@ For deploying multiple models, you can pass a list of ``LLMConfig`` objects to t
)
# Deploy the application
deployment1 = VLLMService.as_deployment(llm_config1.get_serve_options(name_prefix="VLLM:")).bind(llm_config1)
deployment2 = VLLMService.as_deployment(llm_config2.get_serve_options(name_prefix="VLLM:")).bind(llm_config2)
deployment1 = VLLMServer.as_deployment(llm_config1.get_serve_options(name_prefix="vLLM:")).bind(llm_config1)
deployment2 = VLLMServer.as_deployment(llm_config2.get_serve_options(name_prefix="vLLM:")).bind(llm_config2)
llm_app = LLMRouter.as_deployment().bind([deployment1, deployment2])
serve.run(llm_app)
Expand Down Expand Up @@ -203,7 +203,7 @@ For production deployments, Ray Serve LLM provides utilities for config-driven d
autoscaling_config:
min_replicas: 1
max_replicas: 2
import_path: ray.serve.llm.builders:build_openai_app
import_path: ray.serve.llm:build_openai_app
name: llm_app
route_prefix: "/"
Expand All @@ -219,7 +219,7 @@ For production deployments, Ray Serve LLM provides utilities for config-driven d
llm_configs:
- models/qwen-0.5b.yaml
- models/qwen-1.5b.yaml
import_path: ray.serve.llm.builders:build_openai_app
import_path: ray.serve.llm:build_openai_app
name: llm_app
route_prefix: "/"
Expand Down Expand Up @@ -274,8 +274,7 @@ This allows the weights to be loaded on each replica on-the-fly and be cached vi
.. code-block:: python
from ray import serve
from ray.serve.llm.configs import LLMConfig
from ray.serve.llm.builders import build_openai_app
from ray.serve.llm import LLMConfig, build_openai_app
# Configure the model with LoRA
llm_config = LLMConfig(
Expand Down Expand Up @@ -340,8 +339,7 @@ For structured output, you can use JSON mode similar to OpenAI's API:
.. code-block:: python
from ray import serve
from ray.serve.llm.configs import LLMConfig
from ray.serve.llm.builders import build_openai_app
from ray.serve.llm import LLMConfig, build_openai_app
# Configure the model with LoRA
llm_config = LLMConfig(
Expand Down Expand Up @@ -415,8 +413,7 @@ For multimodal models that can process both text and images:
.. code-block:: python
from ray import serve
from ray.serve.llm.configs import LLMConfig
from ray.serve.llm.builders import build_openai_app
from ray.serve.llm import LLMConfig, build_openai_app
# Configure a vision model
Expand Down Expand Up @@ -491,8 +488,7 @@ To set the deployment options, you can use the ``get_serve_options`` method on t
.. code-block:: python
from ray import serve
from ray.serve.llm.configs import LLMConfig
from ray.serve.llm.deployments import VLLMService, LLMRouter
from ray.serve.llm import LLMConfig, VLLMServer, LLMRouter
import os
llm_config = LLMConfig(
Expand All @@ -515,7 +511,7 @@ To set the deployment options, you can use the ``get_serve_options`` method on t
)
# Deploy the application
deployment = VLLMService.as_deployment(llm_config.get_serve_options(name_prefix="VLLM:")).bind(llm_config)
deployment = VLLMServer.as_deployment(llm_config.get_serve_options(name_prefix="vLLM:")).bind(llm_config)
llm_app = LLMRouter.as_deployment().bind([deployment])
serve.run(llm_app)
Expand All @@ -529,8 +525,7 @@ If you are using huggingface models, you can enable fast download by setting `HF
.. code-block:: python
from ray import serve
from ray.serve.llm.configs import LLMConfig
from ray.serve.llm.deployments import VLLMService, LLMRouter
from ray.serve.llm import LLMConfig, VLLMServer, LLMRouter
import os
llm_config = LLMConfig(
Expand All @@ -554,6 +549,6 @@ If you are using huggingface models, you can enable fast download by setting `HF
)
# Deploy the application
deployment = VLLMService.as_deployment(llm_config.get_serve_options(name_prefix="VLLM:")).bind(llm_config)
deployment = VLLMServer.as_deployment(llm_config.get_serve_options(name_prefix="vLLM:")).bind(llm_config)
llm_app = LLMRouter.as_deployment().bind([deployment])
serve.run(llm_app)
Loading

0 comments on commit e151dc2

Please sign in to comment.