From e151dc210348ab2ce37562583859f1afea035198 Mon Sep 17 00:00:00 2001
From: kourosh hakhamaneshi <31483498+kouroshHakha@users.noreply.github.com>
Date: Tue, 4 Mar 2025 09:50:16 -0800
Subject: [PATCH] [LLM APIs] Fast follow ups for 2.44 (1/N) (#51042)
Some of these changes came from bug bash and dogfooding:
- [x] Rename to VLLMService to VLLMServer
- [x] Remove the extra name space hierarchy for most common import path
to make things more consistent with `data.llm`
- [x] Some other inconsitency stuff
- [x] Use vLLM everywhere (instead of VLLM). Most of these changes
should happen on serve side.
- [x] in ray.data.llm use model_source in the vllm_config (model_id is a
serve only concept as it refers to the model name available to the model
discovery layer)
from ray.serve.llm
- [ ] support vllm v1 [in followup]
- [ ] allow a single deployment to llm router (not force people to pass
a list with one item) [in follow up]
- [x] Update the ray serve docs structure to be more flat based on the
dogfooding feedback.
serve llm docs now look more flat and consistent with serve docs.
Serving LLMs is a single page on the side bar and the apis are a
sub-header under the ray serve apis page.
Overview page
API page
---------
Signed-off-by: Kourosh Hakhamaneshi
---
doc/source/data/working-with-llms.rst | 12 +-
.../batch/vllm-with-structural-output.ipynb | 2 +-
doc/source/serve/api/index.md | 68 ++++
doc/source/serve/doc_code/vllm_example.py | 2 +-
doc/source/serve/index.md | 2 +-
doc/source/serve/llm/api.rst | 65 ----
doc/source/serve/llm/index.rst | 11 -
.../llm/{overview.rst => serving-llms.rst} | 45 ++-
python/ray/data/llm.py | 6 +-
.../ray/llm/_internal/batch/processor/base.py | 5 +-
.../batch/processor/vllm_engine_proc.py | 12 +-
python/ray/llm/_internal/common/__init__.py | 0
.../base.py => common/base_pydantic.py} | 0
.../serve/builders/application_builders.py | 4 +-
.../serve/configs/openai_api_models.py | 2 +-
.../_internal/serve/configs/server_models.py | 16 +-
.../deployments/llm/vllm/vllm_deployment.py | 8 +-
.../serve/deployments/llm/vllm/vllm_engine.py | 8 +-
.../serve/deployments/llm/vllm/vllm_models.py | 2 +-
.../observability/usage_telemetry/usage.py | 2 +-
.../gpu/processor/test_vllm_engine_proc.py | 8 +-
.../builders/test_application_builders.py | 2 +-
.../matching_configs/hf_prompt_format.yaml | 2 +-
.../test_lora_deployment_base_client.py | 2 +-
.../llm/multiplex/test_lora_model_loader.py | 2 +-
.../multiplex/test_multiplex_deployment.py | 2 +-
.../deployments/llm/vllm/test_vllm_engine.py | 2 +-
.../serve/deployments/mock_vllm_engine.py | 2 +-
.../ray/llm/tests/serve/mock_vllm_model.yaml | 2 +-
.../serve/mock_vllm_model_no_accelerator.yaml | 2 +-
.../usage_telemetry/test_usage.py | 8 +-
python/ray/serve/llm/__init__.py | 329 ++++++++++++++++++
python/ray/serve/llm/builders.py | 145 --------
python/ray/serve/llm/configs.py | 44 ---
python/ray/serve/llm/deployments.py | 122 -------
.../ray/serve/tests/unit/test_llm_imports.py | 8 +-
.../llm_tests/serve_llama_3dot1_8b_lora.yaml | 2 +-
.../serve_llama_3dot1_8b_quantized_tp1.yaml | 2 +-
.../llm_tests/serve_llama_3dot1_8b_tp2.yaml | 2 +-
39 files changed, 482 insertions(+), 478 deletions(-)
delete mode 100644 doc/source/serve/llm/api.rst
delete mode 100644 doc/source/serve/llm/index.rst
rename doc/source/serve/llm/{overview.rst => serving-llms.rst} (90%)
create mode 100644 python/ray/llm/_internal/common/__init__.py
rename python/ray/llm/_internal/{serve/configs/base.py => common/base_pydantic.py} (100%)
delete mode 100644 python/ray/serve/llm/builders.py
delete mode 100644 python/ray/serve/llm/configs.py
delete mode 100644 python/ray/serve/llm/deployments.py
diff --git a/doc/source/data/working-with-llms.rst b/doc/source/data/working-with-llms.rst
index 6c0e34172f7c..4596509c866f 100644
--- a/doc/source/data/working-with-llms.rst
+++ b/doc/source/data/working-with-llms.rst
@@ -40,7 +40,7 @@ Upon execution, the Processor object instantiates replicas of the vLLM engine (u
import numpy as np
config = vLLMEngineProcessorConfig(
- model="unsloth/Llama-3.1-8B-Instruct",
+ model_source="unsloth/Llama-3.1-8B-Instruct",
engine_kwargs={
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4096,
@@ -82,7 +82,7 @@ Some models may require a Hugging Face token to be specified. You can specify th
.. testcode::
config = vLLMEngineProcessorConfig(
- model="unsloth/Llama-3.1-8B-Instruct",
+ model_source="unsloth/Llama-3.1-8B-Instruct",
runtime_env={"env_vars": {"HF_TOKEN": "your_huggingface_token"}},
concurrency=1,
batch_size=64,
@@ -100,7 +100,7 @@ Use the `vLLMEngineProcessorConfig` to configure the vLLM engine.
from ray.data.llm import vLLMEngineProcessorConfig
config = vLLMEngineProcessorConfig(
- model="unsloth/Llama-3.1-8B-Instruct",
+ model_source="unsloth/Llama-3.1-8B-Instruct",
engine_kwargs={"max_model_len": 20000},
concurrency=1,
batch_size=64,
@@ -111,7 +111,7 @@ For handling larger models, specify model parallelism.
.. testcode::
config = vLLMEngineProcessorConfig(
- model="unsloth/Llama-3.1-8B-Instruct",
+ model_source="unsloth/Llama-3.1-8B-Instruct",
engine_kwargs={
"max_model_len": 16384,
"tensor_parallel_size": 2,
@@ -132,7 +132,7 @@ To optimize model loading, you can configure the `load_format` to `runai_streame
.. testcode::
config = vLLMEngineProcessorConfig(
- model="unsloth/Llama-3.1-8B-Instruct",
+ model_source="unsloth/Llama-3.1-8B-Instruct",
engine_kwargs={"load_format": "runai_streamer"},
concurrency=1,
batch_size=64,
@@ -143,7 +143,7 @@ To do multi-LoRA batch inference, you need to set LoRA related parameters in `en
.. testcode::
config = vLLMEngineProcessorConfig(
- model="unsloth/Llama-3.1-8B-Instruct",
+ model_source="unsloth/Llama-3.1-8B-Instruct",
engine_kwargs={
enable_lora=True,
max_lora_rank=32,
diff --git a/doc/source/llm/examples/batch/vllm-with-structural-output.ipynb b/doc/source/llm/examples/batch/vllm-with-structural-output.ipynb
index 3e5366dd47dc..5fb8955390ec 100644
--- a/doc/source/llm/examples/batch/vllm-with-structural-output.ipynb
+++ b/doc/source/llm/examples/batch/vllm-with-structural-output.ipynb
@@ -42,7 +42,7 @@
"# 2. construct a vLLM processor config.\n",
"processor_config = vLLMEngineProcessorConfig(\n",
" # The base model.\n",
- " model=\"unsloth/Llama-3.2-1B-Instruct\",\n",
+ " model_source=\"unsloth/Llama-3.2-1B-Instruct\",\n",
" # vLLM engine config.\n",
" engine_kwargs=dict(\n",
" # Specify the guided decoding library to use. The default is \"xgrammar\".\n",
diff --git a/doc/source/serve/api/index.md b/doc/source/serve/api/index.md
index 0a7afdab22ba..232099497420 100644
--- a/doc/source/serve/api/index.md
+++ b/doc/source/serve/api/index.md
@@ -396,3 +396,71 @@ Content-Type: application/json
metrics.Gauge
schema.LoggingConfig
```
+
+(serve-llm-api)=
+
+## LLM API
+
+```{eval-rst}
+.. currentmodule:: ray
+```
+
+
+### Builders
+
+```{eval-rst}
+
+.. autosummary::
+ :nosignatures:
+ :toctree: doc/
+
+ serve.llm.build_vllm_deployment
+ serve.llm.build_openai_app
+```
+
+### Configs
+
+```{eval-rst}
+
+.. autosummary::
+ :nosignatures:
+ :toctree: doc/
+ :template: autosummary/autopydantic.rst
+
+ serve.llm.LLMConfig
+ serve.llm.LLMServingArgs
+ serve.llm.ModelLoadingConfig
+ serve.llm.CloudMirrorConfig
+ serve.llm.LoraConfig
+```
+
+
+### Deployments
+
+```{eval-rst}
+
+.. autosummary::
+ :nosignatures:
+ :toctree: doc/
+
+ serve.llm.VLLMServer
+ serve.llm.LLMRouter
+```
+
+### OpenAI API Models
+
+```{eval-rst}
+
+.. autosummary::
+ :nosignatures:
+ :toctree: doc/
+ :template: autosummary/autopydantic_show_json.rst
+
+ serve.llm.openai_api_models.ChatCompletionRequest
+ serve.llm.openai_api_models.CompletionRequest
+ serve.llm.openai_api_models.ChatCompletionStreamResponse
+ serve.llm.openai_api_models.ChatCompletionResponse
+ serve.llm.openai_api_models.CompletionStreamResponse
+ serve.llm.openai_api_models.CompletionResponse
+ serve.llm.openai_api_models.ErrorResponse
+```
diff --git a/doc/source/serve/doc_code/vllm_example.py b/doc/source/serve/doc_code/vllm_example.py
index bb88519acb02..9b94c3c59c37 100644
--- a/doc/source/serve/doc_code/vllm_example.py
+++ b/doc/source/serve/doc_code/vllm_example.py
@@ -16,7 +16,7 @@
class VLLMPredictDeployment:
def __init__(self, **kwargs):
"""
- Construct a VLLM deployment.
+ Construct a vLLM deployment.
Refer to https://github.com/vllm-project/vllm/blob/main/vllm/engine/arg_utils.py
for the full list of arguments.
diff --git a/doc/source/serve/index.md b/doc/source/serve/index.md
index 5cc34ad54f61..7c581ab947d4 100644
--- a/doc/source/serve/index.md
+++ b/doc/source/serve/index.md
@@ -13,7 +13,7 @@ multi-app
model-multiplexing
configure-serve-deployment
http-guide
-Serving LLMs
+Serving LLMs
Production Guide
monitoring
resource-allocation
diff --git a/doc/source/serve/llm/api.rst b/doc/source/serve/llm/api.rst
deleted file mode 100644
index 248aefd34dcf..000000000000
--- a/doc/source/serve/llm/api.rst
+++ /dev/null
@@ -1,65 +0,0 @@
-Ray Serve LLM API
-==============================
-
-
-.. currentmodule:: ray.serve.llm.builders
-
-
-
-Builders
----------------------
-
-.. autosummary::
- :nosignatures:
- :toctree: doc/
-
- build_vllm_deployment
- build_openai_app
-
-
-.. currentmodule:: ray.serve.llm.configs
-
-Configs
----------------------
-.. autosummary::
- :nosignatures:
- :toctree: doc/
- :template: autosummary/autopydantic.rst
-
- LLMConfig
- LLMServingArgs
- ModelLoadingConfig
- CloudMirrorConfig
- LoraConfig
-
-.. currentmodule:: ray.serve.llm.deployments
-
-Deployments
----------------------
-
-.. autosummary::
- :nosignatures:
- :toctree: doc/
-
- VLLMService
- LLMRouter
-
-
-.. currentmodule:: ray.serve.llm.openai_api_models
-
-OpenAI API Models
----------------------
-
-.. autosummary::
- :nosignatures:
- :toctree: doc/
- :template: autosummary/autopydantic_show_json.rst
-
- ChatCompletionRequest
- CompletionRequest
- ChatCompletionStreamResponse
- ChatCompletionResponse
- CompletionStreamResponse
- CompletionResponse
- ErrorResponse
-
diff --git a/doc/source/serve/llm/index.rst b/doc/source/serve/llm/index.rst
deleted file mode 100644
index d73c71775c3a..000000000000
--- a/doc/source/serve/llm/index.rst
+++ /dev/null
@@ -1,11 +0,0 @@
-LLM Serving
-===========
-
-Welcome to the LLM Serving documentation. Here you can find information about the API as well as user guides.
-
-.. toctree::
- :maxdepth: 1
- :titlesonly:
-
- overview
- api
diff --git a/doc/source/serve/llm/overview.rst b/doc/source/serve/llm/serving-llms.rst
similarity index 90%
rename from doc/source/serve/llm/overview.rst
rename to doc/source/serve/llm/serving-llms.rst
index 1da02b6f6523..a480af21871c 100644
--- a/doc/source/serve/llm/overview.rst
+++ b/doc/source/serve/llm/serving-llms.rst
@@ -1,5 +1,7 @@
-Overview
-========
+.. _serving_llms:
+
+Serving LLMs
+============
Ray Serve LLM APIs allow users to deploy multiple LLM models together with a familiar Ray Serve API, while providing compatibility with the OpenAI API.
@@ -27,10 +29,10 @@ Key Components
The ``ray.serve.llm`` module provides two key deployment types for serving LLMs:
-VLLMService
+VLLMServer
~~~~~~~~~~~~~~~~~~
-The VLLMService sets up and manages the vLLM engine for model serving. It can be used standalone or combined with your own custom Ray Serve deployments.
+The VLLMServer sets up and manages the vLLM engine for model serving. It can be used standalone or combined with your own custom Ray Serve deployments.
LLMRouter
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -65,8 +67,7 @@ Deployment through ``LLMRouter``
.. code-block:: python
from ray import serve
- from ray.serve.llm.configs import LLMConfig
- from ray.serve.llm.deployments import VLLMService, LLMRouter
+ from ray.serve.llm import LLMConfig, VLLMServer, LLMRouter
llm_config = LLMConfig(
model_loading_config=dict(
@@ -87,7 +88,7 @@ Deployment through ``LLMRouter``
)
# Deploy the application
- deployment = VLLMService.as_deployment(llm_config.get_serve_options(name_prefix="VLLM:")).bind(llm_config)
+ deployment = VLLMServer.as_deployment(llm_config.get_serve_options(name_prefix="vLLM:")).bind(llm_config)
llm_app = LLMRouter.as_deployment().bind([deployment])
serve.run(llm_app)
@@ -135,8 +136,7 @@ For deploying multiple models, you can pass a list of ``LLMConfig`` objects to t
.. code-block:: python
from ray import serve
- from ray.serve.llm.configs import LLMConfig
- from ray.serve.llm.deployments import VLLMService, LLMRouter
+ from ray.serve.llm import LLMConfig, VLLMServer, LLMRouter
llm_config1 = LLMConfig(
model_loading_config=dict(
@@ -165,8 +165,8 @@ For deploying multiple models, you can pass a list of ``LLMConfig`` objects to t
)
# Deploy the application
- deployment1 = VLLMService.as_deployment(llm_config1.get_serve_options(name_prefix="VLLM:")).bind(llm_config1)
- deployment2 = VLLMService.as_deployment(llm_config2.get_serve_options(name_prefix="VLLM:")).bind(llm_config2)
+ deployment1 = VLLMServer.as_deployment(llm_config1.get_serve_options(name_prefix="vLLM:")).bind(llm_config1)
+ deployment2 = VLLMServer.as_deployment(llm_config2.get_serve_options(name_prefix="vLLM:")).bind(llm_config2)
llm_app = LLMRouter.as_deployment().bind([deployment1, deployment2])
serve.run(llm_app)
@@ -203,7 +203,7 @@ For production deployments, Ray Serve LLM provides utilities for config-driven d
autoscaling_config:
min_replicas: 1
max_replicas: 2
- import_path: ray.serve.llm.builders:build_openai_app
+ import_path: ray.serve.llm:build_openai_app
name: llm_app
route_prefix: "/"
@@ -219,7 +219,7 @@ For production deployments, Ray Serve LLM provides utilities for config-driven d
llm_configs:
- models/qwen-0.5b.yaml
- models/qwen-1.5b.yaml
- import_path: ray.serve.llm.builders:build_openai_app
+ import_path: ray.serve.llm:build_openai_app
name: llm_app
route_prefix: "/"
@@ -274,8 +274,7 @@ This allows the weights to be loaded on each replica on-the-fly and be cached vi
.. code-block:: python
from ray import serve
- from ray.serve.llm.configs import LLMConfig
- from ray.serve.llm.builders import build_openai_app
+ from ray.serve.llm import LLMConfig, build_openai_app
# Configure the model with LoRA
llm_config = LLMConfig(
@@ -340,8 +339,7 @@ For structured output, you can use JSON mode similar to OpenAI's API:
.. code-block:: python
from ray import serve
- from ray.serve.llm.configs import LLMConfig
- from ray.serve.llm.builders import build_openai_app
+ from ray.serve.llm import LLMConfig, build_openai_app
# Configure the model with LoRA
llm_config = LLMConfig(
@@ -415,8 +413,7 @@ For multimodal models that can process both text and images:
.. code-block:: python
from ray import serve
- from ray.serve.llm.configs import LLMConfig
- from ray.serve.llm.builders import build_openai_app
+ from ray.serve.llm import LLMConfig, build_openai_app
# Configure a vision model
@@ -491,8 +488,7 @@ To set the deployment options, you can use the ``get_serve_options`` method on t
.. code-block:: python
from ray import serve
- from ray.serve.llm.configs import LLMConfig
- from ray.serve.llm.deployments import VLLMService, LLMRouter
+ from ray.serve.llm import LLMConfig, VLLMServer, LLMRouter
import os
llm_config = LLMConfig(
@@ -515,7 +511,7 @@ To set the deployment options, you can use the ``get_serve_options`` method on t
)
# Deploy the application
- deployment = VLLMService.as_deployment(llm_config.get_serve_options(name_prefix="VLLM:")).bind(llm_config)
+ deployment = VLLMServer.as_deployment(llm_config.get_serve_options(name_prefix="vLLM:")).bind(llm_config)
llm_app = LLMRouter.as_deployment().bind([deployment])
serve.run(llm_app)
@@ -529,8 +525,7 @@ If you are using huggingface models, you can enable fast download by setting `HF
.. code-block:: python
from ray import serve
- from ray.serve.llm.configs import LLMConfig
- from ray.serve.llm.deployments import VLLMService, LLMRouter
+ from ray.serve.llm import LLMConfig, VLLMServer, LLMRouter
import os
llm_config = LLMConfig(
@@ -554,6 +549,6 @@ If you are using huggingface models, you can enable fast download by setting `HF
)
# Deploy the application
- deployment = VLLMService.as_deployment(llm_config.get_serve_options(name_prefix="VLLM:")).bind(llm_config)
+ deployment = VLLMServer.as_deployment(llm_config.get_serve_options(name_prefix="vLLM:")).bind(llm_config)
llm_app = LLMRouter.as_deployment().bind([deployment])
serve.run(llm_app)
diff --git a/python/ray/data/llm.py b/python/ray/data/llm.py
index 1a848617fd13..096c768ec4f4 100644
--- a/python/ray/data/llm.py
+++ b/python/ray/data/llm.py
@@ -81,7 +81,7 @@ class vLLMEngineProcessorConfig(_vLLMEngineProcessorConfig):
"""The configuration for the vLLM engine processor.
Args:
- model: The model to use for the vLLM engine.
+ model_source: The model source to use for the vLLM engine.
batch_size: The batch size to send to the vLLM engine. Large batch sizes are
likely to saturate the compute resources and could achieve higher throughput.
On the other hand, small batch sizes are more fault-tolerant and could
@@ -120,7 +120,7 @@ class vLLMEngineProcessorConfig(_vLLMEngineProcessorConfig):
from ray.data.llm import vLLMEngineProcessorConfig, build_llm_processor
config = vLLMEngineProcessorConfig(
- model="meta-llama/Meta-Llama-3.1-8B-Instruct",
+ model_source="meta-llama/Meta-Llama-3.1-8B-Instruct",
engine_kwargs=dict(
enable_prefix_caching=True,
enable_chunked_prefill=True,
@@ -187,7 +187,7 @@ def build_llm_processor(
from ray.data.llm import vLLMEngineProcessorConfig, build_llm_processor
config = vLLMEngineProcessorConfig(
- model="meta-llama/Meta-Llama-3.1-8B-Instruct",
+ model_source="meta-llama/Meta-Llama-3.1-8B-Instruct",
engine_kwargs=dict(
enable_prefix_caching=True,
enable_chunked_prefill=True,
diff --git a/python/ray/llm/_internal/batch/processor/base.py b/python/ray/llm/_internal/batch/processor/base.py
index 323e919b0b3b..a05981c8d185 100644
--- a/python/ray/llm/_internal/batch/processor/base.py
+++ b/python/ray/llm/_internal/batch/processor/base.py
@@ -1,7 +1,7 @@
from collections import OrderedDict
from typing import Optional, List, Type, Callable, Dict
-from pydantic import BaseModel, Field
+from pydantic import Field
from ray.data.block import UserDefinedFunction
from ray.data import Dataset
@@ -12,9 +12,10 @@
wrap_preprocess,
wrap_postprocess,
)
+from ray.llm._internal.common.base_pydantic import BaseModelExtended
-class ProcessorConfig(BaseModel):
+class ProcessorConfig(BaseModelExtended):
"""The processor configuration."""
batch_size: int = Field(
diff --git a/python/ray/llm/_internal/batch/processor/vllm_engine_proc.py b/python/ray/llm/_internal/batch/processor/vllm_engine_proc.py
index 30c86737bc48..4f8662aa5bef 100644
--- a/python/ray/llm/_internal/batch/processor/vllm_engine_proc.py
+++ b/python/ray/llm/_internal/batch/processor/vllm_engine_proc.py
@@ -25,8 +25,8 @@ class vLLMEngineProcessorConfig(ProcessorConfig):
"""The configuration for the vLLM engine processor."""
# vLLM stage configurations.
- model: str = Field(
- description="The model to use for the vLLM engine.",
+ model_source: str = Field(
+ description="The model source to use for the vLLM engine.",
)
engine_kwargs: Dict[str, Any] = Field(
default_factory=dict,
@@ -122,7 +122,7 @@ def build_vllm_engine_processor(
stages.append(
ChatTemplateStage(
fn_constructor_kwargs=dict(
- model=config.model,
+ model=config.model_source,
chat_template=config.chat_template,
),
map_batches_kwargs=dict(
@@ -137,7 +137,7 @@ def build_vllm_engine_processor(
stages.append(
TokenizeStage(
fn_constructor_kwargs=dict(
- model=config.model,
+ model=config.model_source,
),
map_batches_kwargs=dict(
zero_copy_batch=True,
@@ -152,7 +152,7 @@ def build_vllm_engine_processor(
stages.append(
vLLMEngineStage(
fn_constructor_kwargs=dict(
- model=config.model,
+ model=config.model_source,
engine_kwargs=config.engine_kwargs,
task_type=config.task_type,
max_pending_requests=config.max_pending_requests,
@@ -177,7 +177,7 @@ def build_vllm_engine_processor(
stages.append(
DetokenizeStage(
fn_constructor_kwargs=dict(
- model=config.model,
+ model=config.model_source,
),
map_batches_kwargs=dict(
zero_copy_batch=True,
diff --git a/python/ray/llm/_internal/common/__init__.py b/python/ray/llm/_internal/common/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/python/ray/llm/_internal/serve/configs/base.py b/python/ray/llm/_internal/common/base_pydantic.py
similarity index 100%
rename from python/ray/llm/_internal/serve/configs/base.py
rename to python/ray/llm/_internal/common/base_pydantic.py
diff --git a/python/ray/llm/_internal/serve/builders/application_builders.py b/python/ray/llm/_internal/serve/builders/application_builders.py
index 45d0d0946fc7..b55e22d7072d 100644
--- a/python/ray/llm/_internal/serve/builders/application_builders.py
+++ b/python/ray/llm/_internal/serve/builders/application_builders.py
@@ -25,7 +25,7 @@ def build_vllm_deployment(
deployment_kwargs = {}
deployment_options = llm_config.get_serve_options(
- name_prefix="VLLMDeployment:",
+ name_prefix="vLLMDeployment:",
)
return VLLMDeployment.options(**deployment_options).bind(
@@ -39,7 +39,7 @@ def _get_llm_deployments(
) -> List[DeploymentHandle]:
llm_deployments = []
for llm_config in llm_base_models:
- if llm_config.llm_engine == LLMEngine.VLLM:
+ if llm_config.llm_engine == LLMEngine.vLLM:
llm_deployments.append(build_vllm_deployment(llm_config, deployment_kwargs))
else:
# Note (genesu): This should never happen because we validate the engine
diff --git a/python/ray/llm/_internal/serve/configs/openai_api_models.py b/python/ray/llm/_internal/serve/configs/openai_api_models.py
index 2e0bb3fe6685..5ea34b5c9bee 100644
--- a/python/ray/llm/_internal/serve/configs/openai_api_models.py
+++ b/python/ray/llm/_internal/serve/configs/openai_api_models.py
@@ -240,7 +240,7 @@ class ChatCompletionRequest(BaseModel):
Union[Literal["none"], Literal["auto"], ChatCompletionNamedToolChoiceParam]
] = "none"
- # NOTE this will be ignored by VLLM -- the model determines the behavior
+ # NOTE this will be ignored by vLLM -- the model determines the behavior
parallel_tool_calls: Optional[bool] = False
user: Optional[str] = None
diff --git a/python/ray/llm/_internal/serve/configs/server_models.py b/python/ray/llm/_internal/serve/configs/server_models.py
index 970fadc2d9d8..1efdd9c2147a 100644
--- a/python/ray/llm/_internal/serve/configs/server_models.py
+++ b/python/ray/llm/_internal/serve/configs/server_models.py
@@ -49,7 +49,7 @@
ErrorResponse,
ResponseFormatType,
)
-from ray.llm._internal.serve.configs.base import BaseModelExtended
+from ray.llm._internal.common.base_pydantic import BaseModelExtended
transformers = try_import("transformers")
@@ -124,7 +124,7 @@ class InputModality(str, Enum):
class LLMEngine(str, Enum):
"""Enum that represents an LLMEngine."""
- VLLM = "VLLM"
+ vLLM = "vLLM"
class JSONModeOptions(BaseModelExtended):
@@ -214,7 +214,7 @@ class LLMConfig(BaseModelExtended):
)
llm_engine: str = Field(
- default=LLMEngine.VLLM.value,
+ default=LLMEngine.vLLM.value,
description=f"The LLMEngine that should be used to run the model. Only the following values are supported: {str([t.value for t in LLMEngine])}",
)
@@ -347,7 +347,7 @@ def get_engine_config(self):
LLMConfig not only has engine config but also deployment config, etc.
"""
- if self.llm_engine == LLMEngine.VLLM:
+ if self.llm_engine == LLMEngine.vLLM:
from ray.llm._internal.serve.deployments.llm.vllm.vllm_models import (
VLLMEngineConfig,
)
@@ -423,17 +423,15 @@ def get_serve_options(
:skipif: True
from ray import serve
- from ray.serve.llm.configs import LLMConfig, ModelLoadingConfig
- from ray.serve.llm.deployments import VLLMDeployment
-
+ from ray.serve.llm import LLMConfig, VLLMServer
llm_config = LLMConfig(
- model_loading_config=ModelLoadingConfig(model_id="test_model"),
+ model_loading_config=dict(model_id="test_model"),
accelerator_type="L4",
runtime_env={"env_vars": {"FOO": "bar"}},
)
serve_options = llm_config.get_serve_options(name_prefix="Test:")
- vllm_app = VLLMDeployment.options(**serve_options).bind(llm_config)
+ vllm_app = VLLMServer.options(**serve_options).bind(llm_config)
serve.run(vllm_app)
Keyword Args:
diff --git a/python/ray/llm/_internal/serve/deployments/llm/vllm/vllm_deployment.py b/python/ray/llm/_internal/serve/deployments/llm/vllm/vllm_deployment.py
index 15fd4e1df2bc..df35bab95037 100644
--- a/python/ray/llm/_internal/serve/deployments/llm/vllm/vllm_deployment.py
+++ b/python/ray/llm/_internal/serve/deployments/llm/vllm/vllm_deployment.py
@@ -372,7 +372,7 @@ async def process_completions(
)
-class VLLMService(LLMDeployment):
+class VLLMServer(LLMDeployment):
_default_engine_cls = VLLMEngine
_default_image_retriever_cls = ImageRetriever
@@ -573,7 +573,7 @@ async def _disk_lora_model(self, lora_model_id: str) -> DiskMultiplexConfig:
def as_deployment(
cls, deployment_options: Dict[str, Any] = None
) -> serve.Deployment:
- """Convert the VLLMService to a Ray Serve deployment.
+ """Convert the VLLMServer to a Ray Serve deployment.
Args:
deployment_options: A dictionary of deployment options.
@@ -604,8 +604,8 @@ def as_deployment(
health_check_period_s=DEFAULT_HEALTH_CHECK_PERIOD_S,
health_check_timeout_s=DEFAULT_HEALTH_CHECK_TIMEOUT_S,
)
-class VLLMDeployment(VLLMService):
- # Note (genesu): We are separating the VLLMService and VLLMDeployment just
+class VLLMDeployment(VLLMServer):
+ # Note (genesu): We are separating the VLLMServer and VLLMDeployment just
# to give developers an ability to test the implementation outside the Ray Serve.
# But in practice we should always test the VLLMDeployment class as a Serve
# deployment to ensure all functionalities can be run remotely asynchronously.
diff --git a/python/ray/llm/_internal/serve/deployments/llm/vllm/vllm_engine.py b/python/ray/llm/_internal/serve/deployments/llm/vllm/vllm_engine.py
index d30a3705d94d..61421b7c639d 100644
--- a/python/ray/llm/_internal/serve/deployments/llm/vllm/vllm_engine.py
+++ b/python/ray/llm/_internal/serve/deployments/llm/vllm/vllm_engine.py
@@ -252,7 +252,7 @@ def __init__(
self,
llm_config: LLMConfig,
):
- """Create a VLLM Engine class
+ """Create a vLLM Engine class
Args:
llm_config: The llm configuration for this engine
@@ -284,7 +284,7 @@ async def initialize_node(llm_config: LLMConfig) -> InitializeNodeOutput:
return await initialize_node_util(llm_config)
async def start(self):
- """Start the VLLM engine.
+ """Start the vLLM engine.
If the engine is already running, do nothing.
"""
@@ -422,7 +422,7 @@ async def _generate(
) -> AsyncGenerator[LLMRawResponse, None]:
"""Generate an LLMRawResponse stream
- The VLLM generation request will be passed into VLLM, and the resulting output
+ The vLLM generation request will be passed into vLLM, and the resulting output
will be wrapped in an LLMRawResponse and yielded back to the user.
Error handling:
@@ -439,7 +439,7 @@ async def _generate(
f"Request {vllm_generation_request.request_id} started. "
f"Prompt: {vllm_generation_request.prompt}"
)
- # Construct a results generator from VLLM
+ # Construct a results generator from vLLM
results_generator: AsyncGenerator["RequestOutput", None] = self.engine.generate(
prompt=vllm.inputs.TextPrompt(
prompt=vllm_generation_request.prompt,
diff --git a/python/ray/llm/_internal/serve/deployments/llm/vllm/vllm_models.py b/python/ray/llm/_internal/serve/deployments/llm/vllm/vllm_models.py
index 23a6a29bee0f..1ca704716727 100644
--- a/python/ray/llm/_internal/serve/deployments/llm/vllm/vllm_models.py
+++ b/python/ray/llm/_internal/serve/deployments/llm/vllm/vllm_models.py
@@ -12,7 +12,7 @@
from ray.llm._internal.utils import try_import
from ray.llm._internal.serve.observability.logging import get_logger
-from ray.llm._internal.serve.configs.base import BaseModelExtended
+from ray.llm._internal.common.base_pydantic import BaseModelExtended
from ray.llm._internal.serve.configs.server_models import (
DiskMultiplexConfig,
CloudMirrorConfig,
diff --git a/python/ray/llm/_internal/serve/observability/usage_telemetry/usage.py b/python/ray/llm/_internal/serve/observability/usage_telemetry/usage.py
index 0267f0b4d8d9..2f722326a5cf 100644
--- a/python/ray/llm/_internal/serve/observability/usage_telemetry/usage.py
+++ b/python/ray/llm/_internal/serve/observability/usage_telemetry/usage.py
@@ -11,7 +11,7 @@
from ray.llm._internal.serve.observability.logging import get_logger
from ray.llm._internal.serve.deployments.llm.multiplex.utils import get_lora_model_ids
-from ray.llm._internal.serve.configs.base import BaseModelExtended
+from ray.llm._internal.common.base_pydantic import BaseModelExtended
if TYPE_CHECKING:
from ray.llm._internal.serve.configs.server_models import LLMConfig
diff --git a/python/ray/llm/tests/batch/gpu/processor/test_vllm_engine_proc.py b/python/ray/llm/tests/batch/gpu/processor/test_vllm_engine_proc.py
index b97cda69c0ff..55ffaa3f786c 100644
--- a/python/ray/llm/tests/batch/gpu/processor/test_vllm_engine_proc.py
+++ b/python/ray/llm/tests/batch/gpu/processor/test_vllm_engine_proc.py
@@ -11,7 +11,7 @@
def test_vllm_engine_processor(gpu_type, model_opt_125m):
config = vLLMEngineProcessorConfig(
- model=model_opt_125m,
+ model_source=model_opt_125m,
engine_kwargs=dict(
max_model_len=8192,
),
@@ -86,7 +86,7 @@ def test_generation_model(gpu_type, model_opt_125m):
"""
processor_config = vLLMEngineProcessorConfig(
- model=model_opt_125m,
+ model_source=model_opt_125m,
engine_kwargs=dict(
enable_prefix_caching=False,
enable_chunked_prefill=True,
@@ -133,7 +133,7 @@ def test_generation_model(gpu_type, model_opt_125m):
def test_embedding_model(gpu_type, model_opt_125m):
processor_config = vLLMEngineProcessorConfig(
- model=model_opt_125m,
+ model_source=model_opt_125m,
task_type="embed",
engine_kwargs=dict(
enable_prefix_caching=False,
@@ -177,7 +177,7 @@ def test_embedding_model(gpu_type, model_opt_125m):
def test_vision_model(gpu_type, model_llava_354m):
processor_config = vLLMEngineProcessorConfig(
- model=model_llava_354m,
+ model_source=model_llava_354m,
task_type="generate",
engine_kwargs=dict(
# Skip CUDA graph capturing to reduce startup time.
diff --git a/python/ray/llm/tests/serve/builders/test_application_builders.py b/python/ray/llm/tests/serve/builders/test_application_builders.py
index 5a128078ff33..b0113dd29741 100644
--- a/python/ray/llm/tests/serve/builders/test_application_builders.py
+++ b/python/ray/llm/tests/serve/builders/test_application_builders.py
@@ -153,7 +153,7 @@ def test_build_vllm_deployment(
shutdown_ray_and_serve,
use_mock_vllm_engine,
):
- """Test `build_vllm_deployment` can build a VLLM deployment."""
+ """Test `build_vllm_deployment` can build a vLLM deployment."""
app = build_vllm_deployment(llm_config)
assert isinstance(app, serve.Application)
diff --git a/python/ray/llm/tests/serve/configs/configs/matching_configs/hf_prompt_format.yaml b/python/ray/llm/tests/serve/configs/configs/matching_configs/hf_prompt_format.yaml
index 686e35c8e8fc..4b7dcea8a26f 100644
--- a/python/ray/llm/tests/serve/configs/configs/matching_configs/hf_prompt_format.yaml
+++ b/python/ray/llm/tests/serve/configs/configs/matching_configs/hf_prompt_format.yaml
@@ -6,7 +6,7 @@ model_loading_config:
model_id: mistral-community/pixtral-12b
model_source: "/home/ray/tests/rayllm/backend/server/configs/cached_model_processors/mistral-community--pixtral-12b"
-llm_engine: VLLM
+llm_engine: vLLM
engine_kwargs:
enable_chunked_prefill: true
diff --git a/python/ray/llm/tests/serve/deployments/llm/multiplex/test_lora_deployment_base_client.py b/python/ray/llm/tests/serve/deployments/llm/multiplex/test_lora_deployment_base_client.py
index dc906b9bc553..6a9c6d4dc893 100644
--- a/python/ray/llm/tests/serve/deployments/llm/multiplex/test_lora_deployment_base_client.py
+++ b/python/ray/llm/tests/serve/deployments/llm/multiplex/test_lora_deployment_base_client.py
@@ -20,7 +20,7 @@
model_loading_config:
model_id: meta-llama/Llama-2-7b-hf
-llm_engine: VLLM
+llm_engine: vLLM
engine_kwargs:
trust_remote_code: True
diff --git a/python/ray/llm/tests/serve/deployments/llm/multiplex/test_lora_model_loader.py b/python/ray/llm/tests/serve/deployments/llm/multiplex/test_lora_model_loader.py
index b51f5e3c4c17..2b120e2d7f61 100644
--- a/python/ray/llm/tests/serve/deployments/llm/multiplex/test_lora_model_loader.py
+++ b/python/ray/llm/tests/serve/deployments/llm/multiplex/test_lora_model_loader.py
@@ -30,7 +30,7 @@ def llm_config(self):
"""Common LLM config used across tests."""
return LLMConfig(
model_loading_config=ModelLoadingConfig(model_id="llm_model_id"),
- llm_engine=LLMEngine.VLLM,
+ llm_engine=LLMEngine.vLLM,
accelerator_type="L4",
lora_config=LoraConfig(
dynamic_lora_loading_path="s3://fake-bucket-uri-abcd"
diff --git a/python/ray/llm/tests/serve/deployments/llm/multiplex/test_multiplex_deployment.py b/python/ray/llm/tests/serve/deployments/llm/multiplex/test_multiplex_deployment.py
index 766f41562943..550ce7308001 100644
--- a/python/ray/llm/tests/serve/deployments/llm/multiplex/test_multiplex_deployment.py
+++ b/python/ray/llm/tests/serve/deployments/llm/multiplex/test_multiplex_deployment.py
@@ -20,7 +20,7 @@
model_loading_config:
model_id: meta-llama/Llama-2-7b-hf
-llm_engine: VLLM
+llm_engine: vLLM
engine_kwargs:
trust_remote_code: True
diff --git a/python/ray/llm/tests/serve/deployments/llm/vllm/test_vllm_engine.py b/python/ray/llm/tests/serve/deployments/llm/vllm/test_vllm_engine.py
index 6baaa2ce791c..1e92f11d201f 100644
--- a/python/ray/llm/tests/serve/deployments/llm/vllm/test_vllm_engine.py
+++ b/python/ray/llm/tests/serve/deployments/llm/vllm/test_vllm_engine.py
@@ -53,7 +53,7 @@ def get_fake_responses(*tokens: List[str]):
for token in tokens:
total += token
- # For some reason VLLM appears to return the full text on each iteration
+ # For some reason vLLM appears to return the full text on each iteration
# We should fix this in vllm
output.append(
SimpleNamespace(
diff --git a/python/ray/llm/tests/serve/deployments/mock_vllm_engine.py b/python/ray/llm/tests/serve/deployments/mock_vllm_engine.py
index 1f9429d05d56..9880fa4156d2 100644
--- a/python/ray/llm/tests/serve/deployments/mock_vllm_engine.py
+++ b/python/ray/llm/tests/serve/deployments/mock_vllm_engine.py
@@ -34,7 +34,7 @@
class MockVLLMEngine:
def __init__(self, llm_config: LLMConfig):
- """Create a VLLM Engine class
+ """Create a vLLM Engine class
Args:
llm_config: The llm configuration for this engine
diff --git a/python/ray/llm/tests/serve/mock_vllm_model.yaml b/python/ray/llm/tests/serve/mock_vllm_model.yaml
index fd5428923d6b..879fc9ec8f3d 100644
--- a/python/ray/llm/tests/serve/mock_vllm_model.yaml
+++ b/python/ray/llm/tests/serve/mock_vllm_model.yaml
@@ -1,7 +1,7 @@
model_loading_config:
model_id: VLLMFakeModel
-llm_engine: VLLM
+llm_engine: vLLM
engine_kwargs:
max_model_len: 4096
diff --git a/python/ray/llm/tests/serve/mock_vllm_model_no_accelerator.yaml b/python/ray/llm/tests/serve/mock_vllm_model_no_accelerator.yaml
index 07f334fd12b3..701fb4171a39 100644
--- a/python/ray/llm/tests/serve/mock_vllm_model_no_accelerator.yaml
+++ b/python/ray/llm/tests/serve/mock_vllm_model_no_accelerator.yaml
@@ -1,7 +1,7 @@
model_loading_config:
model_id: VLLMFakeModel
-llm_engine: VLLM
+llm_engine: vLLM
engine_kwargs:
max_model_len: 4096
diff --git a/python/ray/llm/tests/serve/observability/usage_telemetry/test_usage.py b/python/ray/llm/tests/serve/observability/usage_telemetry/test_usage.py
index c2bcabb002bc..108c3437162f 100644
--- a/python/ray/llm/tests/serve/observability/usage_telemetry/test_usage.py
+++ b/python/ray/llm/tests/serve/observability/usage_telemetry/test_usage.py
@@ -42,14 +42,14 @@ def record_tag_func(key, value):
model_loading_config=ModelLoadingConfig(
model_id="llm_model_id",
),
- llm_engine=LLMEngine.VLLM,
+ llm_engine=LLMEngine.vLLM,
accelerator_type="L4",
)
llm_config_autoscale_model = LLMConfig(
model_loading_config=ModelLoadingConfig(
model_id="llm_config_autoscale_model_id",
),
- llm_engine=LLMEngine.VLLM,
+ llm_engine=LLMEngine.vLLM,
accelerator_type="A10G",
deployment_config=dict(
autoscaling_config=dict(
@@ -63,14 +63,14 @@ def record_tag_func(key, value):
model_loading_config=ModelLoadingConfig(
model_id="llm_config_json_model_id",
),
- llm_engine=LLMEngine.VLLM,
+ llm_engine=LLMEngine.vLLM,
accelerator_type="A10G",
)
llm_config_lora_model = LLMConfig(
model_loading_config=ModelLoadingConfig(
model_id="llm_config_lora_model_id",
),
- llm_engine=LLMEngine.VLLM,
+ llm_engine=LLMEngine.vLLM,
accelerator_type="A10G",
lora_config=LoraConfig(dynamic_lora_loading_path=dynamic_lora_loading_path),
)
diff --git a/python/ray/serve/llm/__init__.py b/python/ray/serve/llm/__init__.py
index e69de29bb2d1..83b3060153f2 100644
--- a/python/ray/serve/llm/__init__.py
+++ b/python/ray/serve/llm/__init__.py
@@ -0,0 +1,329 @@
+from typing import TYPE_CHECKING
+
+from ray.util.annotations import PublicAPI
+
+
+from ray.llm._internal.serve.configs.server_models import (
+ LLMConfig as _LLMConfig,
+ LLMServingArgs as _LLMServingArgs,
+ ModelLoadingConfig as _ModelLoadingConfig,
+ CloudMirrorConfig as _CloudMirrorConfig,
+ LoraConfig as _LoraConfig,
+)
+from ray.llm._internal.serve.deployments.llm.vllm.vllm_deployment import (
+ VLLMServer as _VLLMServer,
+)
+from ray.llm._internal.serve.deployments.routers.router import (
+ LLMRouter as _LLMRouter,
+)
+
+
+if TYPE_CHECKING:
+ from ray.serve.deployment import Application
+
+
+##########
+# Models
+##########
+
+
+@PublicAPI(stability="alpha")
+class LLMConfig(_LLMConfig):
+ """The configuration for starting an LLM deployment."""
+
+ pass
+
+
+@PublicAPI(stability="alpha")
+class LLMServingArgs(_LLMServingArgs):
+ """The configuration for starting an LLM deployment application."""
+
+ pass
+
+
+@PublicAPI(stability="alpha")
+class ModelLoadingConfig(_ModelLoadingConfig):
+ """The configuration for loading an LLM model."""
+
+ pass
+
+
+@PublicAPI(stability="alpha")
+class CloudMirrorConfig(_CloudMirrorConfig):
+ """The configuration for mirroring an LLM model from cloud storage."""
+
+ pass
+
+
+@PublicAPI(stability="alpha")
+class LoraConfig(_LoraConfig):
+ """The configuration for loading an LLM model with LoRA."""
+
+ pass
+
+
+##########
+# Builders
+##########
+
+
+@PublicAPI(stability="alpha")
+def build_vllm_deployment(llm_config: "LLMConfig") -> "Application":
+ """Helper to build a single vllm deployment from the given llm config.
+
+ Examples:
+ .. testcode::
+ :skipif: True
+
+ from ray import serve
+ from ray.serve.llm import LLMConfig, build_vllm_deployment
+
+ # Configure the model
+ llm_config = LLMConfig(
+ model_loading_config=dict(
+ model_id="llama-3.1-8b",
+ model_source="meta-llama/Llama-3.1-8b-instruct",
+ ),
+ deployment_config=dict(
+ autoscaling_config=dict(
+ min_replicas=1,
+ max_replicas=2,
+ )
+ ),
+ accelerator_type="A10G",
+ )
+
+ # Build the deployment
+ vllm_app = build_vllm_deployment(llm_config)
+
+ # Deploy the application
+ model_handle = serve.run(vllm_app)
+
+ # Querying the model handle
+ import asyncio
+ model_handle = model_handle.options(stream=True)
+ async def query_model(model_handle):
+ from ray.serve.llm.openai_api_models import ChatCompletionRequest
+
+ request = ChatCompletionRequest(
+ model="qwen-0.5b",
+ messages=[
+ {
+ "role": "user",
+ "content": "Hello, world!"
+ }
+ ]
+ )
+
+ resp = model_handle.chat.remote(request)
+ async for message in resp:
+ print("message: ", message)
+
+ asyncio.run(query_model(model_handle))
+
+ Args:
+ llm_config: The llm config to build vllm deployment.
+
+ Returns:
+ The configured Ray Serve Application for vllm deployment.
+ """
+ from ray.llm._internal.serve.builders import build_vllm_deployment
+
+ return build_vllm_deployment(llm_config=llm_config)
+
+
+@PublicAPI(stability="alpha")
+def build_openai_app(llm_serving_args: "LLMServingArgs") -> "Application":
+ """Helper to build an OpenAI compatible app with the llm deployment setup from
+ the given llm serving args. This is the main entry point for users to create a
+ Serve application serving LLMs.
+
+
+ Examples:
+ .. testcode::
+ :skipif: True
+
+ from ray import serve
+ from ray.serve.llm import LLMConfig, VLLMServer, LLMRouter
+
+ llm_config1 = LLMConfig(
+ model_loading_config=dict(
+ model_id="qwen-0.5b",
+ model_source="Qwen/Qwen2.5-0.5B-Instruct",
+ ),
+ deployment_config=dict(
+ autoscaling_config=dict(
+ min_replicas=1, max_replicas=2,
+ )
+ ),
+ accelerator_type="A10G",
+ )
+
+ llm_config2 = LLMConfig(
+ model_loading_config=dict(
+ model_id="qwen-1.5b",
+ model_source="Qwen/Qwen2.5-1.5B-Instruct",
+ ),
+ deployment_config=dict(
+ autoscaling_config=dict(
+ min_replicas=1, max_replicas=2,
+ )
+ ),
+ accelerator_type="A10G",
+ )
+
+ # Deploy the application
+ deployment1 = VLLMServer.as_deployment().bind(llm_config1)
+ deployment2 = VLLMServer.as_deployment().bind(llm_config2)
+ llm_app = LLMRouter.as_deployment().bind([deployment1, deployment2])
+ serve.run(llm_app)
+
+
+ # Querying the model via openai client
+ from openai import OpenAI
+
+ # Initialize client
+ client = OpenAI(base_url="http://localhost:8000/v1", api_key="fake-key")
+
+ # Basic completion
+ response = client.chat.completions.create(
+ model="qwen-0.5b",
+ messages=[{"role": "user", "content": "Hello!"}]
+ )
+
+ Args:
+ llm_serving_args: The list of llm configs or the paths to the llm config to
+ build the app.
+
+ Returns:
+ The configured Ray Serve Application router.
+ """
+ from ray.llm._internal.serve.builders import build_openai_app
+
+ return build_openai_app(llm_serving_args=llm_serving_args)
+
+
+#############
+# Deployments
+#############
+
+
+@PublicAPI(stability="alpha")
+class VLLMServer(_VLLMServer):
+ """The implementation of the vLLM engine deployment.
+
+ To build a Deployment object you should use `build_vllm_deployment` function.
+ We also expose a lower level API for more control over the deployment class
+ through `as_deployment` method.
+
+ Examples:
+ .. testcode::
+ :skipif: True
+
+ from ray import serve
+ from ray.serve.llm import LLMConfig, VLLMServer
+
+ # Configure the model
+ llm_config = LLMConfig(
+ model_loading_config=dict(
+ served_model_name="llama-3.1-8b",
+ model_source="meta-llama/Llama-3.1-8b-instruct",
+ ),
+ deployment_config=dict(
+ autoscaling_config=dict(
+ min_replicas=1,
+ max_replicas=8,
+ )
+ ),
+ )
+
+ # Build the deployment directly
+ VLLMDeployment = VLLMServer.as_deployment(llm_config.get_serve_options())
+ vllm_app = VLLMDeployment.bind(llm_config)
+
+ model_handle = serve.run(vllm_app)
+
+ # Query the model via `chat` api
+ from ray.serve.llm.openai_api_models import ChatCompletionRequest
+ request = ChatCompletionRequest(
+ model="llama-3.1-8b",
+ messages=[
+ {
+ "role": "user",
+ "content": "Hello, world!"
+ }
+ ]
+ )
+ response = ray.get(model_handle.chat(request))
+ print(response)
+ """
+
+ pass
+
+
+@PublicAPI(stability="alpha")
+class LLMRouter(_LLMRouter):
+
+ """The implementation of the OpenAI compatiple model router.
+
+ This deployment creates the following endpoints:
+ - /v1/chat/completions: Chat interface (OpenAI-style)
+ - /v1/completions: Text completion
+ - /v1/models: List available models
+ - /v1/models/{model}: Model information
+
+
+ Examples:
+ .. testcode::
+ :skipif: True
+
+
+ from ray import serve
+ from ray.serve.llm import LLMConfig, VLLMServer, LLMRouter
+ from ray.serve.llm.openai_api_models import ChatCompletionRequest
+
+
+ llm_config1 = LLMConfig(
+ model_loading_config=dict(
+ served_model_name="llama-3.1-8b", # Name shown in /v1/models
+ model_source="meta-llama/Llama-3.1-8b-instruct",
+ ),
+ deployment_config=dict(
+ autoscaling_config=dict(
+ min_replicas=1, max_replicas=8,
+ )
+ ),
+ )
+ llm_config2 = LLMConfig(
+ model_loading_config=dict(
+ served_model_name="llama-3.2-3b", # Name shown in /v1/models
+ model_source="meta-llama/Llama-3.2-3b-instruct",
+ ),
+ deployment_config=dict(
+ autoscaling_config=dict(
+ min_replicas=1, max_replicas=8,
+ )
+ ),
+ )
+
+ # Deploy the application
+ vllm_deployment1 = VLLMServer.as_deployment(llm_config1.get_serve_options()).bind(llm_config1)
+ vllm_deployment2 = VLLMServer.as_deployment(llm_config2.get_serve_options()).bind(llm_config2)
+ llm_app = LLMRouter.as_deployment().bind([vllm_deployment1, vllm_deployment2])
+ serve.run(llm_app)
+ """
+
+ pass
+
+
+__all__ = [
+ "LLMConfig",
+ "LLMServingArgs",
+ "ModelLoadingConfig",
+ "CloudMirrorConfig",
+ "LoraConfig",
+ "build_vllm_deployment",
+ "build_openai_app",
+ "VLLMServer",
+ "LLMRouter",
+]
diff --git a/python/ray/serve/llm/builders.py b/python/ray/serve/llm/builders.py
deleted file mode 100644
index eec9f139bb8e..000000000000
--- a/python/ray/serve/llm/builders.py
+++ /dev/null
@@ -1,145 +0,0 @@
-from typing import TYPE_CHECKING
-
-from ray.util.annotations import PublicAPI
-
-if TYPE_CHECKING:
- from ray.serve.deployment import Application
- from ray.serve.llm.configs import LLMConfig, LLMServingArgs
-
-
-@PublicAPI(stability="alpha")
-def build_vllm_deployment(llm_config: "LLMConfig") -> "Application":
- """Helper to build a single vllm deployment from the given llm config.
-
- Examples:
- .. testcode::
- :skipif: True
-
- from ray import serve
- from ray.serve.llm.configs import LLMConfig
- from ray.serve.llm.builders import build_vllm_deployment
-
- # Configure the model
- llm_config = LLMConfig(
- model_loading_config=dict(
- model_id="llama-3.1-8b",
- model_source="meta-llama/Llama-3.1-8b-instruct",
- ),
- deployment_config=dict(
- autoscaling_config=dict(
- min_replicas=1,
- max_replicas=2,
- )
- ),
- accelerator_type="A10G",
- )
-
- # Build the deployment
- vllm_app = build_vllm_deployment(llm_config)
-
- # Deploy the application
- model_handle = serve.run(vllm_app)
-
- # Querying the model handle
- import asyncio
- model_handle = model_handle.options(stream=True)
- async def query_model(model_handle):
- from ray.serve.llm.openai_api_models import ChatCompletionRequest
-
- request = ChatCompletionRequest(
- model="qwen-0.5b",
- messages=[
- {
- "role": "user",
- "content": "Hello, world!"
- }
- ]
- )
-
- resp = model_handle.chat.remote(request)
- async for message in resp:
- print("message: ", message)
-
- asyncio.run(query_model(model_handle))
-
- Args:
- llm_config: The llm config to build vllm deployment.
-
- Returns:
- The configured Ray Serve Application for vllm deployment.
- """
- from ray.llm._internal.serve.builders import build_vllm_deployment
-
- return build_vllm_deployment(llm_config=llm_config)
-
-
-@PublicAPI(stability="alpha")
-def build_openai_app(llm_serving_args: "LLMServingArgs") -> "Application":
- """Helper to build an OpenAI compatible app with the llm deployment setup from
- the given llm serving args. This is the main entry point for users to create a
- Serve application serving LLMs.
-
-
- Examples:
- .. testcode::
- :skipif: True
-
- from ray import serve
- from ray.serve.llm.configs import LLMConfig
- from ray.serve.llm.deployments import VLLMService, LLMRouter
-
- llm_config1 = LLMConfig(
- model_loading_config=dict(
- model_id="qwen-0.5b",
- model_source="Qwen/Qwen2.5-0.5B-Instruct",
- ),
- deployment_config=dict(
- autoscaling_config=dict(
- min_replicas=1, max_replicas=2,
- )
- ),
- accelerator_type="A10G",
- )
-
- llm_config2 = LLMConfig(
- model_loading_config=dict(
- model_id="qwen-1.5b",
- model_source="Qwen/Qwen2.5-1.5B-Instruct",
- ),
- deployment_config=dict(
- autoscaling_config=dict(
- min_replicas=1, max_replicas=2,
- )
- ),
- accelerator_type="A10G",
- )
-
- # Deploy the application
- deployment1 = VLLMService.as_deployment().bind(llm_config1)
- deployment2 = VLLMService.as_deployment().bind(llm_config2)
- llm_app = LLMRouter.as_deployment().bind([deployment1, deployment2])
- serve.run(llm_app)
-
-
- # Querying the model via openai client
- from openai import OpenAI
-
- # Initialize client
- client = OpenAI(base_url="http://localhost:8000/v1", api_key="fake-key")
-
- # Basic completion
- response = client.chat.completions.create(
- model="qwen-0.5b",
- messages=[{"role": "user", "content": "Hello!"}]
- )
-
- Args:
- llm_serving_args: The list of llm configs or the paths to the llm config to
- build the app.
-
- Returns:
- The configured Ray Serve Application router.
- """
- from ray.llm._internal.serve.builders import build_openai_app
-
- return build_openai_app(llm_serving_args=llm_serving_args)
diff --git a/python/ray/serve/llm/configs.py b/python/ray/serve/llm/configs.py
deleted file mode 100644
index d0b99ae131e2..000000000000
--- a/python/ray/serve/llm/configs.py
+++ /dev/null
@@ -1,44 +0,0 @@
-from ray.llm._internal.serve.configs.server_models import (
- LLMConfig as _LLMConfig,
- LLMServingArgs as _LLMServingArgs,
- ModelLoadingConfig as _ModelLoadingConfig,
- CloudMirrorConfig as _CloudMirrorConfig,
- LoraConfig as _LoraConfig,
-)
-
-from ray.util.annotations import PublicAPI
-
-
-@PublicAPI(stability="alpha")
-class LLMConfig(_LLMConfig):
- """The configuration for starting an LLM deployment."""
-
- pass
-
-
-@PublicAPI(stability="alpha")
-class LLMServingArgs(_LLMServingArgs):
- """The configuration for starting an LLM deployment application."""
-
- pass
-
-
-@PublicAPI(stability="alpha")
-class ModelLoadingConfig(_ModelLoadingConfig):
- """The configuration for loading an LLM model."""
-
- pass
-
-
-@PublicAPI(stability="alpha")
-class CloudMirrorConfig(_CloudMirrorConfig):
- """The configuration for mirroring an LLM model from cloud storage."""
-
- pass
-
-
-@PublicAPI(stability="alpha")
-class LoraConfig(_LoraConfig):
- """The configuration for loading an LLM model with LoRA."""
-
- pass
diff --git a/python/ray/serve/llm/deployments.py b/python/ray/serve/llm/deployments.py
deleted file mode 100644
index 77d380f60d1e..000000000000
--- a/python/ray/serve/llm/deployments.py
+++ /dev/null
@@ -1,122 +0,0 @@
-from ray.llm._internal.serve.deployments.llm.vllm.vllm_deployment import (
- VLLMService as _VLLMService,
-)
-from ray.llm._internal.serve.deployments.routers.router import (
- LLMRouter as _LLMRouter,
-)
-
-
-from ray.util.annotations import PublicAPI
-
-
-@PublicAPI(stability="alpha")
-class VLLMService(_VLLMService):
- """The implementation of the VLLM engine deployment.
-
- To build a VLLMDeployment object you should use `build_vllm_deployment` function.
- We also expose a lower level API for more control over the deployment class
- through `as_deployment` method.
-
- Examples:
- .. testcode::
- :skipif: True
-
- from ray import serve
- from ray.serve.config import AutoscalingConfig
- from ray.serve.llm.configs import LLMConfig, ModelLoadingConfig, DeploymentConfig
- from ray.serve.llm.deployments import VLLMDeployment
- from ray.serve.llm.openai_api_models import ChatCompletionRequest
-
- # Configure the model
- llm_config = LLMConfig(
- model_loading_config=ModelLoadingConfig(
- served_model_name="llama-3.1-8b",
- model_source="meta-llama/Llama-3.1-8b-instruct",
- ),
- deployment_config=DeploymentConfig(
- autoscaling_config=AutoscalingConfig(
- min_replicas=1,
- max_replicas=8,
- )
- ),
- )
-
- # Build the deployment directly
- VLLMDeployment = VLLMService.as_deployment(llm_config.get_serve_options())
- vllm_app = VLLMDeployment.bind(llm_config)
-
- model_handle = serve.run(vllm_app)
-
- # Query the model via `chat` api
- from ray.serve.llm.openai_api_models import ChatCompletionRequest
- request = ChatCompletionRequest(
- model="llama-3.1-8b",
- messages=[
- {
- "role": "user",
- "content": "Hello, world!"
- }
- ]
- )
- response = ray.get(model_handle.chat(request))
- print(response)
- """
-
- pass
-
-
-@PublicAPI(stability="alpha")
-class LLMRouter(_LLMRouter):
-
- """The implementation of the OpenAI compatiple model router.
-
- This deployment creates the following endpoints:
- - /v1/chat/completions: Chat interface (OpenAI-style)
- - /v1/completions: Text completion
- - /v1/models: List available models
- - /v1/models/{model}: Model information
-
-
- Examples:
- .. testcode::
- :skipif: True
-
-
- from ray import serve
- from ray.serve.config import AutoscalingConfig
- from ray.serve.llm.configs import LLMConfig, ModelLoadingConfig, DeploymentConfig
- from ray.serve.llm.deployments import VLLMDeployment
- from ray.serve.llm.openai_api_models import ChatCompletionRequest
-
-
- llm_config1 = LLMConfig(
- model_loading_config=ModelLoadingConfig(
- served_model_name="llama-3.1-8b", # Name shown in /v1/models
- model_source="meta-llama/Llama-3.1-8b-instruct",
- ),
- deployment_config=DeploymentConfig(
- autoscaling_config=AutoscalingConfig(
- min_replicas=1, max_replicas=8,
- )
- ),
- )
- llm_config2 = LLMConfig(
- model_loading_config=ModelLoadingConfig(
- served_model_name="llama-3.2-3b", # Name shown in /v1/models
- model_source="meta-llama/Llama-3.2-3b-instruct",
- ),
- deployment_config=DeploymentConfig(
- autoscaling_config=AutoscalingConfig(
- min_replicas=1, max_replicas=8,
- )
- ),
- )
-
- # Deploy the application
- vllm_deployment1 = VLLMDeployment.as_deployment(llm_config1.get_serve_options()).bind(llm_config1)
- vllm_deployment2 = VLLMDeployment.as_deployment(llm_config2.get_serve_options()).bind(llm_config2)
- llm_app = LLMModelRouterDeployment.as_deployment().bind([vllm_deployment1, vllm_deployment2])
- serve.run(llm_app)
- """
-
- pass
diff --git a/python/ray/serve/tests/unit/test_llm_imports.py b/python/ray/serve/tests/unit/test_llm_imports.py
index 5edc50668685..494d517e3d55 100644
--- a/python/ray/serve/tests/unit/test_llm_imports.py
+++ b/python/ray/serve/tests/unit/test_llm_imports.py
@@ -27,16 +27,16 @@ def test_serve_llm_import_does_not_error():
with pytest.raises(ImportError):
import ray.serve.llm # noqa: F401
with pytest.raises(ImportError):
- from ray.serve.llm.configs import (
+ from ray.serve.llm import (
LLMConfig, # noqa: F401
)
with pytest.raises(ImportError):
- from ray.serve.llm.deployments import (
- VLLMService, # noqa: F401
+ from ray.serve.llm import (
+ VLLMServer, # noqa: F401
LLMRouter, # noqa: F401
)
with pytest.raises(ImportError):
- from ray.serve.llm.builders import (
+ from ray.serve.llm import (
build_vllm_deployment, # noqa: F401
build_openai_app, # noqa: F401
)
diff --git a/release/llm_tests/serve_llama_3dot1_8b_lora.yaml b/release/llm_tests/serve_llama_3dot1_8b_lora.yaml
index 8d9bc8cdfac9..d9e34685e0c5 100644
--- a/release/llm_tests/serve_llama_3dot1_8b_lora.yaml
+++ b/release/llm_tests/serve_llama_3dot1_8b_lora.yaml
@@ -2,6 +2,6 @@ applications:
- args:
llm_configs:
- ./model_config/llama_3dot1_8b_lora.yaml
- import_path: ray.serve.llm.builders:build_openai_app
+ import_path: ray.serve.llm:build_openai_app
name: llm-endpoint
route_prefix: /
diff --git a/release/llm_tests/serve_llama_3dot1_8b_quantized_tp1.yaml b/release/llm_tests/serve_llama_3dot1_8b_quantized_tp1.yaml
index 226d5ef9a47b..85666c3b75e8 100644
--- a/release/llm_tests/serve_llama_3dot1_8b_quantized_tp1.yaml
+++ b/release/llm_tests/serve_llama_3dot1_8b_quantized_tp1.yaml
@@ -2,6 +2,6 @@ applications:
- args:
llm_configs:
- ./model_config/llama_3dot1_8b_quantized_tp1.yaml
- import_path: ray.serve.llm.builders:build_openai_app
+ import_path: ray.serve.llm:build_openai_app
name: llm-endpoint
route_prefix: /
diff --git a/release/llm_tests/serve_llama_3dot1_8b_tp2.yaml b/release/llm_tests/serve_llama_3dot1_8b_tp2.yaml
index 11e1c1ea003c..20654319487f 100644
--- a/release/llm_tests/serve_llama_3dot1_8b_tp2.yaml
+++ b/release/llm_tests/serve_llama_3dot1_8b_tp2.yaml
@@ -2,6 +2,6 @@ applications:
- args:
llm_configs:
- ./model_config/llama_3dot1_8b_tp2.yaml
- import_path: ray.serve.llm.builders:build_openai_app
+ import_path: ray.serve.llm:build_openai_app
name: llm-endpoint
route_prefix: /